diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile new file mode 100644 index 00000000..2be22e4d --- /dev/null +++ b/.devops/main-vulkan.Dockerfile @@ -0,0 +1,20 @@ +FROM ubuntu:24.04 AS build +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y build-essential wget cmake git libvulkan-dev glslc \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY .. . +RUN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" + +FROM ubuntu:24.04 AS runtime +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y curl ffmpeg libsdl2-dev wget cmake git libvulkan1 mesa-vulkan-drivers \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY --from=build /app /app +ENV PATH=/app/build/bin:$PATH +ENTRYPOINT [ "bash", "-c" ] diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index ff420f2b..83473e46 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -13,10 +13,10 @@ jobs: ubuntu-22: runs-on: ubuntu-22.04 steps: - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: go-version: '^1.23' - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - run: | cd bindings/go make test diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 680862fb..c3f158e2 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -17,5 +17,5 @@ jobs: - uses: ruby/setup-ruby@v1 with: ruby-version: '3.2' - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - run: rake test diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5c1cf93b..8ce887fd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -67,7 +67,7 @@ jobs: steps: - name: Checkout with full history - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -127,7 +127,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -159,7 +159,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -174,10 +174,6 @@ jobs: sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - apt update apt install -y build-essential libsdl2-dev cmake git cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a @@ -195,7 +191,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -210,10 +206,6 @@ jobs: sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - apt update apt install -y build-essential libsdl2-dev cmake git cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp @@ -231,7 +223,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -263,7 +255,7 @@ jobs: # # steps: # - name: Clone -# uses: actions/checkout@v4 +# uses: actions/checkout@v6 # # - name: Build # uses: cross-platform-actions/action@v0.27.0 @@ -289,7 +281,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -323,7 +315,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -338,10 +330,6 @@ jobs: sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - apt update apt install -y build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a @@ -361,7 +349,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -376,10 +364,6 @@ jobs: sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - apt update apt install -y build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp @@ -402,7 +386,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -417,10 +401,6 @@ jobs: sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - apt update apt install -y clang build-essential cmake libsdl2-dev git cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang @@ -440,7 +420,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -480,7 +460,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -504,7 +484,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build id: cmake_build @@ -532,7 +512,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -556,7 +536,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build id: cmake_build @@ -581,7 +561,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ${{ matrix.sys }} uses: msys2/setup-msys2@v2 @@ -636,7 +616,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Add msbuild to PATH uses: microsoft/setup-msbuild@v2 @@ -666,31 +646,31 @@ jobs: - name: Upload SDL2.dll if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ${{ matrix.s2arc }}_SDL2.dll path: build/bin/${{ matrix.build }}/SDL2.dll - name: Upload whisper dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/whisper.dll - name: Upload ggml dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll - name: Upload ggml base dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ggml_base_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-base.dll - name: Upload ggml cpu dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ggml_cpu_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-cpu.dll @@ -702,7 +682,7 @@ jobs: - name: Upload binaries if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper-bin-${{ matrix.arch }}.zip path: whisper-bin-${{ matrix.arch }}.zip @@ -731,10 +711,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); @@ -788,7 +768,7 @@ jobs: - name: Upload binaries if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper-blas-bin-${{ matrix.arch }}.zip path: whisper-blas-bin-${{ matrix.arch }}.zip @@ -812,7 +792,7 @@ jobs: sdl2_ver: 2.28.5 steps: - name: Clone repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Ninja id: install_ninja @@ -997,7 +977,7 @@ jobs: - name: Upload binaries if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip @@ -1013,7 +993,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup emsdk uses: mymindstorm/setup-emsdk@v14 @@ -1036,7 +1016,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Configure run: | @@ -1078,7 +1058,7 @@ jobs: - name: Upload artifacts if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip @@ -1090,12 +1070,12 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: whisper - name: Install Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: zulu java-version: 21 @@ -1119,10 +1099,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: set up JDK 11 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '11' distribution: 'temurin' @@ -1145,36 +1125,36 @@ jobs: needs: ['windows'] runs-on: windows-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Install Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: zulu java-version: 20 - name: Download Whisper Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: whisper_x64.dll - name: Download GGML Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: ggml_x64.dll - name: Download GGML Base Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: ggml_base_x64.dll - name: Download GGML CPU Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: ggml_cpu_x64.dll - name: Download SDL2.dll - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: x64_SDL2.dll @@ -1221,7 +1201,7 @@ jobs: Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip" - name: Upload jar - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whispercpp.jar.zip path: whispercpp.jar.zip @@ -1245,7 +1225,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test quantize run: | @@ -1269,7 +1249,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -1282,7 +1262,7 @@ jobs: # Downloads all the artifacts from the previous jobs - name: Download artifacts id: download-artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: ./artifact @@ -1332,7 +1312,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set environment variables id: set_vars @@ -1358,7 +1338,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build shell: bash @@ -1378,7 +1358,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1403,7 +1383,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1428,7 +1408,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1453,7 +1433,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1478,7 +1458,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1503,7 +1483,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1517,7 +1497,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1531,7 +1511,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1545,7 +1525,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1558,7 +1538,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1571,7 +1551,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 0e2fb1f2..6c0de0ec 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -22,10 +22,11 @@ jobs: - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } + - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64" } steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -67,7 +68,7 @@ jobs: echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . push: ${{ github.event_name == 'push' }} diff --git a/.github/workflows/examples-wasm.yml b/.github/workflows/examples-wasm.yml index ebbbdfe2..927438cd 100644 --- a/.github/workflows/examples-wasm.yml +++ b/.github/workflows/examples-wasm.yml @@ -22,10 +22,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Pages - uses: actions/configure-pages@v4 + uses: actions/configure-pages@v5 - name: Setup emsdk uses: mymindstorm/setup-emsdk@v14 @@ -88,7 +88,7 @@ jobs: find staging -type f | sort - name: Upload artifact - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: path: ./staging diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 74ef8e0f..1c9ade5a 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -17,7 +17,7 @@ jobs: node-version: [ 16.x, 18.x ] steps: - name: Clone - uses: actions/checkout@v1 + uses: actions/checkout@v6 - name: Dependencies run: | @@ -27,7 +27,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v1 + uses: actions/setup-node@v6 with: node-version: ${{ matrix.node-version }} cache: 'npm' diff --git a/LICENSE b/LICENSE index acb96ce7..e7dca554 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2023-2026 The ggml authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index d7dea5f7..e4e1383b 100644 --- a/README.md +++ b/README.md @@ -443,11 +443,12 @@ ffmpeg -i samples/jfk.wav jfk.opus ### Images -We have two Docker images available for this project: +We have multiple Docker images available for this project: 1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`) 2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`) 3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`) +4. `ghcr.io/ggml-org/whisper.cpp:main-vulkan`: Same as `main` but compiled with Vulkan support. (platforms: `linux/amd64`) ### Usage @@ -456,15 +457,27 @@ We have two Docker images available for this project: docker run -it --rm \ -v path/to/models:/models \ whisper.cpp:main "./models/download-ggml-model.sh base /models" + # transcribe an audio file docker run -it --rm \ -v path/to/models:/models \ -v path/to/audios:/audios \ whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav" + # transcribe an audio file in samples folder docker run -it --rm \ -v path/to/models:/models \ whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav" + +# run the web server +docker run -it --rm -p "8080:8080" \ + -v path/to/models:/models \ + whisper.cpp:main "whisper-server --host 127.0.0.1 -m /models/ggml-base.bin" + +# run the bench too on the small.en model using 4 threads +docker run -it --rm \ + -v path/to/models:/models \ + whisper.cpp:main "whisper-bench -m /models/ggml-small.en.bin -t 4" ``` ## Installing with Conan @@ -750,7 +763,7 @@ argument to `whisper-cli`. In addition to this option a VAD model is also required. The way this works is that first the audio samples are passed through -the VAD model which will detect speech segments. Using this information the +the VAD model which will detect speech segments. Using this information, only the speech segments that are detected are extracted from the original audio input and passed to whisper for processing. This reduces the amount of audio data that needs to be processed by whisper and can significantly speed up the diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go index 728c6df5..e72262eb 100644 --- a/bindings/go/examples/go-model-download/main.go +++ b/bindings/go/examples/go-model-download/main.go @@ -282,13 +282,20 @@ func Download(ctx context.Context, p io.Writer, model, out string) (string, erro default: // Read body n, err := resp.Body.Read(data) + if n > 0 { + if m, err := w.Write(data[:n]); err != nil { + return path, err + } else { + count += int64(m) + } + } + if err != nil { - DownloadReport(p, pct, count, resp.ContentLength) + if err == io.EOF { + DownloadReport(p, pct, count, resp.ContentLength) + return path, nil + } return path, err - } else if m, err := w.Write(data[:n]); err != nil { - return path, err - } else { - count += int64(m) } } } diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index ea202753..c6280a69 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -247,6 +247,58 @@ whisper.transcribe("path/to/audio.wav", params) ``` +### Tokens ### + +Each segment has tokens. + +To enable token timestamps, you need to set `Whisper::Params#token_timestamps = true`. Then, retrieve tokens from segments using `Whisper::Segment#each_token`. + +```ruby +whisper = Whisper::Context.new("base.en") +params = Whisper::Params.new(token_timestamps: true) +whisper + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + segment.each_token do |token| + token => {start_time:, end_time:, text:, probability:} + st = "%05.2fs" % (start_time / 1000.0) + et = "%05.2fs" % (end_time / 1000.0) + prob = "%.1f%%" % (probability * 100) + puts "[#{st} --> #{et}] #{text} (#{prob})" + end + end +``` + +``` +[00.00s --> 00.00s] [_BEG_] (84.2%) +[00.32s --> 00.37s] And (71.2%) +[00.37s --> 00.53s] so (98.5%) +[00.69s --> 00.85s] my (70.7%) +[00.85s --> 01.59s] fellow (99.5%) +[01.59s --> 02.10s] Americans (90.1%) +[02.85s --> 03.30s] , (28.4%) +[03.30s --> 04.14s] ask (79.8%) +[04.14s --> 04.28s] not (78.9%) +[05.03s --> 05.35s] what (93.3%) +[05.41s --> 05.74s] your (98.8%) +[05.74s --> 06.41s] country (99.6%) +[06.41s --> 06.74s] can (97.7%) +[06.74s --> 06.92s] do (99.0%) +[07.00s --> 07.00s] for (95.8%) +[07.01s --> 07.52s] you (98.5%) +[07.81s --> 08.05s] , (49.3%) +[08.19s --> 08.37s] ask (65.6%) +[08.37s --> 08.75s] what (98.8%) +[08.91s --> 09.04s] you (98.2%) +[09.04s --> 09.32s] can (96.9%) +[09.32s --> 09.38s] do (90.3%) +[09.44s --> 09.76s] for (91.8%) +[09.76s --> 09.99s] your (98.2%) +[10.02s --> 10.36s] country (99.6%) +[10.51s --> 10.99s] . (87.0%) +[11.00s --> 11.00s] [_TT_550] (7.6%) +``` + ### Models ### You can see model information: @@ -323,7 +375,38 @@ whisper end ``` -The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. +The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. + +If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. + +```ruby +require "torchaudio" +require "arrow-numo-narray" +require "whisper" + +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +# Convert Torch::Tensor to Arrow::Array via Numo::NArray +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +whisper = Whisper::Context.new("base") +whisper + # Arrow::Array exports MemoryView + .full(Whisper::Params.new, samples) +``` + +Custom context params +--------------------- + +You can use customize `Whisper::Context`'s behavior using `Whisper::Context::Params`. + +```ruby +context_params = Whisper::Context::Params.new( + use_gpu: false, + flash_attn: false, + # etc +) +whisper = Whisper::Context.new("base", context_params) +``` Using VAD separately from ASR ----------------------------- @@ -334,13 +417,27 @@ VAD feature itself is useful. You can use it separately from ASR: vad = Whisper::VAD::Context.new("silero-v6.2.0") vad .detect("path/to/audio.wav", Whisper::VAD::Params.new) - .each_with_index do |segment, index| + .each.with_index do |segment, index| segment => {start_time: st, end_time: ed} # `Segment` responds to `#deconstruct_keys` puts "[%{nth}: %{st} --> %{ed}]" % {nth: index + 1, st:, ed:} end ``` +You may also low level API `Whisper::VAD::Context#segments_from_samples` as such `Whisper::Context#full`: + +```ruby +# Ruby Array +reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000)) +samples = reader.enum_for(:each_buffer).map(&:samples).flatten + +# Or, object which exports MemoryView +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +segments = vad.segments_from_samples(Whisper::VAD::Params.new, samples) +``` + Development ----------- diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 8a5ac674..acff501a 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -7,6 +7,7 @@ options = Options.new(cmake).to_s have_library("gomp") rescue nil libs = Dependencies.new(cmake, options).to_s +$CFLAGS << " -O3 -march=native" $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" $LOCAL_LIBS << " #{libs}" $cleanfiles << " build #{libs}" diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ac677e9e..ba71d4ba 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,5 +1,3 @@ -#include -#include #include "ruby_whisper.h" VALUE mWhisper; @@ -35,7 +33,8 @@ static bool is_log_callback_finalized = false; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); -extern void init_ruby_whisper_context(VALUE *mWhisper); +extern VALUE init_ruby_whisper_context(VALUE *mWhisper); +extern void init_ruby_whisper_context_params(VALUE *cContext); extern void init_ruby_whisper_params(VALUE *mWhisper); extern void init_ruby_whisper_error(VALUE *mWhisper); extern void init_ruby_whisper_segment(VALUE *mWhisper); @@ -164,6 +163,22 @@ void Init_whisper() { rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG)); rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT)); + rb_define_const(mWhisper, "AHEADS_NONE", INT2NUM(WHISPER_AHEADS_NONE)); + rb_define_const(mWhisper, "AHEADS_N_TOP_MOST", INT2NUM(WHISPER_AHEADS_N_TOP_MOST)); + rb_define_const(mWhisper, "AHEADS_CUSTOM", INT2NUM(WHISPER_AHEADS_CUSTOM)); + rb_define_const(mWhisper, "AHEADS_TINY_EN", INT2NUM(WHISPER_AHEADS_TINY_EN)); + rb_define_const(mWhisper, "AHEADS_TINY", INT2NUM(WHISPER_AHEADS_TINY)); + rb_define_const(mWhisper, "AHEADS_BASE_EN", INT2NUM(WHISPER_AHEADS_BASE_EN)); + rb_define_const(mWhisper, "AHEADS_BASE", INT2NUM(WHISPER_AHEADS_BASE)); + rb_define_const(mWhisper, "AHEADS_SMALL_EN", INT2NUM(WHISPER_AHEADS_SMALL_EN)); + rb_define_const(mWhisper, "AHEADS_SMALL", INT2NUM(WHISPER_AHEADS_SMALL)); + rb_define_const(mWhisper, "AHEADS_MEDIUM_EN", INT2NUM(WHISPER_AHEADS_MEDIUM_EN)); + rb_define_const(mWhisper, "AHEADS_MEDIUM", INT2NUM(WHISPER_AHEADS_MEDIUM)); + rb_define_const(mWhisper, "AHEADS_LARGE_V1", INT2NUM(WHISPER_AHEADS_LARGE_V1)); + rb_define_const(mWhisper, "AHEADS_LARGE_V2", INT2NUM(WHISPER_AHEADS_LARGE_V2)); + rb_define_const(mWhisper, "AHEADS_LARGE_V3", INT2NUM(WHISPER_AHEADS_LARGE_V3)); + rb_define_const(mWhisper, "AHEADS_LARGE_V3_TURBO", INT2NUM(WHISPER_AHEADS_LARGE_V3_TURBO)); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); @@ -172,7 +187,8 @@ void Init_whisper() { rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); - init_ruby_whisper_context(&mWhisper); + cContext = init_ruby_whisper_context(&mWhisper); + init_ruby_whisper_context_params(&cContext); init_ruby_whisper_params(&mWhisper); init_ruby_whisper_error(&mWhisper); init_ruby_whisper_segment(&mWhisper); diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 3f5660c3..8dfd103c 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -1,6 +1,8 @@ #ifndef RUBY_WHISPER_H #define RUBY_WHISPER_H +#include +#include #include "whisper.h" typedef struct { @@ -14,6 +16,10 @@ typedef struct { struct whisper_context *context; } ruby_whisper; +typedef struct ruby_whisper_context_params { + struct whisper_context_params params; +} ruby_whisper_context_params; + typedef struct { struct whisper_full_params params; bool diarize; @@ -35,7 +41,7 @@ typedef struct { typedef struct { whisper_token_data *token_data; - const char *text; + VALUE text; } ruby_whisper_token; typedef struct { @@ -55,6 +61,13 @@ typedef struct { struct whisper_vad_context *context; } ruby_whisper_vad_context; +typedef struct parsed_samples_t { + float *samples; + int n_samples; + rb_memory_view_t memview; + bool memview_exported; +} parsed_samples_t; + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -62,13 +75,28 @@ typedef struct { } \ } while (0) -#define GetToken(obj, rwt) do { \ +#define GetContextParams(obj, rwcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_context_params, &ruby_whisper_context_params_type, (rwcp)); \ +} while (0) + +#define GetToken(obj, rwt) do { \ TypedData_Get_Struct((obj), ruby_whisper_token, &ruby_whisper_token_type, (rwt)); \ if ((rwt)->token_data == NULL) { \ rb_raise(rb_eRuntimeError, "Not initialized"); \ } \ } while (0) +#define GetVADContext(obj, rwvc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_context, &ruby_whisper_vad_context_type, (rwvc)); \ + if ((rwvc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetVADParams(obj, rwvp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_params, &ruby_whisper_vad_params_type, (rwvp)); \ +} while (0) + #define GetVADSegments(obj, rwvss) do { \ TypedData_Get_Struct((obj), ruby_whisper_vad_segments, &ruby_whisper_vad_segments_type, (rwvss)); \ if ((rwvss)->segments == NULL) { \ diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index a7b5f851..c39d43bd 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -1,5 +1,3 @@ -#include -#include #include "ruby_whisper.h" extern ID id_to_s; @@ -20,6 +18,7 @@ extern VALUE eError; extern VALUE cModel; extern const rb_data_type_t ruby_whisper_params_type; +extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); @@ -27,6 +26,27 @@ extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); ID transcribe_option_names[1]; +typedef struct fill_samples_args { + float *dest; + VALUE *src; + int n_samples; +} fill_samples_args; + +typedef struct full_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} full_args; + +typedef struct full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} full_parallel_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -124,16 +144,25 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; VALUE whisper_model_file_path; + VALUE context_params; + struct whisper_context_params params; // TODO: we can support init from buffer here too maybe another ruby object to expose - rb_scan_args(argc, argv, "01", &whisper_model_file_path); + rb_scan_args(argc, argv, "11", &whisper_model_file_path, &context_params); TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } - rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); + if (NIL_P(context_params)) { + params = whisper_context_default_params(); + } else { + ruby_whisper_context_params *rwcp; + GetContextParams(context_params, rwcp); + params = rwcp->params; + } + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), params); if (rw->context == NULL) { rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); } @@ -272,6 +301,147 @@ VALUE ruby_whisper_model_type(VALUE self) return rb_str_new2(whisper_model_type_readable(rw->context)); } +static bool +check_memory_view(rb_memory_view_t *memview) +{ + if (memview->format != NULL && strcmp(memview->format, "f") != 0) { + rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); + return false; + } + if (memview->format != NULL && memview->ndim != 1) { + rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); + return false; + } + + return true; +} + +static VALUE +fill_samples(VALUE rb_args) +{ + fill_samples_args *args = (fill_samples_args *)rb_args; + + if (RB_TYPE_P(*args->src, T_ARRAY)) { + for (int i = 0; i < args->n_samples; i++) { + args->dest[i] = RFLOAT_VALUE(rb_ary_entry(*args->src, i)); + } + } else { + // TODO: use rb_block_call + VALUE iter = rb_funcall(*args->src, id_to_enum, 1, rb_str_new2("each")); + for (int i = 0; i < args->n_samples; i++) { + // TODO: check if iter is exhausted and raise ArgumentError appropriately + VALUE sample = rb_funcall(iter, id_next, 0); + args->dest[i] = RFLOAT_VALUE(sample); + } + } + + return Qnil; +} + +struct parsed_samples_t +parse_samples(VALUE *samples, VALUE *n_samples) +{ + bool memview_available = rb_memory_view_available_p(*samples); + struct parsed_samples_t parsed = {0}; + parsed.memview_exported = false; + const bool is_array = RB_TYPE_P(*samples, T_ARRAY); + + if (!NIL_P(*n_samples)) { + parsed.n_samples = NUM2INT(*n_samples); + if (is_array) { + if (RARRAY_LEN(*samples) < parsed.n_samples) { + rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(*samples), parsed.n_samples); + } + } + // Should check when samples.respond_to?(:length)? + } else { + if (is_array) { + if (RARRAY_LEN(*samples) > INT_MAX) { + rb_raise(rb_eArgError, "samples are too long"); + } + parsed.n_samples = (int)RARRAY_LEN(*samples); + } else if (memview_available) { + bool memview_got = rb_memory_view_get(*samples, &parsed.memview, RUBY_MEMORY_VIEW_SIMPLE); + if (memview_got) { + parsed.memview_exported = check_memory_view(&parsed.memview); + if (!parsed.memview_exported) { + rb_memory_view_release(&parsed.memview); + parsed.memview = (rb_memory_view_t){0}; + } + } + if (parsed.memview_exported) { + ssize_t n_samples_size = parsed.memview.byte_size / parsed.memview.item_size; + if (n_samples_size > INT_MAX) { + rb_memory_view_release(&parsed.memview); + rb_raise(rb_eArgError, "samples are too long: %zd", n_samples_size); + } + parsed.n_samples = (int)n_samples_size; + } else { + rb_warn("unable to get a memory view. falls back to Ruby object"); + if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length"); + } + } + } else if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of float when n_samples is not given"); + } + } + + if (parsed.memview_exported) { + parsed.samples = (float *)parsed.memview.data; + } else { + parsed.samples = ALLOC_N(float, parsed.n_samples); + fill_samples_args args = { + parsed.samples, + samples, + parsed.n_samples, + }; + int state; + rb_protect(fill_samples, (VALUE)&args, &state); + if (state) { + xfree(parsed.samples); + rb_jump_tag(state); + } + } + + return parsed; +} + +VALUE +release_samples(VALUE rb_parsed_args) +{ + parsed_samples_t *parsed_args = (parsed_samples_t *)rb_parsed_args; + + if (parsed_args->memview_exported) { + rb_memory_view_release(&parsed_args->memview); + } else { + xfree(parsed_args->samples); + } + *parsed_args = (parsed_samples_t){0}; + + return Qnil; +} + +static VALUE +full_body(VALUE rb_args) +{ + full_args *args = (full_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + prepare_transcription(rwp, args->context); + int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); + + return INT2NUM(result); +} + /* * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text * Not thread safe for same context @@ -289,65 +459,17 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); - if (argc == 3) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)RARRAY_LEN(samples); - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)n_samples_size; - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // TODO: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError appropriately - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - prepare_transcription(rwp, &self); - const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { @@ -355,6 +477,22 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static VALUE +full_parallel_body(VALUE rb_args) +{ + full_parallel_args *args = (full_parallel_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + prepare_transcription(rwp, args->context); + int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); + + return INT2NUM(result); +} + /* * Split the input audio in chunks and process each chunk separately using whisper_full_with_state() * Result is stored in the default state of the context @@ -372,19 +510,11 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) { if (argc < 2 || argc > 4) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..4)", argc); } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; + VALUE n_samples = argc == 2 ? Qnil : argv[2]; int n_processors; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); switch (argc) { case 2: n_processors = 1; @@ -396,56 +526,16 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) n_processors = NUM2INT(argv[3]); break; } - if (argc >= 3 && !NIL_P(argv[2])) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)n_samples_size; - } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)RARRAY_LEN(samples); - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // FIXME: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - prepare_transcription(rwp, &self); - const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + const full_parallel_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + n_processors, + }; + const VALUE rb_result = rb_ensure(full_parallel_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { @@ -631,7 +721,7 @@ ruby_whisper_get_model(VALUE self) return rb_whisper_model_s_new(self); } -void +VALUE init_ruby_whisper_context(VALUE *mWhisper) { cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject); @@ -669,4 +759,6 @@ init_ruby_whisper_context(VALUE *mWhisper) rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); rb_define_method(cContext, "model", ruby_whisper_get_model, 0); + + return cContext; } diff --git a/bindings/ruby/ext/ruby_whisper_context_params.c b/bindings/ruby/ext/ruby_whisper_context_params.c new file mode 100644 index 00000000..87df21d4 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_context_params.c @@ -0,0 +1,163 @@ +#include "ruby_whisper.h" + +#define NUM_PARAMS 6 + +#define DEF_BOOLEAN_ATTR_METHOD(name) \ +static VALUE \ +ruby_whisper_context_params_get_ ## name(VALUE self) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + return rwcp->params.name ? Qtrue : Qfalse; \ +} \ +static VALUE \ +ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + rwcp->params.name = RTEST(value); \ + return value; \ +} + +#define DEF_INT_ATTR_METHOD(name) \ +static VALUE \ +ruby_whisper_context_params_get_ ## name(VALUE self) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + return INT2NUM(rwcp->params.name); \ +} \ +static VALUE \ +ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + rwcp->params.name = NUM2INT(value); \ + return value; \ +} + +#define DEFINE_PARAM(param_name, nth) \ + id_ ## param_name = rb_intern(#param_name); \ + param_names[nth] = id_ ## param_name; \ + rb_define_method(cContextParams, #param_name, ruby_whisper_context_params_get_ ## param_name, 0); \ + rb_define_method(cContextParams, #param_name "=", ruby_whisper_context_params_set_ ## param_name, 1); + +VALUE cContextParams; + +static ID param_names[NUM_PARAMS]; +static ID id_use_gpu; +static ID id_flash_attn; +static ID id_gpu_device; +static ID id_dtw_token_timestamps; +static ID id_dtw_aheads_preset; +static ID id_dtw_n_top; + +static size_t +ruby_whisper_context_params_memsize(const void *p) +{ + const ruby_whisper_context_params *rwcp = (ruby_whisper_context_params *)p; + if (!rwcp) { + return 0; + } + return sizeof(ruby_whisper_context_params); +} + +const rb_data_type_t ruby_whisper_context_params_type = { + "ruby_whisper_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_context_params_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_context_params *rwcp; + return TypedData_Make_Struct(klass, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp); +} + +DEF_BOOLEAN_ATTR_METHOD(use_gpu); +DEF_BOOLEAN_ATTR_METHOD(flash_attn); +DEF_INT_ATTR_METHOD(gpu_device); +DEF_BOOLEAN_ATTR_METHOD(dtw_token_timestamps); +DEF_INT_ATTR_METHOD(dtw_aheads_preset); + +static VALUE +ruby_whisper_context_params_get_dtw_n_top(VALUE self) { + ruby_whisper_context_params *rwcp; + GetContextParams(self, rwcp); + + int dtw_n_top = rwcp->params.dtw_n_top; + + return dtw_n_top == -1 ? Qnil : INT2NUM(dtw_n_top); +} + +static VALUE +ruby_whisper_context_params_set_dtw_n_top(VALUE self, VALUE value) { + ruby_whisper_context_params *rwcp; + GetContextParams(self, rwcp); + + rwcp->params.dtw_n_top = NIL_P(value) ? -1 : NUM2INT(value); + + return value; +} + +#define SET_PARAM_IF_SAME(param_name) \ + if (id == id_ ## param_name) { \ + ruby_whisper_context_params_set_ ## param_name(self, value); \ + continue; \ + } + +static VALUE +ruby_whisper_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_context_params *rwcp; + TypedData_Get_Struct(self, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp); + rwcp->params = whisper_context_default_params(); + + VALUE kw_hash; + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + VALUE values[NUM_PARAMS] = {Qundef}; + rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values); + + ID id; + VALUE value; + for (int i = 0; i < NUM_PARAMS; i++) { + id = param_names[i]; + value = values[i]; + if (value == Qundef) { + continue; + } + SET_PARAM_IF_SAME(use_gpu) + SET_PARAM_IF_SAME(flash_attn) + SET_PARAM_IF_SAME(gpu_device) + SET_PARAM_IF_SAME(dtw_token_timestamps) + SET_PARAM_IF_SAME(dtw_aheads_preset) + SET_PARAM_IF_SAME(dtw_n_top) + } + + return Qnil; +} + +#undef SET_PARAM_IF_SAME + +void +init_ruby_whisper_context_params(VALUE *cContext) +{ + cContextParams = rb_define_class_under(*cContext, "Params", rb_cObject); + + rb_define_alloc_func(cContextParams, ruby_whisper_context_params_s_allocate); + rb_define_method(cContextParams, "initialize", ruby_whisper_context_params_initialize, -1); + + DEFINE_PARAM(use_gpu, 0) + DEFINE_PARAM(flash_attn, 1) + DEFINE_PARAM(gpu_device, 2) + DEFINE_PARAM(dtw_token_timestamps, 3) + DEFINE_PARAM(dtw_aheads_preset, 4) + DEFINE_PARAM(dtw_n_top, 5) +} + +#undef DEFINE_PARAM +#undef DEF_INT_ATTR_METHOD +#undef DEF_BOOLEAN_ATTR_METHOD +#undef NUM_PARAMS diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c index b196a8b5..0e91fb3f 100644 --- a/bindings/ruby/ext/ruby_whisper_model.c +++ b/bindings/ruby/ext/ruby_whisper_model.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 4dfe2575..61eb1733 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define BOOL_PARAMS_SETTER(self, prop, value) \ diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index 5229cb53..ee0d66c4 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 6 diff --git a/bindings/ruby/ext/ruby_whisper_token.c b/bindings/ruby/ext/ruby_whisper_token.c index ea4f4e63..73f5a547 100644 --- a/bindings/ruby/ext/ruby_whisper_token.c +++ b/bindings/ruby/ext/ruby_whisper_token.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 11 @@ -25,12 +24,34 @@ ruby_whisper_token_memsize(const void *p) if (!rwt) { return 0; } - return sizeof(rwt); + size_t size = sizeof(*rwt); + if (rwt->token_data) { + size += sizeof(*rwt->token_data); + } + return size; +} + +static void +ruby_whisper_token_mark(void *p) +{ + ruby_whisper_token *rwt = (ruby_whisper_token *)p; + rb_gc_mark(rwt->text); +} + +static void +ruby_whisper_token_free(void *p) +{ + ruby_whisper_token *rwt = (ruby_whisper_token *)p; + if (rwt->token_data) { + xfree(rwt->token_data); + rwt->token_data = NULL; + } + xfree(rwt); } static const rb_data_type_t ruby_whisper_token_type = { "ruby_whisper_token", - {0, RUBY_DEFAULT_FREE, ruby_whisper_token_memsize,}, + {ruby_whisper_token_mark, ruby_whisper_token_free, ruby_whisper_token_memsize,}, 0, 0, 0 }; @@ -41,19 +62,19 @@ ruby_whisper_token_allocate(VALUE klass) ruby_whisper_token *rwt; VALUE token = TypedData_Make_Struct(klass, ruby_whisper_token, &ruby_whisper_token_type, rwt); rwt->token_data = NULL; - rwt->text = NULL; + rwt->text = Qnil; return token; } VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int i_token) { - whisper_token_data token_data = whisper_full_get_token_data(context, i_segment, i_token); const VALUE token = ruby_whisper_token_allocate(cToken); ruby_whisper_token *rwt; TypedData_Get_Struct(token, ruby_whisper_token, &ruby_whisper_token_type, rwt); - rwt->token_data = &token_data; - rwt->text = whisper_full_get_token_text(context, i_segment, i_token); + rwt->token_data = ALLOC(whisper_token_data); + *(rwt->token_data) = whisper_full_get_token_data(context, i_segment, i_token); + rwt->text = rb_str_new2(whisper_full_get_token_text(context, i_segment, i_token)); return token; } @@ -183,10 +204,9 @@ ruby_whisper_token_get_text(VALUE self) { ruby_whisper_token *rwt; GetToken(self, rwt); - return rb_str_new2(rwt->text); + return rwt->text; } - /* * Start time of the token. * diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 594b2db9..c00fbcd1 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #include "common-whisper.h" #include @@ -13,6 +12,7 @@ extern const rb_data_type_t ruby_whisper_params_type; extern ID id_to_s; extern ID id_call; +extern ID id_to_path; extern ID transcribe_option_names[1]; extern void @@ -50,6 +50,9 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); } + if (rb_respond_to(wave_file_path, id_to_path)) { + wave_file_path = rb_funcall(wave_file_path, id_to_path, 0); + } std::string fname_inp = StringValueCStr(wave_file_path); std::vector pcmf32; // mono-channel F32 PCM diff --git a/bindings/ruby/ext/ruby_whisper_vad_context.c b/bindings/ruby/ext/ruby_whisper_vad_context.c index bf2ed2ba..97c9736b 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context.c +++ b/bindings/ruby/ext/ruby_whisper_vad_context.c @@ -1,12 +1,23 @@ -#include #include "ruby_whisper.h" extern ID id_to_s; extern VALUE cVADContext; +extern const rb_data_type_t ruby_whisper_vad_params_type; extern VALUE ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params); extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE parsed); + +extern VALUE ruby_whisper_vad_segments_s_init(struct whisper_vad_segments *segments); + +typedef struct segments_from_samples_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} segments_from_samples_args; static size_t ruby_whisper_vad_context_memsize(const void *p) @@ -66,10 +77,46 @@ ruby_whisper_vad_context_initialize(VALUE self, VALUE model_path) return Qnil; } +static VALUE +segments_from_samples_body(VALUE rb_args) +{ + segments_from_samples_args *args = (segments_from_samples_args *)rb_args; + + ruby_whisper_vad_context *rwvc; + ruby_whisper_vad_params *rwvp; + GetVADContext(*args->context, rwvc); + GetVADParams(*args->params, rwvp); + + struct whisper_vad_segments *segments = whisper_vad_segments_from_samples(rwvc->context, rwvp->params, args->samples, args->n_samples); + + return ruby_whisper_vad_segments_s_init(segments); +} + +static VALUE +ruby_whisper_vad_segments_from_samples(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + segments_from_samples_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE segments = rb_ensure(segments_from_samples_body, (VALUE)&args, release_samples, (VALUE)&parsed); + + return segments; +} + void init_ruby_whisper_vad_context(VALUE *mVAD) { cVADContext = rb_define_class_under(*mVAD, "Context", rb_cObject); rb_define_alloc_func(cVADContext, ruby_whisper_vad_context_s_allocate); rb_define_method(cVADContext, "initialize", ruby_whisper_vad_context_initialize, 1); + rb_define_method(cVADContext, "segments_from_samples", ruby_whisper_vad_segments_from_samples, -1); rb_define_method(cVADContext, "detect", ruby_whisper_vad_detect, 2); } diff --git a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp index 58609f87..802b0222 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp +++ b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #include "common-whisper.h" #include @@ -8,6 +7,8 @@ extern "C" { #endif +extern ID id_to_path; + extern VALUE cVADSegments; extern const rb_data_type_t ruby_whisper_vad_context_type; @@ -25,12 +26,12 @@ ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params) { std::vector> pcmf32s; whisper_vad_segments *segments; - TypedData_Get_Struct(self, ruby_whisper_vad_context, &ruby_whisper_vad_context_type, rwvc); - if (rwvc->context == NULL) { - rb_raise(rb_eRuntimeError, "Doesn't have referenxe to context internally"); - } + GetVADContext(self, rwvc); TypedData_Get_Struct(params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + if (rb_respond_to(file_path, id_to_path)) { + file_path = rb_funcall(file_path, id_to_path, 0); + } cpp_file_path = StringValueCStr(file_path); if (!read_audio_data(cpp_file_path, pcmf32, pcmf32s, false)) { diff --git a/bindings/ruby/ext/ruby_whisper_vad_params.c b/bindings/ruby/ext/ruby_whisper_vad_params.c index f254bfa2..28256650 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_params.c +++ b/bindings/ruby/ext/ruby_whisper_vad_params.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define DEFINE_PARAM(param_name, nth) \ diff --git a/bindings/ruby/ext/ruby_whisper_vad_segment.c b/bindings/ruby/ext/ruby_whisper_vad_segment.c index 49ff0aad..84a007bb 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segment.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segment.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 2 diff --git a/bindings/ruby/ext/ruby_whisper_vad_segments.c b/bindings/ruby/ext/ruby_whisper_vad_segments.c index 1bb37593..db62fdb6 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segments.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segments.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" extern ID id___method__; diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 1137e3f3..9ade451c 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -17,6 +17,21 @@ module Whisper LOG_LEVEL_ERROR: Integer LOG_LEVEL_DEBUG: Integer LOG_LEVEL_CONT: Integer + AHEADS_NONE: Integer + AHEADS_N_TOP_MOST: Integer + AHEADS_CUSTOM: Integer + AHEADS_TINY_EN: Integer + AHEADS_TINY: Integer + AHEADS_BASE_EN: Integer + AHEADS_BASE: Integer + AHEADS_SMALL_EN: Integer + AHEADS_SMALL: Integer + AHEADS_MEDIUM_EN: Integer + AHEADS_MEDIUM: Integer + AHEADS_LARGE_V1: Integer + AHEADS_LARGE_V2: Integer + AHEADS_LARGE_V3: Integer + AHEADS_LARGE_V3_TURBO: Integer def self.lang_max_id: () -> Integer def self.lang_id: (string name) -> Integer @@ -37,8 +52,8 @@ module Whisper # puts text # end # - def transcribe: (string, Params, ?n_processors: Integer) -> self - | (string, Params, ?n_processors: Integer) { (String) -> void } -> self + def transcribe: (path, Params, ?n_processors: Integer) -> self + | (path, Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -120,6 +135,30 @@ module Whisper def to_srt: () -> String def to_webvtt: () -> String + + class Params + def self.new: ( + use_gpu: boolish, + flash_attn: boolish, + gpu_device: Integer, + dtw_token_timestamps: boolish, + dtw_aheads_preset: Integer, + dtw_n_top: Integer | nil, + ) -> instance + + def use_gpu=: (boolish) -> boolish + def use_gpu: () -> (true | false) + def flash_attn=: (boolish) -> boolish + def flash_attn: () -> (true | false) + def gpu_device=: (Integer) -> Integer + def gpu_device: () -> Integer + def dtw_token_timestamps=: (boolish) -> boolish + def dtw_token_timestamps: () -> (true | false) + def dtw_aheads_preset=: (Integer) -> Integer + def dtw_aheads_preset: () -> Integer + def dtw_n_top=: (Integer | nil) -> (Integer | nil) + def dtw_n_top: () -> (Integer | nil) + end end class Params @@ -603,6 +642,8 @@ module Whisper class Context def self.new: (String | path | ::URI::HTTP model_name_or_path) -> instance + def segments_from_samples: (Params, Array[Float] samples, ?Integer n_samples) -> Segments + | (Params, _Samples, ?Integer n_samples) -> Segments def detect: (path wav_file_path, Params) -> Segments end diff --git a/bindings/ruby/test/test_context_params.rb b/bindings/ruby/test/test_context_params.rb new file mode 100644 index 00000000..8d19fdc9 --- /dev/null +++ b/bindings/ruby/test/test_context_params.rb @@ -0,0 +1,82 @@ +require_relative "helper" + +class TestContextParams < TestBase + PARAM_NAMES = [ + :use_gpu, + :flash_attn, + :gpu_device, + :dtw_token_timestamps, + :dtw_aheads_preset, + :dtw_n_top + ] + + def test_new + params = Whisper::Context::Params.new + assert_instance_of Whisper::Context::Params, params + end + + def test_attributes + params = Whisper::Context::Params.new + + assert_true params.use_gpu + params.use_gpu = false + assert_false params.use_gpu + + assert_true params.flash_attn + params.flash_attn = false + assert_false params.flash_attn + + assert_equal 0, params.gpu_device + params.gpu_device = 1 + assert_equal 1, params.gpu_device + + assert_false params.dtw_token_timestamps + params.dtw_token_timestamps = true + assert_true params.dtw_token_timestamps + + assert_equal Whisper::AHEADS_NONE, params.dtw_aheads_preset + params.dtw_aheads_preset =Whisper::AHEADS_BASE + assert_equal Whisper::AHEADS_BASE, params.dtw_aheads_preset + + assert_nil params.dtw_n_top + params.dtw_n_top = 6 + assert_equal 6, params.dtw_n_top + params.dtw_n_top = nil + assert_nil params.dtw_n_top + end + + def test_new_with_kw_args + params = Whisper::Context::Params.new(use_gpu: false) + assert_false params.use_gpu + end + + def test_new_with_kw_wargs_non_existent + assert_raise ArgumentError do + Whisper::Context::Params.new(non_existent: "value") + end + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_params = Whisper::Context::Params.new + default_value = default_params.send(param) + value = if param == :dtw_n_top + 6 + else + case default_value + in true | false + !default_value + in Integer + default_value + 1 + end + end + params = Whisper::Context::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + expected = default_params.send(name) + actual = params.send(name) + assert_equal expected, actual + end + end +end diff --git a/bindings/ruby/test/test_token.rb b/bindings/ruby/test/test_token.rb index e5834b1b..a23f6813 100644 --- a/bindings/ruby/test/test_token.rb +++ b/bindings/ruby/test/test_token.rb @@ -56,6 +56,17 @@ class TestToken < TestBase @segment.each_token.collect(&:text) end + def test_token_timestamps + params = Whisper::Params.new(token_timestamps: true) + whisper.transcribe(TestBase::AUDIO, params) + prev = -1 + whisper.each_segment.first.each_token do |token| + assert token.start_time >= prev + assert token.end_time >= token.start_time + prev = token.end_time + end + end + def test_deconstruct_keys_with_nil keys = %i[id tid probability log_probability pt ptsum t_dtw voice_length start_time end_time text] expected = keys.collect {|key| [key, @token.send(key)] }.to_h diff --git a/bindings/ruby/test/test_vad_context.rb b/bindings/ruby/test/test_vad_context.rb index 704916db..b4558d34 100644 --- a/bindings/ruby/test/test_vad_context.rb +++ b/bindings/ruby/test/test_vad_context.rb @@ -9,6 +9,25 @@ class TestVADContext < TestBase def test_detect context = Whisper::VAD::Context.new("silero-v6.2.0") segments = context.detect(AUDIO, Whisper::VAD::Params.new) + assert_segments segments + end + + def test_invalid_model_type + assert_raise TypeError do + Whisper::VAD::Context.new(Object.new) + end + end + + def test_allocate + vad = Whisper::VAD::Context.allocate + assert_raise do + vad.detect(AUDIO, Whisper::VAD::Params.new) + end + end + + private + + def assert_segments(segments) assert_instance_of Whisper::VAD::Segments, segments i = 0 @@ -35,16 +54,47 @@ class TestVADContext < TestBase assert_equal 4, segments.length end - def test_invalid_model_type - assert_raise TypeError do - Whisper::VAD::Context.new(Object.new) + sub_test_case "from samples" do + def setup + super + @vad = Whisper::VAD::Context.new("silero-v6.2.0") + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} end - end - def test_allocate - vad = Whisper::VAD::Context.allocate - assert_raise do - vad.detect(AUDIO, Whisper::VAD::Params.new) + def test_segments_from_samples + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_without_length + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples) + assert_segments segments + end + + def test_segments_from_samples_enumerator + samples = @samples.each + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + end + end + + def test_segments_from_samples_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples, 11) + end + end + + def test_segments_from_samples_with_memory_view + samples = JFKReader.new(AUDIO) + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + assert_segments segments end end end diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 96e248ac..29071210 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -1,6 +1,7 @@ require_relative "helper" require "stringio" require "etc" +require "pathname" # Exists to detect memory-related bug Whisper.log_set ->(level, buffer, user_data) {}, nil @@ -20,6 +21,15 @@ class TestWhisper < TestBase } end + def test_whisper_pathname + @whisper = Whisper::Context.new("base.en") + params = Whisper::Params.new + + @whisper.transcribe(Pathname(AUDIO), params) {|text| + assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) + } + end + def test_transcribe_non_parallel @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new @@ -207,6 +217,16 @@ class TestWhisper < TestBase assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) end + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @whisper.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + def test_full_parallel nprocessors = 2 @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2e05769a..88b94e7e 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.5' + s.version = '1.3.6' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] diff --git a/build-xcframework.sh b/build-xcframework.sh index bbf2764d..4d462bbf 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -559,7 +559,7 @@ xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-device/framework/whisper.framework \ -debug-symbols $(pwd)/build-ios-device/dSYMs/whisper.dSYM \ -framework $(pwd)/build-macos/framework/whisper.framework \ - -debug-symbols $(pwd)/build-macos/dSYMS/whisper.dSYM \ + -debug-symbols $(pwd)/build-macos/dSYMs/whisper.dSYM \ -framework $(pwd)/build-visionos/framework/whisper.framework \ -debug-symbols $(pwd)/build-visionos/dSYMs/whisper.dSYM \ -framework $(pwd)/build-visionos-sim/framework/whisper.framework \ diff --git a/close-issue.yml b/close-issue.yml index 276a217d..f661de1c 100644 --- a/close-issue.yml +++ b/close-issue.yml @@ -15,7 +15,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v10 with: exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap" days-before-issue-stale: 30 diff --git a/cmake/whisper-config.cmake.in b/cmake/whisper-config.cmake.in index 6a3fa227..b70c1e5a 100644 --- a/cmake/whisper-config.cmake.in +++ b/cmake/whisper-config.cmake.in @@ -3,60 +3,25 @@ set(WHISPER_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) set(WHISPER_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) set(WHISPER_SHARED_LIB @BUILD_SHARED_LIBS@) -set(GGML_BLAS @GGML_BLAS@) -set(GGML_CUDA @GGML_CUDA@) -set(GGML_METAL @GGML_METAL@) -set(GGML_HIPBLAS @GGML_HIPBLAS@) -set(GGML_ACCELERATE @GGML_ACCELERATE@) - @PACKAGE_INIT@ set_and_check(WHISPER_INCLUDE_DIR "@PACKAGE_WHISPER_INCLUDE_INSTALL_DIR@") set_and_check(WHISPER_LIB_DIR "@PACKAGE_WHISPER_LIB_INSTALL_DIR@") set_and_check(WHISPER_BIN_DIR "@PACKAGE_WHISPER_BIN_INSTALL_DIR@") -# Ensure transient dependencies satisfied - -find_package(Threads REQUIRED) - -if (APPLE AND GGML_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) -endif() - -if (GGML_BLAS) - find_package(BLAS REQUIRED) -endif() - -if (GGML_CUDA) - find_package(CUDAToolkit REQUIRED) -endif() - -if (GGML_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) -endif() - -if (GGML_HIPBLAS) - find_package(hip REQUIRED) - find_package(hipblas REQUIRED) - find_package(rocblas REQUIRED) -endif() +find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) find_library(whisper_LIBRARY whisper REQUIRED - HINTS ${WHISPER_LIB_DIR}) - -set(_whisper_link_deps "Threads::Threads" "@WHISPER_EXTRA_LIBS@") -set(_whisper_transient_defines "@WHISPER_TRANSIENT_DEFINES@") + HINTS ${WHISPER_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) add_library(whisper UNKNOWN IMPORTED) - set_target_properties(whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${WHISPER_INCLUDE_DIR}" - INTERFACE_LINK_LIBRARIES "${_whisper_link_deps}" - INTERFACE_COMPILE_DEFINITIONS "${_whisper_transient_defines}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LOCATION "${whisper_LIBRARY}" INTERFACE_COMPILE_FEATURES cxx_std_11 diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 9a54742f..4e84c1b2 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -77,6 +77,7 @@ struct whisper_params { bool log_score = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool carry_initial_prompt = false; @@ -129,6 +130,10 @@ static char * requires_value_error(const std::string & arg) { } static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -195,6 +200,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -276,6 +282,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); @@ -1003,6 +1010,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index c42b644f..6f02a250 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -73,6 +73,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_IQ1_M: case GGML_FTYPE_MOSTLY_BF16: case GGML_FTYPE_MOSTLY_MXFP4: + case GGML_FTYPE_MOSTLY_NVFP4: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -213,6 +214,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); diff --git a/examples/miniaudio.h b/examples/miniaudio.h index c74bebeb..24e676bb 100644 --- a/examples/miniaudio.h +++ b/examples/miniaudio.h @@ -1,6 +1,6 @@ /* Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file. -miniaudio - v0.11.22 - 2025-02-24 +miniaudio - v0.11.24 - 2026-01-17 David Reid - mackron@gmail.com @@ -12,18 +12,10 @@ GitHub: https://github.com/mackron/miniaudio /* 1. Introduction =============== -To use miniaudio, include "miniaudio.h": - - ```c - #include "miniaudio.h" - ``` - -The implementation is contained in "miniaudio.c". Just compile this like any other source file. You -can include miniaudio.c if you want to compile your project as a single translation unit: - - ```c - #include "miniaudio.c" - ``` +To use miniaudio, just include "miniaudio.h" like any other header and add "miniaudio.c" to your +source tree. If you don't want to add it to your source tree you can compile and link to it like +any other library. Note that ABI compatibility is not guaranteed between versions, even with bug +fix releases, so take care if compiling as a shared object. miniaudio includes both low level and high level APIs. The low level API is good for those who want to do all of their mixing themselves and only require a light weight interface to the underlying @@ -303,7 +295,7 @@ The engine encapsulates both the resource manager and the node graph to create a use high level API. The resource manager and node graph APIs are covered in more later sections of this manual. -The code below shows how you can initialize an engine using it's default configuration. +The code below shows how you can initialize an engine using its default configuration. ```c ma_result result; @@ -391,7 +383,7 @@ Sounds are not started by default. Start a sound with `ma_sound_start()` and sto `ma_sound_stop()`. When a sound is stopped, it is not rewound to the start. Use `ma_sound_seek_to_pcm_frame(&sound, 0)` to seek back to the start of a sound. By default, starting and stopping sounds happens immediately, but sometimes it might be convenient to schedule the sound -the be started and/or stopped at a specific time. This can be done with the following functions: +to be started and/or stopped at a specific time. This can be done with the following functions: ```c ma_sound_set_start_time_in_pcm_frames() @@ -463,6 +455,11 @@ is at the end, use `ma_sound_at_end()`. Looping of a sound can be controlled wit miniaudio should work cleanly out of the box without the need to download or install any dependencies. See below for platform-specific details. +This library has been designed to be added directly to your source tree which is the preferred way +of using it, but you can compile it as a normal library if that's your preference. Be careful if +compiling as a shared object because miniaudio is not ABI compatible between any release, including +bug fix releases. It's recommended you link statically. + Note that GCC and Clang require `-msse2`, `-mavx2`, etc. for SIMD optimizations. If you get errors about undefined references to `__sync_val_compare_and_swap_8`, `__atomic_load_8`, @@ -532,7 +529,7 @@ you'll need to disable run-time linking with `MA_NO_RUNTIME_LINKING` and link wi The Emscripten build emits Web Audio JavaScript directly and should compile cleanly out of the box. You cannot use `-std=c*` compiler flags, nor `-ansi`. -You can enable the use of AudioWorkets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling +You can enable the use of AudioWorklets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling with the following options: -sAUDIO_WORKLET=1 -sWASM_WORKERS=1 -sASYNCIFY @@ -881,7 +878,7 @@ read data within a certain range of the underlying data. To do this you can use This is useful if you have a sound bank where many sounds are stored in the same file and you want the data source to only play one of those sub-sounds. Note that once the range is set, everything -that takes a position, such as cursors and loop points, should always be relatvie to the start of +that takes a position, such as cursors and loop points, should always be relative to the start of the range. When the range is set, any previously defined loop point will be reset. Custom loop points can also be used with data sources. By default, data sources will loop after @@ -889,7 +886,7 @@ they reach the end of the data source, but if you need to loop at a specific loc the following: ```c - result = ma_data_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); + result = ma_data_source_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); if (result != MA_SUCCESS) { return result; // Failed to set the loop point. } @@ -3750,7 +3747,7 @@ extern "C" { #define MA_VERSION_MAJOR 0 #define MA_VERSION_MINOR 11 -#define MA_VERSION_REVISION 22 +#define MA_VERSION_REVISION 24 #define MA_VERSION_STRING MA_XSTRINGIFY(MA_VERSION_MAJOR) "." MA_XSTRINGIFY(MA_VERSION_MINOR) "." MA_XSTRINGIFY(MA_VERSION_REVISION) #if defined(_MSC_VER) && !defined(__clang__) @@ -3857,37 +3854,65 @@ typedef ma_uint16 wchar_t; #define MA_SIZE_MAX 0xFFFFFFFF /* When SIZE_MAX is not defined by the standard library just default to the maximum 32-bit unsigned integer. */ #endif +#define MA_UINT64_MAX (((ma_uint64)0xFFFFFFFF << 32) | (ma_uint64)0xFFFFFFFF) /* Weird shifting syntax is for VC6 compatibility. */ + /* Platform/backend detection. */ -#if defined(_WIN32) || defined(__COSMOPOLITAN__) +#if defined(_WIN32) #define MA_WIN32 #if defined(MA_FORCE_UWP) || (defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PC_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PC_APP) || (defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) #define MA_WIN32_UWP #elif defined(WINAPI_FAMILY) && (defined(WINAPI_FAMILY_GAMES) && WINAPI_FAMILY == WINAPI_FAMILY_GAMES) #define MA_WIN32_GDK + #elif defined(NXDK) + #define MA_WIN32_NXDK #else #define MA_WIN32_DESKTOP #endif + + /* The original Xbox. */ + #if defined(NXDK) /* <-- Add other Xbox compiler toolchains here, and then add a toolchain-specific define in case we need to discriminate between them later. */ + #define MA_XBOX + + #if defined(NXDK) + #define MA_XBOX_NXDK + #endif + #endif #endif -#if !defined(_WIN32) /* If it's not Win32, assume POSIX. */ +#if defined(__MSDOS__) || defined(MSDOS) || defined(_MSDOS) || defined(__DOS__) + #define MA_DOS + + /* No threading allowed on DOS. */ + #ifndef MA_NO_THREADING + #define MA_NO_THREADING + #endif + + /* No runtime linking allowed on DOS. */ + #ifndef MA_NO_RUNTIME_LINKING + #define MA_NO_RUNTIME_LINKING + #endif +#endif +#if !defined(MA_WIN32) && !defined(MA_DOS) /* If it's not Win32, assume POSIX. */ #define MA_POSIX - /* - Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. - You can use this to avoid including pthread.h in the header section. The downside is that it - results in some fixed sized structures being declared for the various types that are used in - miniaudio. The risk here is that these types might be too small for a given platform. This - risk is yours to take and no support will be offered if you enable this option. - */ - #ifndef MA_NO_PTHREAD_IN_HEADER - #include /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ - typedef pthread_t ma_pthread_t; - typedef pthread_mutex_t ma_pthread_mutex_t; - typedef pthread_cond_t ma_pthread_cond_t; - #else - typedef ma_uintptr ma_pthread_t; - typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; - typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #if !defined(MA_NO_THREADING) + /* + Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. + You can use this to avoid including pthread.h in the header section. The downside is that it + results in some fixed sized structures being declared for the various types that are used in + miniaudio. The risk here is that these types might be too small for a given platform. This + risk is yours to take and no support will be offered if you enable this option. + */ + #ifndef MA_NO_PTHREAD_IN_HEADER + #include /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ + typedef pthread_t ma_pthread_t; + typedef pthread_mutex_t ma_pthread_mutex_t; + typedef pthread_cond_t ma_pthread_cond_t; + #else + typedef ma_uintptr ma_pthread_t; + typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; + typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #endif #endif #if defined(__unix__) @@ -3914,8 +3939,11 @@ typedef ma_uint16 wchar_t; #if defined(__PROSPERO__) #define MA_PROSPERO #endif - #if defined(__NX__) - #define MA_NX + #if defined(__3DS__) + #define MA_3DS + #endif + #if defined(__SWITCH__) || defined(__NX__) + #define MA_SWITCH #endif #if defined(__BEOS__) || defined(__HAIKU__) #define MA_BEOS @@ -3925,12 +3953,13 @@ typedef ma_uint16 wchar_t; #endif #endif -#if defined(__has_c_attribute) - #if __has_c_attribute(fallthrough) - #define MA_FALLTHROUGH [[fallthrough]] - #endif +#if !defined(MA_FALLTHROUGH) && defined(__cplusplus) && __cplusplus >= 201703L + #define MA_FALLTHROUGH [[fallthrough]] #endif -#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) && (defined(__clang__) || defined(__GNUC__)) +#if !defined(MA_FALLTHROUGH) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 202000L + #define MA_FALLTHROUGH [[fallthrough]] +#endif +#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) #if __has_attribute(fallthrough) #define MA_FALLTHROUGH __attribute__((fallthrough)) #endif @@ -3967,7 +3996,7 @@ typedef ma_uint16 wchar_t; #define MA_NO_INLINE __attribute__((noinline)) #else #define MA_INLINE MA_GNUC_INLINE_HINT - #define MA_NO_INLINE __attribute__((noinline)) + #define MA_NO_INLINE #endif #elif defined(__WATCOMC__) #define MA_INLINE __inline @@ -4153,9 +4182,13 @@ typedef enum MA_CHANNEL_AUX_29 = 49, MA_CHANNEL_AUX_30 = 50, MA_CHANNEL_AUX_31 = 51, + + /* Count. */ + MA_CHANNEL_POSITION_COUNT, + + /* Aliases. */ MA_CHANNEL_LEFT = MA_CHANNEL_FRONT_LEFT, MA_CHANNEL_RIGHT = MA_CHANNEL_FRONT_RIGHT, - MA_CHANNEL_POSITION_COUNT = (MA_CHANNEL_AUX_31 + 1) } _ma_channel_position; /* Do not use `_ma_channel_position` directly. Use `ma_channel` instead. */ typedef enum @@ -4350,7 +4383,7 @@ typedef struct typedef struct { - ma_int32 state; + ma_uint32 state; } ma_lcg; @@ -6569,22 +6602,18 @@ This section contains the APIs for device playback and capture. Here is where yo ************************************************************************************************************************************************************/ #ifndef MA_NO_DEVICE_IO /* Some backends are only supported on certain platforms. */ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #define MA_SUPPORT_WASAPI #if defined(MA_WIN32_DESKTOP) /* DirectSound and WinMM backends are only supported on desktops. */ #define MA_SUPPORT_DSOUND #define MA_SUPPORT_WINMM - - /* Don't enable JACK here if compiling with Cosmopolitan. It'll be enabled in the Linux section below. */ - #if !defined(__COSMOPOLITAN__) - #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ - #endif + #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ #endif #endif #if defined(MA_UNIX) && !defined(MA_ORBIS) && !defined(MA_PROSPERO) #if defined(MA_LINUX) - #if !defined(MA_ANDROID) && !defined(__COSMOPOLITAN__) /* ALSA is not supported on Android. */ + #if !defined(MA_ANDROID) && !defined(MA_EMSCRIPTEN) /* ALSA is not supported on Android. */ #define MA_SUPPORT_ALSA #endif #endif @@ -7426,6 +7455,7 @@ struct ma_context ma_proc snd_pcm_hw_params_set_rate_resample; ma_proc snd_pcm_hw_params_set_rate; ma_proc snd_pcm_hw_params_set_rate_near; + ma_proc snd_pcm_hw_params_set_rate_minmax; ma_proc snd_pcm_hw_params_set_buffer_size_near; ma_proc snd_pcm_hw_params_set_periods_near; ma_proc snd_pcm_hw_params_set_access; @@ -7986,6 +8016,7 @@ struct ma_device /*AAudioStream**/ ma_ptr pStreamPlayback; /*AAudioStream**/ ma_ptr pStreamCapture; ma_mutex rerouteLock; + ma_atomic_bool32 isTearingDown; ma_aaudio_usage usage; ma_aaudio_content_type contentType; ma_aaudio_input_preset inputPreset; @@ -9644,7 +9675,7 @@ Parameters ---------- pBackends (out, optional) A pointer to the buffer that will receive the enabled backends. Set to NULL to retrieve the backend count. Setting - the capacity of the buffer to `MA_BUFFER_COUNT` will guarantee it's large enough for all backends. + the capacity of the buffer to `MA_BACKEND_COUNT` will guarantee it's large enough for all backends. backendCap (in) The capacity of the `pBackends` buffer. @@ -10489,6 +10520,7 @@ typedef struct ma_decoding_backend_vtable** ppCustomDecodingBackendVTables; ma_uint32 customDecodingBackendCount; void* pCustomDecodingBackendUserData; + ma_resampler_config resampling; } ma_resource_manager_config; MA_API ma_resource_manager_config ma_resource_manager_config_init(void); @@ -10816,6 +10848,7 @@ MA_API ma_result ma_node_graph_read_pcm_frames(ma_node_graph* pNodeGraph, void* MA_API ma_uint32 ma_node_graph_get_channels(const ma_node_graph* pNodeGraph); MA_API ma_uint64 ma_node_graph_get_time(const ma_node_graph* pNodeGraph); MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 globalTime); +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph); @@ -11123,6 +11156,7 @@ typedef struct ma_bool8 isPitchDisabled; /* Pitching can be explicitly disabled with MA_SOUND_FLAG_NO_PITCH to optimize processing. */ ma_bool8 isSpatializationDisabled; /* Spatialization can be explicitly disabled with MA_SOUND_FLAG_NO_SPATIALIZATION. */ ma_uint8 pinnedListenerIndex; /* The index of the listener this node should always use for spatialization. If set to MA_LISTENER_INDEX_CLOSEST the engine will use the closest listener. */ + ma_resampler_config resampling; } ma_engine_node_config; MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_engine_node_type type, ma_uint32 flags); @@ -11137,7 +11171,7 @@ typedef struct ma_uint32 volumeSmoothTimeInPCMFrames; ma_mono_expansion_mode monoExpansionMode; ma_fader fader; - ma_linear_resampler resampler; /* For pitch shift. */ + ma_resampler resampler; /* For pitch shift. */ ma_spatializer spatializer; ma_panner panner; ma_gainer volumeGainer; /* This will only be used if volumeSmoothTimeInPCMFrames is > 0. */ @@ -11193,6 +11227,7 @@ typedef struct ma_uint64 loopPointEndInPCMFrames; ma_sound_end_proc endCallback; /* Fired when the sound reaches the end. Will be fired from the audio thread. Do not restart, uninitialize or otherwise change the state of the sound from here. Instead fire an event or set a variable to indicate to a different thread to change the start of the sound. Will not be fired in response to a scheduled stop with ma_sound_set_stop_time_*(). */ void* pEndCallbackUserData; + ma_resampler_config pitchResampling; #ifndef MA_NO_RESOURCE_MANAGER ma_resource_manager_pipeline_notifications initNotifications; #endif @@ -11211,7 +11246,10 @@ struct ma_sound MA_ATOMIC(4, ma_bool32) atEnd; ma_sound_end_proc endCallback; void* pEndCallbackUserData; - ma_bool8 ownsDataSource; + float* pProcessingCache; /* Will be null if pDataSource is null. */ + ma_uint32 processingCacheFramesRemaining; + ma_uint32 processingCacheCap; + ma_bool8 ownsDataSource; /* We're declaring a resource manager data source object here to save us a malloc when loading a @@ -11255,7 +11293,7 @@ typedef struct ma_log* pLog; /* When set to NULL, will use the context's log. */ ma_uint32 listenerCount; /* Must be between 1 and MA_ENGINE_MAX_LISTENERS. */ ma_uint32 channels; /* The number of channels to use when mixing and spatializing. When set to 0, will use the native channel count of the device. */ - ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native channel count of the device. */ + ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native sample rate of the device. */ ma_uint32 periodSizeInFrames; /* If set to something other than 0, updates will always be exactly this size. The underlying device may be a different size, but from the perspective of the mixer that won't matter.*/ ma_uint32 periodSizeInMilliseconds; /* Used if periodSizeInFrames is unset. */ ma_uint32 gainSmoothTimeInFrames; /* The number of frames to interpolate the gain of spatialized sounds across. If set to 0, will use gainSmoothTimeInMilliseconds. */ @@ -11269,6 +11307,8 @@ typedef struct ma_vfs* pResourceManagerVFS; /* A pointer to a pre-allocated VFS object to use with the resource manager. This is ignored if pResourceManager is not NULL. */ ma_engine_process_proc onProcess; /* Fired at the end of each call to ma_engine_read_pcm_frames(). For engine's that manage their own internal device (the default configuration), this will be fired from the audio thread, and you do not need to call ma_engine_read_pcm_frames() manually in order to trigger this. */ void* pProcessUserData; /* User data that's passed into onProcess. */ + ma_resampler_config resourceManagerResampling; /* The resampling config to use with the resource manager. */ + ma_resampler_config pitchResampling; /* The resampling config for the pitch and Doppler effects. You will typically want this to be a fast resampler. For high quality stuff, it's recommended that you pre-resample. */ } ma_engine_config; MA_API ma_engine_config ma_engine_config_init(void); @@ -11298,6 +11338,7 @@ struct ma_engine ma_mono_expansion_mode monoExpansionMode; ma_engine_process_proc onProcess; void* pProcessUserData; + ma_resampler_config pitchResamplingConfig; }; MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEngine); @@ -11358,8 +11399,12 @@ MA_API ma_engine* ma_sound_get_engine(const ma_sound* pSound); MA_API ma_data_source* ma_sound_get_data_source(const ma_sound* pSound); MA_API ma_result ma_sound_start(ma_sound* pSound); MA_API ma_result ma_sound_stop(ma_sound* pSound); -MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ -MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ +MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API void ma_sound_reset_start_time(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time(ma_sound* pSound); +MA_API void ma_sound_reset_fade(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound); /* Resets fades and scheduled stop time. Does not seek back to the start. */ MA_API void ma_sound_set_volume(ma_sound* pSound, float volume); MA_API float ma_sound_get_volume(const ma_sound* pSound); MA_API void ma_sound_set_pan(ma_sound* pSound, float pan); @@ -11419,11 +11464,11 @@ MA_API ma_bool32 ma_sound_is_looping(const ma_sound* pSound); MA_API ma_bool32 ma_sound_at_end(const ma_sound* pSound); MA_API ma_result ma_sound_seek_to_pcm_frame(ma_sound* pSound, ma_uint64 frameIndex); /* Just a wrapper around ma_data_source_seek_to_pcm_frame(). */ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeconds); /* Abstraction to ma_sound_seek_to_pcm_frame() */ -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor); -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength); -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor); -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength); +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor); +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength); +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor); +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength); MA_API ma_result ma_sound_set_end_callback(ma_sound* pSound, ma_sound_end_proc callback, void* pUserData); MA_API ma_result ma_sound_group_init(ma_engine* pEngine, ma_uint32 flags, ma_sound_group* pParentGroup, ma_sound_group* pGroup); @@ -11544,17 +11589,23 @@ IMPLEMENTATION #endif #if !defined(MA_WIN32) -#include -#include /* select() (used for ma_sleep()). */ -#include + #if !defined(MA_NO_THREADING) + #include + #include /* For pthreads. */ + #endif + + #include /* select() (used for ma_sleep()). */ + #include /* For nanosleep() */ + #include #endif -#ifdef MA_NX -#include /* For nanosleep() */ +/* For fstat(), etc. */ +#if defined(MA_XBOX_NXDK) + #include /* Suggestion for NXDK: Add a sys/stat.h wrapper for compatibility. */ +#else + #include #endif -#include /* For fstat(), etc. */ - #ifdef MA_EMSCRIPTEN #include #endif @@ -11606,7 +11657,7 @@ IMPLEMENTATION #endif /* Intrinsics Support */ -#if (defined(MA_X64) || defined(MA_X86)) && !defined(__COSMOPOLITAN__) +#if defined(MA_X64) || defined(MA_X86) #if defined(_MSC_VER) && !defined(__clang__) /* MSVC. */ #if _MSC_VER >= 1400 && !defined(MA_NO_SSE2) /* 2005 */ @@ -11861,7 +11912,7 @@ static MA_INLINE ma_bool32 ma_has_neon(void) #endif #ifndef MA_RESTRICT - #if defined(__clang__) || defined(__GNUC__) || defined(_MSC_VER) + #if defined(__clang__) || defined(_MSC_VER) || (defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 95))) #define MA_RESTRICT __restrict #else #define MA_RESTRICT @@ -11955,7 +12006,7 @@ static void ma_sleep__posix(ma_uint32 milliseconds) (void)milliseconds; MA_ASSERT(MA_FALSE); /* The Emscripten build should never sleep. */ #else - #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_NX) + #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_SWITCH) struct timespec ts; ts.tv_sec = milliseconds / 1000; ts.tv_nsec = milliseconds % 1000 * 1000000; @@ -11997,7 +12048,7 @@ static MA_INLINE void ma_yield(void) #endif #endif #else - __asm__ __volatile__ ("pause"); + __asm__ __volatile__ ("rep; nop"); #endif #elif (defined(__arm__) && defined(__ARM_ARCH) && __ARM_ARCH >= 7) || defined(_M_ARM64) || (defined(_M_ARM) && _M_ARM >= 7) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6T2__) /* ARM */ @@ -12020,7 +12071,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) { unsigned int prevState; - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio don't support the "safe" versions of _controlfp_s(). I don't @@ -12043,7 +12094,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { prevState = _mm_getcsr(); _mm_setcsr(prevState | MA_MM_DENORMALS_ZERO_MASK | MA_MM_FLUSH_ZERO_MASK); @@ -12067,7 +12118,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) static MA_INLINE void ma_restore_denormals(unsigned int prevState) { - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio do not support _controlfp_s(). See ma_disable_denormals(). */ #if _MSC_VER <= 1200 @@ -12083,7 +12134,7 @@ static MA_INLINE void ma_restore_denormals(unsigned int prevState) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { _mm_setcsr(prevState); } @@ -12719,6 +12770,29 @@ MA_API MA_NO_INLINE int ma_strcmp(const char* str1, const char* str2) return ((unsigned char*)str1)[0] - ((unsigned char*)str2)[0]; } +MA_API MA_NO_INLINE int ma_wcscmp(const wchar_t* str1, const wchar_t* str2) +{ + if (str1 == str2) return 0; + + /* These checks differ from the standard implementation. It's not important, but I prefer it just for sanity. */ + if (str1 == NULL) return -1; + if (str2 == NULL) return 1; + + for (;;) { + if (str1[0] == L'\0') { + break; + } + if (str1[0] != str2[0]) { + break; + } + + str1 += 1; + str2 += 1; + } + + return ((unsigned short*)str1)[0] - ((unsigned short*)str2)[0]; +} + MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA, const char* srcB) { int result; @@ -12736,6 +12810,22 @@ MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA return result; } +MA_API MA_NO_INLINE size_t ma_wcslen(const wchar_t* str) +{ + const wchar_t* end; + + if (str == NULL) { + return 0; + } + + end = str; + while (end[0] != '\0') { + end += 1; + } + + return end - str; +} + MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_callbacks* pAllocationCallbacks) { size_t sz; @@ -12758,7 +12848,7 @@ MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_ca MA_API MA_NO_INLINE wchar_t* ma_copy_string_w(const wchar_t* src, const ma_allocation_callbacks* pAllocationCallbacks) { - size_t sz = wcslen(src)+1; + size_t sz = ma_wcslen(src)+1; wchar_t* dst = (wchar_t*)ma_malloc(sz * sizeof(*dst), pAllocationCallbacks); if (dst == NULL) { return NULL; @@ -13189,7 +13279,7 @@ MA_API ma_result ma_fopen(FILE** ppFile, const char* pFilePath, const char* pOpe return MA_INVALID_ARGS; } -#if defined(_MSC_VER) && _MSC_VER >= 1400 +#if (defined(_MSC_VER) && _MSC_VER >= 1400) && !defined(MA_XBOX_NXDK) err = fopen_s(ppFile, pFilePath, pOpenMode); if (err != 0) { return ma_result_from_errno(err); @@ -13231,7 +13321,7 @@ _wfopen() isn't always available in all compilation environments. This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. */ -#if defined(_WIN32) +#if defined(_WIN32) && !defined(MA_XBOX_NXDK) #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) #define MA_HAS_WFOPEN #endif @@ -13247,29 +13337,34 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ return MA_INVALID_ARGS; } -#if defined(MA_HAS_WFOPEN) + #if defined(MA_HAS_WFOPEN) { /* Use _wfopen() on Windows. */ - #if defined(_MSC_VER) && _MSC_VER >= 1400 - errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); - if (err != 0) { - return ma_result_from_errno(err); + #if defined(_MSC_VER) && _MSC_VER >= 1400 + { + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return ma_result_from_errno(err); + } } - #else - *ppFile = _wfopen(pFilePath, pOpenMode); - if (*ppFile == NULL) { - return ma_result_from_errno(errno); + #else + { + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return ma_result_from_errno(errno); + } } - #endif + #endif + (void)pAllocationCallbacks; } -#else - /* - Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can - think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for - maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. - */ + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) /* If your compiler does not support wcsrtombs(), add it here. */ { + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ mbstate_t mbs; size_t lenMB; const wchar_t* pFilePathTemp = pFilePath; @@ -13310,11 +13405,16 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ ma_free(pFilePathMB, pAllocationCallbacks); } + #else + { + /* Getting here means there is no way to open the file with a wide character string. */ + *ppFile = NULL; + } + #endif if (*ppFile == NULL) { return MA_ERROR; } -#endif return MA_SUCCESS; } @@ -13323,7 +13423,7 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_COPY_MEMORY(dst, src, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13343,7 +13443,7 @@ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 si static MA_INLINE void ma_zero_memory_64(void* dst, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_ZERO_MEMORY(dst, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13472,6 +13572,18 @@ static ma_result ma_allocation_callbacks_init_copy(ma_allocation_callbacks* pDst Logging **************************************************************************************************************************************************************/ +#ifndef ma_va_copy + #if !defined(_MSC_VER) || _MSC_VER >= 1800 + #if (defined(__GNUC__) && __GNUC__ < 3) + #define ma_va_copy(dst, src) ((dst) = (src)) /* This is untested. Not sure if this is correct for old GCC. */ + #else + #define ma_va_copy(dst, src) va_copy((dst), (src)) + #endif + #else + #define ma_va_copy(dst, src) ((dst) = (src)) + #endif +#endif + MA_API const char* ma_log_level_to_string(ma_uint32 logLevel) { switch (logLevel) @@ -13712,9 +13824,15 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat int length; char pFormattedMessageStack[1024]; char* pFormattedMessageHeap = NULL; + va_list args2; /* First try formatting into our fixed sized stack allocated buffer. If this is too small we'll fallback to a heap allocation. */ - length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args); + ma_va_copy(args2, args); + { + length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args2); + } + va_end(args2); + if (length < 0) { return MA_INVALID_OPERATION; /* An error occurred when trying to convert the buffer. */ } @@ -13755,17 +13873,10 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat char* pFormattedMessage = NULL; va_list args2; - #if _MSC_VER >= 1800 + ma_va_copy(args2, args); { - va_copy(args2, args); + formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); } - #else - { - args2 = args; - } - #endif - - formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); va_end(args2); if (formattedLen <= 0) { @@ -13964,7 +14075,7 @@ miniaudio's purposes. #define MA_LCG_A 48271 #define MA_LCG_C 0 -static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_seed() to use an explicit seed. */ +static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_lcg_seed() to use an explicit seed. */ static MA_INLINE void ma_lcg_seed(ma_lcg* pLCG, ma_int32 seed) { @@ -14013,7 +14124,7 @@ static MA_INLINE ma_int32 ma_lcg_rand_range_s32(ma_lcg* pLCG, ma_int32 lo, ma_in } - +#if 0 /* Currently unused. */ static MA_INLINE void ma_seed(ma_int32 seed) { ma_lcg_seed(&g_maLCG, seed); @@ -14038,6 +14149,7 @@ static MA_INLINE float ma_rand_f32(void) { return ma_lcg_rand_f32(&g_maLCG); } +#endif static MA_INLINE float ma_rand_range_f32(float lo, float hi) { @@ -14097,6 +14209,7 @@ Atomics **************************************************************************************************************************************************************/ /* c89atomic.h begin */ #ifndef ma_atomic_h +#define ma_atomic_h #if defined(__cplusplus) extern "C" { #endif @@ -14108,11 +14221,63 @@ extern "C" { #endif #endif typedef int ma_atomic_memory_order; -#define MA_ATOMIC_HAS_8 -#define MA_ATOMIC_HAS_16 -#define MA_ATOMIC_HAS_32 -#define MA_ATOMIC_HAS_64 -#if (defined(_MSC_VER) ) || defined(__WATCOMC__) || defined(__DMC__) +#if !defined(MA_ATOMIC_MODERN_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC_ASM) && \ + !defined(MA_ATOMIC_MODERN_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) || defined(__BORLANDC__) + #if (defined(_MSC_VER) && _MSC_VER > 1600) + #define MA_ATOMIC_MODERN_MSVC + #else + #if defined(MA_X64) + #define MA_ATOMIC_LEGACY_MSVC + #else + #define MA_ATOMIC_LEGACY_MSVC_ASM + #endif + #endif + #elif (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) || defined(__clang__) + #define MA_ATOMIC_MODERN_GCC + #else + #if defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 1)) + #define MA_ATOMIC_LEGACY_GCC + #else + #define MA_ATOMIC_LEGACY_GCC_ASM + #endif + #endif +#endif +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #include + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, src, order, intrin, ma_atomicType, msvcType) \ + switch (order) \ + { \ + case ma_atomic_memory_order_relaxed: \ + { \ + intrin##_nf((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_consume: \ + case ma_atomic_memory_order_acquire: \ + { \ + intrin##_acq((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_release: \ + { \ + intrin##_rel((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_acq_rel: \ + case ma_atomic_memory_order_seq_cst: \ + default: \ + { \ + intrin((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + } #define MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ @@ -14138,720 +14303,1501 @@ typedef int ma_atomic_memory_order; } break; \ } \ return result; - #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, desired, order, intrin, ma_atomicType, msvcType) \ + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, 1, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + return (ma_atomic_flag)_InterlockedExchange((volatile long*)dst, (long)1); + } + #endif + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, 0, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + _InterlockedExchange((volatile long*)dst, (long)0); + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, 1 + xchg [ecx], eax + mov result, eax + } + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov dword ptr [esi], 0 + } + } else { + __asm { + mov esi, dst + mov eax, 0 + xchg [esi], eax + } + } + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } +#endif +#if defined(MA_ATOMIC_MODERN_GCC) + #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED + #define ma_atomic_memory_order_consume __ATOMIC_CONSUME + #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE + #define ma_atomic_memory_order_release __ATOMIC_RELEASE + #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL + #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST + typedef ma_uint32 ma_atomic_flag; + #define ma_atomic_flag_test_and_set_explicit(dst, order) __atomic_exchange_n(dst, 1, order) + #define ma_atomic_flag_clear_explicit(dst, order) __atomic_store_n(dst, 0, order) + #define ma_atomic_flag_load_explicit(dst, order) __atomic_load_n(dst, order) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, 1); + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_release) { + __sync_synchronize(); + } + __sync_lock_release(dst); + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return __sync_val_compare_and_swap((ma_atomic_flag*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #if defined(MA_X86) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory") + #elif defined(MA_X64) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory") + #else + #error Unsupported architecture. + #endif + #define MA_ATOMIC_XCHG_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "xchg"instructionSizeSuffix" %0, %1" \ + : "=r"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_RELAXED_GCC_X86(instructionSizeSuffix, result, dst) \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + ) + #define MA_ATOMIC_LOAD_RELEASE_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_release); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_SEQ_CST_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ); \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result; + #if defined(MA_X86) || defined(MA_X64) + { + (void)order; + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, 1); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + ); + } else if (order == ma_atomic_memory_order_release) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + : + : "memory" + ); + } else { + ma_atomic_flag tmp = 0; + __asm__ __volatile__( + "xchgl %0, %1" + : "=r"(tmp), + "=m"(*dst) + : "0"(tmp), + "m"(*dst) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_atomic_flag result; + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } +#endif +#define ma_atomic_flag_test_and_set(dst) ma_atomic_flag_test_and_set_explicit(dst, ma_atomic_memory_order_acquire) +#define ma_atomic_flag_clear(dst) ma_atomic_flag_clear_explicit(dst, ma_atomic_memory_order_release) +typedef ma_atomic_flag ma_atomic_spinlock; +static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) +{ + for (;;) { + if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { + break; + } + while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { + } + } +} +static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) +{ + ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); +} +ma_atomic_spinlock ma_atomic_global_lock; +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC_ASM) || defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i486__)) || (defined(_M_IX86) && _M_IX86 >= 400))) + #if defined(MA_ATOMIC_LEGACY_MSVC) && defined(MA_X64) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i586__)) || (defined(_M_IX86) && _M_IX86 >= 500))) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #else + #endif + #else + #endif + #if defined(MA_ARM32) || defined(MA_ARM64) + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ARM64) || defined(__ARM_ARCH_7A__) || defined(__ARM_ARCH_7R__) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6Z__) || defined(__ARM_ARCH_6ZK__) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + #if defined(MA_ATOMIC_PPC32) || defined(MA_ATOMIC_PPC64) + #if (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 7))) && !defined(__clang__) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ATOMIC_PPC64) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + return 1; + #else + return 0; + #endif + } +#endif +#define MA_ATOMIC_COMPARE_AND_SWAP_LOCK(sizeInBits, dst, expected, replacement) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + if (result == expected) { \ + *dst = replacement; \ + } \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_LOAD_EXPLICIT_LOCK(sizeInBits, ptr, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *ptr; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_STORE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock) +#define MA_ATOMIC_STORE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order +#define MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_ADD_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst += src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_FETCH_ADD_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = oldValue + src; \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_AND_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue & src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_OR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue | src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_XOR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue ^ src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, replacement, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ { \ case ma_atomic_memory_order_relaxed: \ { \ - result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_consume: \ case ma_atomic_memory_order_acquire: \ { \ - result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_release: \ { \ - result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_acq_rel: \ case ma_atomic_memory_order_seq_cst: \ default: \ { \ - result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ } \ return result; - #define ma_atomic_memory_order_relaxed 0 - #define ma_atomic_memory_order_consume 1 - #define ma_atomic_memory_order_acquire 2 - #define ma_atomic_memory_order_release 3 - #define ma_atomic_memory_order_acq_rel 4 - #define ma_atomic_memory_order_seq_cst 5 - #if _MSC_VER < 1600 && defined(MA_X86) - #define MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY - #endif - #if _MSC_VER < 1600 - #undef MA_ATOMIC_HAS_8 - #undef MA_ATOMIC_HAS_16 - #endif - #if !defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #include - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) - { - ma_uint8 result = 0; - __asm { - mov ecx, dst - mov al, expected - mov dl, desired - lock cmpxchg [ecx], dl - mov result, al - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) - { - ma_uint16 result = 0; - __asm { - mov ecx, dst - mov ax, expected - mov dx, desired - lock cmpxchg [ecx], dx - mov result, ax - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) - { - ma_uint32 result = 0; - __asm { - mov ecx, dst - mov eax, expected - mov edx, desired - lock cmpxchg [ecx], edx - mov result, eax - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) - { - ma_uint32 resultEAX = 0; - ma_uint32 resultEDX = 0; - __asm { - mov esi, dst - mov eax, dword ptr expected - mov edx, dword ptr expected + 4 - mov ebx, dword ptr desired - mov ecx, dword ptr desired + 4 - lock cmpxchg8b qword ptr [esi] - mov resultEAX, eax - mov resultEDX, edx - } - return ((ma_uint64)resultEDX << 32) | resultEAX; - } - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + #define ma_atomic_compare_and_swap_8( dst, expected, replacement) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)replacement, (char)expected) #else - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_compare_and_swap_8( dst, expected, desired) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)desired, (char)expected) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)desired, (short)expected) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)desired, (long)expected) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)desired, (ma_int64)expected) - #endif + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xchg [ecx], al - mov result, al - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 result = 0; - (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xchg [ecx], ax - mov result, ax - } - return result; - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 result = 0; - (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xchg [ecx], eax - mov result, eax - } - return result; - } - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + #define ma_atomic_compare_and_swap_16(dst, expected, replacement) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)replacement, (short)expected) #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_32(dst, expected, replacement) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)replacement, (long)expected) + #else + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_64(dst, expected, replacement) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)replacement, (ma_int64)expected) + #else + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange8, ma_uint8, char); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); + } #else + { + (void)order; + return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); + } + #else + { + (void)order; + return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); + } + #else + { + (void)order; + return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + } + #else + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); + } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange8, ma_uint8, char); + } + #else + { (void)order; return (ma_uint8)_InterlockedExchange8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchange16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchange((volatile long*)dst, (long)src); - #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); - #else - (void)order; - return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); #endif - } - #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; } - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xadd [ecx], al - mov result, al - } - return result; - } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) { - ma_uint16 result = 0; - (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xadd [ecx], ax - mov result, ax - } - return result; + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 result = 0; - (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xadd [ecx], eax - mov result, eax - } - return result; - } - #endif - #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd8, ma_uint8, char); #else + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); + } + #else + { + (void)order; + return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); + } + #endif + } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd8, ma_uint8, char); + } + #else + { (void)order; return (ma_uint8)_InterlockedExchangeAdd8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchangeAdd16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchangeAdd((volatile long*)dst, (long)src); - #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); - #else - (void)order; - return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); #endif - } + } #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } + #else + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); + } + #else + { + (void)order; + return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); + } + #endif + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #endif + } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) + #if defined(MA_X64) + #define ma_atomic_thread_fence(order) __faststorefence(), (void)order + #elif defined(MA_ARM64) + #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order + #else + static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) + { + volatile ma_uint32 barrier = 0; + ma_atomic_fetch_add_explicit_32(&barrier, 0, order); } #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + #define ma_atomic_signal_fence(order) _ReadWriteBarrier(), (void)order +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { + ma_uint8 result = 0; + __asm { + mov ecx, dst + mov al, expected + mov dl, replacement + lock cmpxchg [ecx], dl + mov result, al + } + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + __asm { + mov ecx, dst + mov ax, expected + mov dx, replacement + lock cmpxchg [ecx], dx + mov result, ax + } + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + __asm { + mov ecx, dst + mov eax, expected + mov edx, replacement + lock cmpxchg [ecx], edx + mov result, eax + } + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + ma_uint32 resultEAX = 0; + ma_uint32 resultEDX = 0; + __asm { + mov esi, dst + mov eax, dword ptr expected + mov edx, dword ptr expected + 4 + mov ebx, dword ptr replacement + mov ecx, dword ptr replacement + 4 + lock cmpxchg8b qword ptr [esi] + mov resultEAX, eax + mov resultEDX, edx + } + return ((ma_uint64)resultEDX << 32) | resultEAX; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, [esi] + mov result, al + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov al, [esi] + lock add dword ptr [esp], 0 + mov result, al + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov al, [esi] + mov result, al + lock add dword ptr [esp], 0 + } + } + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, [esi] + mov result, ax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov ax, [esi] + lock add dword ptr [esp], 0 + mov result, ax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov ax, [esi] + mov result, ax + lock add dword ptr [esp], 0 + } + } + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, src + mov [esi], al + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + __asm { + mov esi, dst + mov al, src + xchg [esi], al + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, src + mov [esi], ax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + __asm { + mov esi, dst + mov ax, src + xchg [esi], ax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, src + mov [esi], eax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + __asm { + mov esi, dst + mov eax, src + xchg [esi], eax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; (void)order; __asm { - lock add [esp], 0 + mov ecx, dst + mov al, src + lock xchg [ecx], al + mov result, al } + return result; } - #else - #if defined(MA_X64) - #define ma_atomic_thread_fence(order) __faststorefence(), (void)order - #elif defined(MA_ARM64) - #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order #else - static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) - { - volatile ma_uint32 barrier = 0; - ma_atomic_fetch_add_explicit_32(&barrier, 0, order); + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + (void)order; + __asm { + mov ecx, dst + mov ax, src + lock xchg [ecx], ax + mov result, ax } - #endif - #endif - #define ma_atomic_compiler_fence() ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); + return result; + } #else - (void)order; - return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, src + xchg [ecx], eax + mov result, eax + } + return result; + } #else - (void)order; - return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); + } #else - (void)order; - return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + lock xadd [ecx], al + mov result, al + } + return result; + } #else - (void)order; - return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } #endif - } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); + ma_uint16 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov ax, src + lock xadd [ecx], ax + mov result, ax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); - #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); - #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); + ma_uint32 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov eax, src + lock xadd [ecx], eax + mov result, eax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); - #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); - #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); - #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); - #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); + ma_uint8 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov al, src + neg al + lock xadd [ecx], al + mov result, al + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); - #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); - #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); + ma_uint16 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov ax, src + neg ax + lock xadd [ecx], ax + mov result, ax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); - #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); - #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); + ma_uint32 result = 0; (void)order; - return oldValue; - #endif + __asm { + mov ecx, dst + mov eax, src + neg eax + lock xadd [ecx], eax + mov result, eax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); - #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) - #else - typedef ma_uint32 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_32(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_32(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_32(ptr, order) - #endif -#elif defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + { + (void)order; + __asm { + lock add dword ptr [esp], 0 + } + } + #define ma_atomic_signal_fence(order) __asm {}; (void)order +#endif +#if defined(MA_ATOMIC_MODERN_GCC) #define MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE - #define MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE - #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED - #define ma_atomic_memory_order_consume __ATOMIC_CONSUME - #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE - #define ma_atomic_memory_order_release __ATOMIC_RELEASE - #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL - #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") #define ma_atomic_thread_fence(order) __atomic_thread_fence(order) #define ma_atomic_signal_fence(order) __atomic_signal_fence(order) #define ma_atomic_is_lock_free_8(ptr) __atomic_is_lock_free(1, ptr) #define ma_atomic_is_lock_free_16(ptr) __atomic_is_lock_free(2, ptr) #define ma_atomic_is_lock_free_32(ptr) __atomic_is_lock_free(4, ptr) #define ma_atomic_is_lock_free_64(ptr) __atomic_is_lock_free(8, ptr) - #define ma_atomic_test_and_set_explicit_8( dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) __atomic_store_n(dst, 0, order) #define ma_atomic_store_explicit_8( dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_16(dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_32(dst, src, order) __atomic_store_n(dst, src, order) @@ -14864,14 +15810,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_16(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_32(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_64(dst, src, order) __atomic_exchange_n(dst, src, order) - #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_8( dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_16(dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_32(dst, src, order) __atomic_fetch_add(dst, src, order) @@ -14892,19 +15838,19 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_16(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_32(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_64(dst, src, order) __atomic_fetch_and(dst, src, order) - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) @@ -14913,636 +15859,1134 @@ typedef int ma_atomic_memory_order; #pragma clang diagnostic ignored "-Watomic-alignment" #endif #endif - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) #pragma clang diagnostic pop #endif - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(dst, order) (ma_bool32)__atomic_test_and_set(dst, order) - #define ma_atomic_flag_clear_explicit(dst, order) __atomic_clear(dst, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) -#else - #define ma_atomic_memory_order_relaxed 1 - #define ma_atomic_memory_order_consume 2 - #define ma_atomic_memory_order_acquire 3 - #define ma_atomic_memory_order_release 4 - #define ma_atomic_memory_order_acq_rel 5 - #define ma_atomic_memory_order_seq_cst 6 - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") - #if defined(__GNUC__) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_signal_fence(order) __asm__ __volatile__("":::"memory") + #if defined(MA_ATOMIC_LEGACY_GCC) #define ma_atomic_thread_fence(order) __sync_synchronize(), (void)order - static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - if (order > ma_atomic_memory_order_acquire) { - __sync_synchronize(); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + return __sync_val_compare_and_swap(dst, expected, replacement); } - return __sync_lock_test_and_set(dst, src); + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - ma_uint16 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - ma_uint32 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_sub(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_or(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_xor(dst, src); - } - static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - (void)order; - return __sync_fetch_and_and(dst, src); - } - #define ma_atomic_compare_and_swap_8( dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #else - #if defined(MA_X86) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory", "cc") - #elif defined(MA_X64) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory", "cc") - #else - #error Unsupported architecture. Please submit a feature request. - #endif - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) - { - ma_uint8 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; - } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) - { - ma_uint16 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; - } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) - { - ma_uint32 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; - } - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) - { - volatile ma_uint64 result; - #if defined(MA_X86) - ma_uint32 resultEAX; - ma_uint32 resultEDX; - __asm__ __volatile__("push %%ebx; xchg %5, %%ebx; lock; cmpxchg8b %0; pop %%ebx" : "+m"(*dst), "=a"(resultEAX), "=d"(resultEDX) : "a"(expected & 0xFFFFFFFF), "d"(expected >> 32), "r"(desired & 0xFFFFFFFF), "c"(desired >> 32) : "cc"); - result = ((ma_uint64)resultEDX << 32) | resultEAX; - #elif defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 result; - (void)order; - #if defined(MA_X86) - do { - result = *dst; - } while (ma_atomic_compare_and_swap_64(dst, result, src) != result); - #elif defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - #if defined(MA_X86) - ma_uint64 oldValue; - ma_uint64 newValue; - (void)order; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - return oldValue; - #elif defined(MA_X64) - ma_uint64 result; - (void)order; - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - return result; - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - } - static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + #endif } + static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #endif + } + #elif defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define MA_ATOMIC_CMPXCHG_GCC_X86(instructionSizeSuffix, result, dst, expected, replacement) \ + __asm__ __volatile__( \ + "lock; cmpxchg"instructionSizeSuffix" %2, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "r"(replacement), \ + "0"(expected), \ + "m"(*dst) \ + : "cc", "memory") + #define MA_ATOMIC_XADD_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "lock; xadd"instructionSizeSuffix" %0, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "cc", "memory") + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("b", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("w", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("l", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X86) + { + ma_uint32 resultEAX; + ma_uint32 resultEDX; + __asm__ __volatile__( + "pushl %%ebx\n" + "movl %4, %%ebx\n" + "lock cmpxchg8b (%%edi)\n" + "popl %%ebx\n" + : "=a"(resultEAX), + "=d"(resultEDX) + : "a"((ma_uint32)(expected & 0xFFFFFFFF)), + "d"((ma_uint32)(expected >> 32)), + "r"((ma_uint32)(replacement & 0xFFFFFFFF)), + "c"((ma_uint32)(replacement >> 32)), + "D"(dst) + : "memory", "cc"); + result = ((ma_uint64)resultEDX << 32) | resultEAX; + } + #elif defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("q", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("b", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("b", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("b", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("w", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("w", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("w", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("q", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("q", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("q", result, dst); + } + } + #elif defined(MA_X86) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, dst, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("b", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("w", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + (void)order; + #if defined(MA_X86) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("q", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movb %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgb %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movw %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgw %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movl %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgl %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE void ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movq %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgq %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint8 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("b", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint16 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("w", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint32 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("l", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + ma_uint64 result; + MA_ATOMIC_XADD_GCC_X86("q", result, dst, src); + (void)order; + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #else + #error Unsupported compiler. #endif - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); - } - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); - } - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); - } - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); - } - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) #endif #if !defined(MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint8 expectedValue; - ma_uint8 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_8(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_8(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_8(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint16 expectedValue; - ma_uint16 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_16(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_16(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_16(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint32 expectedValue; - ma_uint32 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_32(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_32(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_32(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint64 expectedValue; - ma_uint64 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_64(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_64(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_64(expected, result, failureOrder); - return 0; - } - } - #endif - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) -#endif -#if !defined(MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE) - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; - } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) - { - (void)ptr; - return 1; - } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) - { - (void)ptr; - return 1; - } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) - { - (void)ptr; - #if defined(MA_64BIT) - return 1; - #else - #if defined(MA_X86) || defined(MA_X64) + ma_uint8 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_8(dst, *expected, replacement); + if (result == *expected) { return 1; - #else + } else { + *expected = result; return 0; - #endif - #endif + } } + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + { + ma_uint16 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_16(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } + } + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + { + ma_uint32 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_32(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } + } + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + { + ma_uint64 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_64(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } + } + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) #endif #if defined(MA_64BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15561,17 +17005,17 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, (ma_uint64)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)desired); + return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)replacement); } #elif defined(MA_32BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15590,36 +17034,26 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_32((volatile ma_uint32*)dst, (ma_uint32)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)desired); + return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)replacement); } #else #error Unsupported architecture. #endif -#define ma_atomic_flag_test_and_set(ptr) ma_atomic_flag_test_and_set_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_flag_clear(ptr) ma_atomic_flag_clear_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_ptr(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_ptr(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_8( ptr) ma_atomic_test_and_set_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_16(ptr) ma_atomic_test_and_set_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_32(ptr) ma_atomic_test_and_set_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_64(ptr) ma_atomic_test_and_set_explicit_64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_8( ptr) ma_atomic_clear_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_16(ptr) ma_atomic_clear_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_32(ptr) ma_atomic_clear_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_64(ptr) ma_atomic_clear_explicit_64(ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_ptr(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_ptr(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_8( dst, src) ma_atomic_store_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_16(dst, src) ma_atomic_store_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_32(dst, src) ma_atomic_store_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15632,14 +17066,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_16(dst, src) ma_atomic_exchange_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_32(dst, src) ma_atomic_exchange_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_64(dst, src) ma_atomic_exchange_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_16( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_32( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_64( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_16( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_32( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_64( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_8( dst, src) ma_atomic_fetch_add_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_16(dst, src) ma_atomic_fetch_add_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_32(dst, src) ma_atomic_fetch_add_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15660,14 +17094,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_16(dst, src) ma_atomic_fetch_and_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_32(dst, src) ma_atomic_fetch_and_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_64(dst, src) ma_atomic_fetch_and_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_explicit_i8( ptr, order) (ma_int8 )ma_atomic_test_and_set_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_test_and_set_explicit_i16(ptr, order) (ma_int16)ma_atomic_test_and_set_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_test_and_set_explicit_i32(ptr, order) (ma_int32)ma_atomic_test_and_set_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_test_and_set_explicit_i64(ptr, order) (ma_int64)ma_atomic_test_and_set_explicit_64((ma_uint64*)ptr, order) -#define ma_atomic_clear_explicit_i8( ptr, order) ma_atomic_clear_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_clear_explicit_i16(ptr, order) ma_atomic_clear_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_clear_explicit_i32(ptr, order) ma_atomic_clear_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_clear_explicit_i64(ptr, order) ma_atomic_clear_explicit_64((ma_uint64*)ptr, order) #define ma_atomic_store_explicit_i8( dst, src, order) ma_atomic_store_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_store_explicit_i16(dst, src, order) ma_atomic_store_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_store_explicit_i32(dst, src, order) ma_atomic_store_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15680,14 +17106,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_i16(dst, src, order) (ma_int16)ma_atomic_exchange_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_exchange_explicit_i32(dst, src, order) (ma_int32)ma_atomic_exchange_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_exchange_explicit_i64(dst, src, order) (ma_int64)ma_atomic_exchange_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_i8( dst, src, order) (ma_int8 )ma_atomic_fetch_add_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_fetch_add_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_add_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_add_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_add_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15708,14 +17134,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_and_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_and_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_and_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_fetch_and_explicit_i64(dst, src, order) (ma_int64)ma_atomic_fetch_and_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_test_and_set_i8( ptr) ma_atomic_test_and_set_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i16(ptr) ma_atomic_test_and_set_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i32(ptr) ma_atomic_test_and_set_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i64(ptr) ma_atomic_test_and_set_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i8( ptr) ma_atomic_clear_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i16(ptr) ma_atomic_clear_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i32(ptr) ma_atomic_clear_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i64(ptr) ma_atomic_clear_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i8( dst, src) ma_atomic_store_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i16(dst, src) ma_atomic_store_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i32(dst, src) ma_atomic_store_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15728,14 +17146,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_i16(dst, src) ma_atomic_exchange_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i32(dst, src) ma_atomic_exchange_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i64(dst, src) ma_atomic_exchange_explicit_i64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i16(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i16(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i8( dst, src) ma_atomic_fetch_add_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i16(dst, src) ma_atomic_fetch_add_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i32(dst, src) ma_atomic_fetch_add_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15812,28 +17230,28 @@ static MA_INLINE double ma_atomic_exchange_explicit_f64(volatile double* dst, do r.i = ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, x.i, order); return r.f; } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } static MA_INLINE float ma_atomic_fetch_add_explicit_f32(volatile float* dst, float src, ma_atomic_memory_order order) @@ -15924,10 +17342,10 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_load_f64(ptr) (double)ma_atomic_load_explicit_f64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f32(dst, src) (float )ma_atomic_exchange_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f64(dst, src) (double)ma_atomic_exchange_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f32(dst, src) ma_atomic_fetch_add_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f64(dst, src) ma_atomic_fetch_add_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_sub_f32(dst, src) ma_atomic_fetch_sub_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15938,39 +17356,24 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_fetch_xor_f64(dst, src) ma_atomic_fetch_xor_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f32(dst, src) ma_atomic_fetch_and_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f64(dst, src) ma_atomic_fetch_and_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float desired) +static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float replacement) { ma_atomic_if32 r; ma_atomic_if32 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, e.i, d.i); return r.f; } -static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double desired) +static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double replacement) { ma_atomic_if64 r; ma_atomic_if64 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, e.i, d.i); return r.f; } -typedef ma_atomic_flag ma_atomic_spinlock; -static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) -{ - for (;;) { - if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { - break; - } - while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { - } - } -} -static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) -{ - ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); -} #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) #pragma GCC diagnostic pop #endif @@ -16176,7 +17579,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int result; pthread_attr_t* pAttr = NULL; -#if !defined(__EMSCRIPTEN__) && !defined(__3DS__) +#if !defined(MA_EMSCRIPTEN) && !defined(MA_3DS) && !defined(MA_SWITCH) /* Try setting the thread priority. It's not critical if anything fails here. */ pthread_attr_t attr; if (pthread_attr_init(&attr) == 0) { @@ -16208,9 +17611,18 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } #endif - if (stackSize > 0) { - pthread_attr_setstacksize(&attr, stackSize); + #if defined(_POSIX_THREAD_ATTR_STACKSIZE) && _POSIX_THREAD_ATTR_STACKSIZE >= 0 + { + if (stackSize > 0) { + pthread_attr_setstacksize(&attr, stackSize); + } } + #else + { + (void)stackSize; /* Suppress unused parameter warning. */ + } + #endif + if (scheduler != -1) { int priorityMin = sched_get_priority_min(scheduler); @@ -16218,7 +17630,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int priorityStep = (priorityMax - priorityMin) / 7; /* 7 = number of priorities supported by miniaudio. */ struct sched_param sched; - if (pthread_attr_getschedparam(&attr, &sched) == 0) { + if (priorityMin != -1 && priorityMax != -1 && pthread_attr_getschedparam(&attr, &sched) == 0) { if (priority == ma_thread_priority_idle) { sched.sched_priority = priorityMin; } else if (priority == ma_thread_priority_realtime) { @@ -16267,6 +17679,21 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } if (result != 0) { + /* + There have been reports that attempting to create a realtime thread can sometimes fail. In this case, + fall back to a normal priority thread. + + I'm including a compile-time option here to disable this functionality for those who have a hard + requirement on realtime threads and would rather an explicit failure. + */ + #ifndef MA_NO_PTHREAD_REALTIME_PRIORITY_FALLBACK + { + if(result == EPERM && priority == ma_thread_priority_realtime) { + return ma_thread_create__posix(pThread, ma_thread_priority_normal, stackSize, entryProc, pData); + } + } + #endif + return ma_result_from_errno(result); } @@ -16538,7 +17965,7 @@ static ma_result ma_event_signal__win32(ma_event* pEvent) static ma_result ma_semaphore_init__win32(int initialValue, ma_semaphore* pSemaphore) { - *pSemaphore = CreateSemaphoreW(NULL, (LONG)initialValue, LONG_MAX, NULL); + *pSemaphore = CreateSemaphore(NULL, (LONG)initialValue, LONG_MAX, NULL); if (*pSemaphore == NULL) { return ma_result_from_GetLastError(GetLastError()); } @@ -17432,10 +18859,12 @@ static MA_INLINE ma_uint16 ma_job_extract_slot(ma_uint64 toc) return (ma_uint16)(toc & 0x0000FFFF); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_uint16 ma_job_extract_code(ma_uint64 toc) { return (ma_uint16)((toc & 0xFFFF0000) >> 16); } +#endif static MA_INLINE ma_uint64 ma_job_toc_to_allocation(ma_uint64 toc) { @@ -17900,6 +19329,13 @@ MA_API ma_result ma_job_queue_next(ma_job_queue* pQueue, ma_job* pJob) Dynamic Linking *******************************************************************************/ +/* Disable run-time linking on certain backends and platforms. */ +#ifndef MA_NO_RUNTIME_LINKING + #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) || defined(MA_SWITCH) || defined(MA_DOS) + #define MA_NO_RUNTIME_LINKING + #endif +#endif + #ifdef MA_POSIX /* No need for dlfcn.h if we're not using runtime linking. */ #ifndef MA_NO_RUNTIME_LINKING @@ -17909,104 +19345,124 @@ Dynamic Linking MA_API ma_handle ma_dlopen(ma_log* pLog, const char* filename) { -#ifndef MA_NO_RUNTIME_LINKING - ma_handle handle; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_handle handle; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); - #ifdef MA_WIN32 - /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ - #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) - handle = (ma_handle)LoadLibraryA(filename); + #ifdef MA_WIN32 + /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ + #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) + handle = (ma_handle)LoadLibraryA(filename); + #else + /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ + WCHAR filenameW[4096]; + if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { + handle = NULL; + } else { + handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); + } + #endif #else - /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ - WCHAR filenameW[4096]; - if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { - handle = NULL; - } else { - handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); - } + handle = (ma_handle)dlopen(filename, RTLD_NOW); #endif - #else - handle = (ma_handle)dlopen(filename, RTLD_NOW); - #endif - /* - I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority - backend is a deliberate design choice. Instead I'm logging it as an informational message. - */ - if (handle == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); + /* + I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority + backend is a deliberate design choice. Instead I'm logging it as an informational message. + */ + if (handle == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); + } + + return handle; } - - return handle; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)filename; - return NULL; -#endif + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)filename; + return NULL; + } + #endif } MA_API void ma_dlclose(ma_log* pLog, ma_handle handle) { -#ifndef MA_NO_RUNTIME_LINKING - #ifdef MA_WIN32 - FreeLibrary((HMODULE)handle); - #else - /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ - #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + #ifndef MA_NO_RUNTIME_LINKING + { + #ifdef MA_WIN32 { - dlclose((void*)handle); + FreeLibrary((HMODULE)handle); } #else { - (void)handle; + /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ + #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + { + dlclose((void*)handle); + } + #else + { + (void)handle; + } + #endif } #endif - #endif - (void)pLog; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; -#endif + (void)pLog; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + } + #endif } MA_API ma_proc ma_dlsym(ma_log* pLog, ma_handle handle, const char* symbol) { -#ifndef MA_NO_RUNTIME_LINKING - ma_proc proc; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_proc proc; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); -#ifdef _WIN32 - proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); -#else -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wpedantic" -#endif - proc = (ma_proc)dlsym((void*)handle, symbol); -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic pop -#endif -#endif + #ifdef _WIN32 + { + proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); + } + #else + { + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + #endif + proc = (ma_proc)dlsym((void*)handle, symbol); + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic pop + #endif + } + #endif - if (proc == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); + if (proc == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); + } + + (void)pLog; /* It's possible for pContext to be unused. */ + return proc; } - - (void)pLog; /* It's possible for pContext to be unused. */ - return proc; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; - (void)symbol; - return NULL; -#endif + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + (void)symbol; + return NULL; + } + #endif } @@ -18020,13 +19476,6 @@ DEVICE I/O ************************************************************************************************************************************************************* ************************************************************************************************************************************************************/ -/* Disable run-time linking on certain backends and platforms. */ -#ifndef MA_NO_RUNTIME_LINKING - #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) - #define MA_NO_RUNTIME_LINKING - #endif -#endif - #ifdef MA_APPLE #include #endif @@ -18039,12 +19488,6 @@ DEVICE I/O #ifdef MA_POSIX #include - #include - - /* No need for dlfcn.h if we're not using runtime linking. */ - #ifndef MA_NO_RUNTIME_LINKING - #include - #endif #endif /* This must be set to at least 26. */ @@ -18299,7 +19742,7 @@ MA_API ma_bool32 ma_is_loopback_supported(ma_backend backend) -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) /* WASAPI error codes. */ #define MA_AUDCLNT_E_NOT_INITIALIZED ((HRESULT)0x88890001) #define MA_AUDCLNT_E_ALREADY_INITIALIZED ((HRESULT)0x88890002) @@ -18514,6 +19957,11 @@ typedef LONG (WINAPI * MA_PFN_RegCloseKey)(HKEY hKey); typedef LONG (WINAPI * MA_PFN_RegQueryValueExA)(HKEY hKey, const char* lpValueName, DWORD* lpReserved, DWORD* lpType, BYTE* lpData, DWORD* lpcbData); #endif /* MA_WIN32_DESKTOP */ +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ + MA_API size_t ma_strlen_WCHAR(const WCHAR* str) { size_t len = 0; @@ -18577,7 +20025,7 @@ Timing *******************************************************************************/ #if defined(MA_WIN32) && !defined(MA_POSIX) static LARGE_INTEGER g_ma_TimerFrequency; /* <-- Initialized to zero since it's static. */ - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { LARGE_INTEGER counter; @@ -18589,7 +20037,7 @@ Timing pTimer->counter = counter.QuadPart; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { LARGE_INTEGER counter; if (!QueryPerformanceCounter(&counter)) { @@ -18600,7 +20048,7 @@ Timing } #elif defined(MA_APPLE) && (MAC_OS_X_VERSION_MIN_REQUIRED < 101200) static ma_uint64 g_ma_TimerFrequency = 0; - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { mach_timebase_info_data_t baseTime; mach_timebase_info(&baseTime); @@ -18609,7 +20057,7 @@ Timing pTimer->counter = mach_absolute_time(); } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter = mach_absolute_time(); ma_uint64 oldTimeCounter = pTimer->counter; @@ -18634,15 +20082,15 @@ Timing #define MA_CLOCK_ID CLOCK_REALTIME #endif - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - pTimer->counter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18650,21 +20098,21 @@ Timing struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - newTimeCounter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000000.0; } #else - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timeval newTime; gettimeofday(&newTime, NULL); - pTimer->counter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000) + newTime.tv_usec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18672,7 +20120,7 @@ Timing struct timeval newTime; gettimeofday(&newTime, NULL); - newTimeCounter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000) + newTime.tv_usec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000.0; @@ -19248,14 +20696,6 @@ static MA_INLINE void ma_device__set_state(ma_device* pDevice, ma_device_state n } -#if defined(MA_WIN32) - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ -#endif - - MA_API ma_uint32 ma_get_format_priority_index(ma_format format) /* Lower = better. */ { @@ -19967,7 +21407,7 @@ static ma_result ma_context_init__null(ma_context* pContext, const ma_context_co WIN32 COMMON *******************************************************************************/ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) #define ma_CoInitializeEx(pContext, pvReserved, dwCoInit) ((pContext->win32.CoInitializeEx) ? ((MA_PFN_CoInitializeEx)pContext->win32.CoInitializeEx)(pvReserved, dwCoInit) : ((MA_PFN_CoInitialize)pContext->win32.CoInitialize)(pvReserved)) #define ma_CoUninitialize(pContext) ((MA_PFN_CoUninitialize)pContext->win32.CoUninitialize)() @@ -19982,7 +21422,7 @@ WIN32 COMMON #define ma_PropVariantClear(pContext, pvar) PropVariantClear(pvar) #endif -#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) +#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) && !defined(MA_XBOX_NXDK) typedef size_t DWORD_PTR; #endif @@ -20409,11 +21849,21 @@ typedef enum MA_AudioCategory_Other = 0 /* <-- miniaudio is only caring about Other. */ } MA_AUDIO_STREAM_CATEGORY; +typedef enum +{ + MA_AUDCLNT_STREAMOPTIONS_NONE, + MA_AUDCLNT_STREAMOPTIONS_RAW, + MA_AUDCLNT_STREAMOPTIONS_MATCH_FORMAT, + MA_AUDCLNT_STREAMOPTIONS_AMBISONICS, + MA_AUDCLNT_STREAMOPTIONS_POST_VOLUME_LOOPBACK +} MA_AUDCLNT_STREAMOPTIONS; + typedef struct { ma_uint32 cbSize; BOOL bIsOffload; MA_AUDIO_STREAM_CATEGORY eCategory; + MA_AUDCLNT_STREAMOPTIONS Options; } ma_AudioClientProperties; /* IUnknown */ @@ -21588,6 +23038,7 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device { ma_IMMDeviceEnumerator* pDeviceEnumerator; HRESULT hr; + HRESULT CoInitializeResult; MA_ASSERT(pContext != NULL); MA_ASSERT(ppMMDevice != NULL); @@ -21601,12 +23052,17 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device The community has reported that this seems to fix the crash. There are future plans to move all WASAPI operation over to a single thread to make everything safer, but in the meantime while we wait for that to come online I'm happy enough to use this hack instead. + + CoUninitialize should only be called if we successfully initialized. S_OK and S_FALSE both mean that we need to + call CoUninitialize since the internal ref count was increased. RPC_E_CHANGED_MODE means that CoInitializeEx was + called with a different COINIT value, and we don't call CoUninitialize in that case. Other errors are possible, + so we check for S_OK and S_FALSE specifically. */ - ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); + CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); { hr = ma_CoCreateInstance(pContext, &MA_CLSID_MMDeviceEnumerator, NULL, CLSCTX_ALL, &MA_IID_IMMDeviceEnumerator, (void**)&pDeviceEnumerator); - } - ma_CoUninitialize(pContext); + } + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pContext); } if (FAILED(hr)) { /* <-- This is checking the call above to ma_CoCreateInstance(). */ ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[WASAPI] Failed to create IMMDeviceEnumerator.\n"); @@ -21950,7 +23406,7 @@ static ma_result ma_context_get_IAudioClient__wasapi(ma_context* pContext, ma_de pActivationParams = &activationParams; /* When requesting a specific device ID we need to use a special device ID. */ - MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ + MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (ma_wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ pDeviceID = &virtualDeviceID; } else { pActivationParams = NULL; /* No activation parameters required. */ @@ -26679,6 +28135,9 @@ typedef snd_pcm_channel_area_t ma_snd_pcm_channel_area_t; typedef snd_pcm_chmap_t ma_snd_pcm_chmap_t; typedef snd_pcm_state_t ma_snd_pcm_state_t; +/* snd_pcm_state_t */ +#define MA_SND_PCM_STATE_XRUN SND_PCM_STATE_XRUN + /* snd_pcm_stream_t */ #define MA_SND_PCM_STREAM_PLAYBACK SND_PCM_STREAM_PLAYBACK #define MA_SND_PCM_STREAM_CAPTURE SND_PCM_STREAM_CAPTURE @@ -26874,6 +28333,7 @@ typedef int (* ma_snd_pcm_hw_params_set_channels_minmax_proc) ( typedef int (* ma_snd_pcm_hw_params_set_rate_resample_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val); typedef int (* ma_snd_pcm_hw_params_set_rate_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val, int dir); typedef int (* ma_snd_pcm_hw_params_set_rate_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); +typedef int (* ma_snd_pcm_hw_params_set_rate_minmax_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *min, int *mindir, unsigned int *max, int *maxdir); typedef int (* ma_snd_pcm_hw_params_set_buffer_size_near_proc)(ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_uframes_t *val); typedef int (* ma_snd_pcm_hw_params_set_periods_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); typedef int (* ma_snd_pcm_hw_params_set_access_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_access_t _access); @@ -28640,8 +30100,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_hw_params_get_format_mask_proc _snd_pcm_hw_params_get_format_mask = snd_pcm_hw_params_get_format_mask; ma_snd_pcm_hw_params_set_channels_proc _snd_pcm_hw_params_set_channels = snd_pcm_hw_params_set_channels; ma_snd_pcm_hw_params_set_channels_near_proc _snd_pcm_hw_params_set_channels_near = snd_pcm_hw_params_set_channels_near; + ma_snd_pcm_hw_params_set_channels_minmax_proc _snd_pcm_hw_params_set_channels_minmax = snd_pcm_hw_params_set_channels_minmax; ma_snd_pcm_hw_params_set_rate_resample_proc _snd_pcm_hw_params_set_rate_resample = snd_pcm_hw_params_set_rate_resample; - ma_snd_pcm_hw_params_set_rate_near _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; + ma_snd_pcm_hw_params_set_rate_proc _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; ma_snd_pcm_hw_params_set_rate_near_proc _snd_pcm_hw_params_set_rate_near = snd_pcm_hw_params_set_rate_near; ma_snd_pcm_hw_params_set_rate_minmax_proc _snd_pcm_hw_params_set_rate_minmax = snd_pcm_hw_params_set_rate_minmax; ma_snd_pcm_hw_params_set_buffer_size_near_proc _snd_pcm_hw_params_set_buffer_size_near = snd_pcm_hw_params_set_buffer_size_near; @@ -28693,9 +30154,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_info_proc _snd_pcm_info = snd_pcm_info; ma_snd_pcm_info_sizeof_proc _snd_pcm_info_sizeof = snd_pcm_info_sizeof; ma_snd_pcm_info_get_name_proc _snd_pcm_info_get_name = snd_pcm_info_get_name; - ma_snd_pcm_poll_descriptors _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; - ma_snd_pcm_poll_descriptors_count _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; - ma_snd_pcm_poll_descriptors_revents _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; + ma_snd_pcm_poll_descriptors_proc _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; + ma_snd_pcm_poll_descriptors_count_proc _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; + ma_snd_pcm_poll_descriptors_revents_proc _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; ma_snd_config_update_free_global_proc _snd_config_update_free_global = snd_config_update_free_global; pContext->alsa.snd_pcm_open = (ma_proc)_snd_pcm_open; @@ -28711,6 +30172,7 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co pContext->alsa.snd_pcm_hw_params_set_rate_resample = (ma_proc)_snd_pcm_hw_params_set_rate_resample; pContext->alsa.snd_pcm_hw_params_set_rate = (ma_proc)_snd_pcm_hw_params_set_rate; pContext->alsa.snd_pcm_hw_params_set_rate_near = (ma_proc)_snd_pcm_hw_params_set_rate_near; + pContext->alsa.snd_pcm_hw_params_set_rate_minmax = (ma_proc)_snd_pcm_hw_params_set_rate_minmax; pContext->alsa.snd_pcm_hw_params_set_buffer_size_near = (ma_proc)_snd_pcm_hw_params_set_buffer_size_near; pContext->alsa.snd_pcm_hw_params_set_periods_near = (ma_proc)_snd_pcm_hw_params_set_periods_near; pContext->alsa.snd_pcm_hw_params_set_access = (ma_proc)_snd_pcm_hw_params_set_access; @@ -29436,7 +30898,7 @@ typedef void (* ma_pa_threaded_mainloop_unlock_proc) ( typedef void (* ma_pa_threaded_mainloop_wait_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_signal_proc) (ma_pa_threaded_mainloop* m, int wait_for_accept); typedef void (* ma_pa_threaded_mainloop_accept_proc) (ma_pa_threaded_mainloop* m); -typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (ma_pa_threaded_mainloop* m); +typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (const ma_pa_threaded_mainloop* m); typedef ma_pa_mainloop_api* (* ma_pa_threaded_mainloop_get_api_proc) (ma_pa_threaded_mainloop* m); typedef int (* ma_pa_threaded_mainloop_in_thread_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_set_name_proc) (ma_pa_threaded_mainloop* m, const char* name); @@ -29445,13 +30907,13 @@ typedef void (* ma_pa_context_unref_proc) ( typedef int (* ma_pa_context_connect_proc) (ma_pa_context* c, const char* server, ma_pa_context_flags_t flags, const ma_pa_spawn_api* api); typedef void (* ma_pa_context_disconnect_proc) (ma_pa_context* c); typedef void (* ma_pa_context_set_state_callback_proc) (ma_pa_context* c, ma_pa_context_notify_cb_t cb, void* userdata); -typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (ma_pa_context* c); +typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (const ma_pa_context* c); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_list_proc) (ma_pa_context* c, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_list_proc) (ma_pa_context* c, ma_pa_source_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_by_name_proc) (ma_pa_context* c, const char* name, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_by_name_proc)(ma_pa_context* c, const char* name, ma_pa_source_info_cb_t cb, void* userdata); typedef void (* ma_pa_operation_unref_proc) (ma_pa_operation* o); -typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (ma_pa_operation* o); +typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (const ma_pa_operation* o); typedef ma_pa_channel_map* (* ma_pa_channel_map_init_extend_proc) (ma_pa_channel_map* m, unsigned channels, ma_pa_channel_map_def_t def); typedef int (* ma_pa_channel_map_valid_proc) (const ma_pa_channel_map* m); typedef int (* ma_pa_channel_map_compatible_proc) (const ma_pa_channel_map* m, const ma_pa_sample_spec* ss); @@ -29460,12 +30922,12 @@ typedef void (* ma_pa_stream_unref_proc) ( typedef int (* ma_pa_stream_connect_playback_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags, const ma_pa_cvolume* volume, ma_pa_stream* sync_stream); typedef int (* ma_pa_stream_connect_record_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags); typedef int (* ma_pa_stream_disconnect_proc) (ma_pa_stream* s); -typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (ma_pa_stream* s); +typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (const ma_pa_stream* s); typedef const ma_pa_sample_spec* (* ma_pa_stream_get_sample_spec_proc) (ma_pa_stream* s); typedef const ma_pa_channel_map* (* ma_pa_stream_get_channel_map_proc) (ma_pa_stream* s); typedef const ma_pa_buffer_attr* (* ma_pa_stream_get_buffer_attr_proc) (ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_set_buffer_attr_proc) (ma_pa_stream* s, const ma_pa_buffer_attr* attr, ma_pa_stream_success_cb_t cb, void* userdata); -typedef const char* (* ma_pa_stream_get_device_name_proc) (ma_pa_stream* s); +typedef const char* (* ma_pa_stream_get_device_name_proc) (const ma_pa_stream* s); typedef void (* ma_pa_stream_set_write_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_read_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_suspended_callback_proc) (ma_pa_stream* s, ma_pa_stream_notify_cb_t cb, void* userdata); @@ -29473,15 +30935,15 @@ typedef void (* ma_pa_stream_set_moved_callback_proc) ( typedef int (* ma_pa_stream_is_suspended_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_flush_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_drain_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); -typedef int (* ma_pa_stream_is_corked_proc) (ma_pa_stream* s); +typedef int (* ma_pa_stream_is_corked_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_cork_proc) (ma_pa_stream* s, int b, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_trigger_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef int (* ma_pa_stream_begin_write_proc) (ma_pa_stream* s, void** data, size_t* nbytes); typedef int (* ma_pa_stream_write_proc) (ma_pa_stream* s, const void* data, size_t nbytes, ma_pa_free_cb_t free_cb, int64_t offset, ma_pa_seek_mode_t seek); typedef int (* ma_pa_stream_peek_proc) (ma_pa_stream* s, const void** data, size_t* nbytes); typedef int (* ma_pa_stream_drop_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_writable_size_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_readable_size_proc) (ma_pa_stream* s); +typedef size_t (* ma_pa_stream_writable_size_proc) (const ma_pa_stream* s); +typedef size_t (* ma_pa_stream_readable_size_proc) (const ma_pa_stream* s); typedef struct { @@ -29777,9 +31239,10 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, } /* Now we need to connect to the context. Everything is asynchronous so we need to wait for it to connect before returning. */ - result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? 0 : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); + result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? MA_PA_CONTEXT_NOFLAGS : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio context."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -29788,6 +31251,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, result = ma_wait_for_pa_context_to_connect__pulse(pContext, pMainLoop, pPulseContext); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Waiting for connection failed."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -30510,7 +31974,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi const ma_pa_buffer_attr* pActualAttr = NULL; const ma_pa_channel_map* pActualChannelMap = NULL; ma_uint32 iChannel; - ma_pa_stream_flags_t streamFlags; + int streamFlags; MA_ASSERT(pDevice != NULL); MA_ZERO_OBJECT(&pDevice->pulse); @@ -30568,8 +32032,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorCapture->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ if (pDescriptorCapture->sampleRate != 0) { @@ -30626,7 +32095,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, streamFlags); + error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, (ma_pa_stream_flags_t)streamFlags); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio capture stream."); result = ma_result_from_pulse(error); @@ -30720,8 +32189,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorPlayback->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ @@ -30783,7 +32257,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, streamFlags, NULL, NULL); + error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, (ma_pa_stream_flags_t)streamFlags, NULL, NULL); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio playback stream."); result = ma_result_from_pulse(error); @@ -31338,6 +32812,7 @@ typedef JackProcessCallback ma_JackProcessCallback; typedef JackBufferSizeCallback ma_JackBufferSizeCallback; typedef JackShutdownCallback ma_JackShutdownCallback; #define MA_JACK_DEFAULT_AUDIO_TYPE JACK_DEFAULT_AUDIO_TYPE +#define ma_JackNullOption JackNullOption #define ma_JackNoStartServer JackNoStartServer #define ma_JackPortIsInput JackPortIsInput #define ma_JackPortIsOutput JackPortIsOutput @@ -31352,6 +32827,7 @@ typedef int (* ma_JackProcessCallback) (ma_jack_nframes_t nframes, void* arg) typedef int (* ma_JackBufferSizeCallback)(ma_jack_nframes_t nframes, void* arg); typedef void (* ma_JackShutdownCallback) (void* arg); #define MA_JACK_DEFAULT_AUDIO_TYPE "32 bit float mono audio" +#define ma_JackNullOption 0 #define ma_JackNoStartServer 1 #define ma_JackPortIsInput 1 #define ma_JackPortIsOutput 2 @@ -31392,7 +32868,7 @@ static ma_result ma_context_open_client__jack(ma_context* pContext, ma_jack_clie maxClientNameSize = ((ma_jack_client_name_size_proc)pContext->jack.jack_client_name_size)(); /* Includes null terminator. */ ma_strncpy_s(clientName, ma_min(sizeof(clientName), maxClientNameSize), (pContext->jack.pClientName != NULL) ? pContext->jack.pClientName : "miniaudio", (size_t)-1); - pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? 0 : ma_JackNoStartServer, &status, NULL); + pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? ma_JackNullOption : ma_JackNoStartServer, &status, NULL); if (pClient == NULL) { return MA_FAILED_TO_OPEN_BACKEND_DEVICE; } @@ -36994,7 +38470,7 @@ OSS Backend #define MA_OSS_DEFAULT_DEVICE_NAME "/dev/dsp" -static int ma_open_temp_device__oss() +static int ma_open_temp_device__oss(void) { /* The OSS sample code uses "/dev/mixer" as the device for getting system properties so I'm going to do the same. */ int fd = open("/dev/mixer", O_RDONLY, 0); @@ -37834,25 +39310,30 @@ static void ma_stream_error_callback__aaudio(ma_AAudioStream* pStream, void* pUs (void)error; ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] ERROR CALLBACK: error=%d, AAudioStream_getState()=%d\n", error, ((MA_PFN_AAudioStream_getState)pDevice->pContext->aaudio.AAudioStream_getState)(pStream)); + /* When we get an error, we'll assume that the stream is in an erroneous state and needs to be restarted. From the documentation, we cannot do this from the error callback. Therefore we are going to use an event thread for the AAudio backend to do this cleanly and safely. */ - job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); - job.data.device.aaudio.reroute.pDevice = pDevice; - - if (pStream == pDevice->aaudio.pStreamCapture) { - job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Tearing down device.\n"); } else { - job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; - } - - result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); - if (result != MA_SUCCESS) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); - return; + job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); + job.data.device.aaudio.reroute.pDevice = pDevice; + + if (pStream == pDevice->aaudio.pStreamCapture) { + job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + } else { + job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; + } + + result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); + if (result != MA_SUCCESS) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); + return; + } } } @@ -38169,7 +39650,7 @@ static ma_result ma_close_streams__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); - /* When re-routing, streams may have been closed and never re-opened. Hence the extra checks below. */ + /* When rerouting, streams may have been closed and never re-opened. Hence the extra checks below. */ if (pDevice->type == ma_device_type_capture || pDevice->type == ma_device_type_duplex) { ma_close_stream__aaudio(pDevice->pContext, (ma_AAudioStream*)pDevice->aaudio.pStreamCapture); pDevice->aaudio.pStreamCapture = NULL; @@ -38186,6 +39667,12 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); + /* + Note: Closing the streams may cause a timeout error, which would then trigger rerouting in our error callback. + We must not schedule a reroute when device is getting destroyed. + */ + ma_atomic_bool32_set(&pDevice->aaudio.isTearingDown, MA_TRUE); + /* Wait for any rerouting to finish before attempting to close the streams. */ ma_mutex_lock(&pDevice->aaudio.rerouteLock); { @@ -38193,7 +39680,7 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) } ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - /* Destroy re-routing lock. */ + /* Destroy rerouting lock. */ ma_mutex_uninit(&pDevice->aaudio.rerouteLock); return MA_SUCCESS; @@ -38429,17 +39916,22 @@ static ma_result ma_device_stop__aaudio(ma_device* pDevice) static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type deviceType) { + const ma_int32 maxAttempts = 4; /* Reasonable retry limit. */ + ma_result result; - int32_t retries = 0; + ma_int32 iAttempt; MA_ASSERT(pDevice != NULL); - /* - TODO: Stop retrying if main thread is about to uninit device. - */ - ma_mutex_lock(&pDevice->aaudio.rerouteLock); - { -error_disconnected: + /* We got disconnected! Retry a few times, until we find a connected device! */ + iAttempt = 0; + while (iAttempt++ < maxAttempts) { + /* Device tearing down? No need to reroute! */ + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + result = MA_SUCCESS; /* Caller should continue as normal. */ + break; + } + /* The first thing to do is close the streams. */ ma_close_streams__aaudio(pDevice); @@ -38495,14 +39987,16 @@ error_disconnected: result = ma_device_init_streams__aaudio(pDevice, &deviceConfig, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to create stream after route change."); - goto done; + /* Reroute failed! */ + break; } result = ma_device_post_init(pDevice, deviceType, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to initialize device after route change."); ma_close_streams__aaudio(pDevice); - goto done; + /* Reroute failed! */ + break; } /* We'll only ever do this in response to a reroute. */ @@ -38513,26 +40007,23 @@ error_disconnected: if (pDevice->aaudio.noAutoStartAfterReroute == MA_FALSE) { result = ma_device_start__aaudio(pDevice); if (result != MA_SUCCESS) { - /* We got disconnected! Retry a few times, until we find a connected device! */ - retries += 1; - if (retries <= 3) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", retries); - goto error_disconnected; + if (iAttempt < maxAttempts) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", iAttempt); + } else { + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, giving up."); } - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change."); - goto done; } } else { - ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ + ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ } } - - result = MA_SUCCESS; - } -done: - /* Re-routing done */ - ma_mutex_unlock(&pDevice->aaudio.rerouteLock); + if (result == MA_SUCCESS) { + /* Reroute successful! */ + break; + } + } + return result; } @@ -38698,7 +40189,7 @@ static ma_result ma_context_init__aaudio(ma_context* pContext, const ma_context_ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) { - ma_result result; + ma_result result = MA_SUCCESS; ma_device* pDevice; MA_ASSERT(pJob != NULL); @@ -38706,19 +40197,22 @@ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) pDevice = (ma_device*)pJob->data.device.aaudio.reroute.pDevice; MA_ASSERT(pDevice != NULL); - /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ - result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); - if (result != MA_SUCCESS) { - /* - Getting here means we failed to reroute the device. The best thing I can think of here is to - just stop the device. - */ - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); - ma_device_stop(pDevice); - return result; + ma_mutex_lock(&pDevice->aaudio.rerouteLock); + { + /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ + result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); + if (result != MA_SUCCESS) { + /* + Getting here means we failed to reroute the device. The best thing I can think of here is to + just stop the device. + */ + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); + ma_device_stop(pDevice); + } } + ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - return MA_SUCCESS; + return result; } #else /* Getting here means there is no AAudio backend so we need a no-op job implementation. */ @@ -40269,8 +41763,11 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const frameCount = pDevice->capture.internalPeriodSizeInFrames; } + /* + If this is called by the device has not yet been started we need to return early, making sure we output silence to + the output buffer. + */ if (ma_device_get_state(pDevice) != ma_device_state_started) { - /* Fill the output buffer with zero to avoid a noise sound */ for (int i = 0; i < outputCount; i += 1) { MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); } @@ -40292,7 +41789,9 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const if (outputCount > 0) { /* If it's a capture-only device, we'll need to output silence. */ if (pDevice->type == ma_device_type_capture) { - MA_ZERO_MEMORY(pOutputs[0].data, frameCount * pDevice->playback.internalChannels * sizeof(float)); + for (int i = 0; i < outputCount; i += 1) { + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } else { ma_device_process_pcm_frames_playback__webaudio(pDevice, frameCount, pDevice->webaudio.pIntermediaryBuffer); @@ -40302,6 +41801,14 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const pOutputs[0].data[frameCount*iChannel + iFrame] = pDevice->webaudio.pIntermediaryBuffer[iFrame*pDevice->playback.internalChannels + iChannel]; } } + + /* + Just above we output data to the first output buffer. Here we just make sure we're putting silence into any + remaining output buffers. + */ + for (int i = 1; i < outputCount; i += 1) { /* <-- Note that the counter starts at 1 instead of 0. */ + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } } @@ -40782,8 +42289,8 @@ static ma_result ma_context_uninit__webaudio(ma_context* pContext) /* Remove the global miniaudio object from window if there are no more references to it. */ EM_ASM({ if (typeof(window.miniaudio) !== 'undefined') { - miniaudio.unlock_event_types.map(function(event_type) { - document.removeEventListener(event_type, miniaudio.unlock, true); + window.miniaudio.unlock_event_types.map(function(event_type) { + document.removeEventListener(event_type, window.miniaudio.unlock, true); }); window.miniaudio.referenceCount -= 1; @@ -41236,13 +42743,13 @@ MA_API ma_result ma_device_post_init(ma_device* pDevice, ma_device_type deviceTy static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) { ma_device* pDevice = (ma_device*)pData; -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) HRESULT CoInitializeResult; #endif MA_ASSERT(pDevice != NULL); -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) CoInitializeResult = ma_CoInitializeEx(pDevice->pContext, NULL, MA_COINIT_VALUE); #endif @@ -41333,8 +42840,8 @@ static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) ma_event_signal(&pDevice->stopEvent); } -#ifdef MA_WIN32 - if (CoInitializeResult == S_OK) { +#if defined(MA_WIN32) && !defined(MA_XBOX) + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pDevice->pContext); } #endif @@ -41358,67 +42865,92 @@ static ma_bool32 ma_device__is_initialized(ma_device* pDevice) static ma_result ma_context_uninit_backend_apis__win32(ma_context* pContext) { /* For some reason UWP complains when CoUninitialize() is called. I'm just not going to call it on UWP. */ -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - if (pContext->win32.CoInitializeResult == S_OK) { - ma_CoUninitialize(pContext); + #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) + { + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + if (pContext->win32.CoInitializeResult == S_OK || pContext->win32.CoInitializeResult == S_FALSE) { + ma_CoUninitialize(pContext); /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + } + } + #endif + + #if defined(MA_WIN32_DESKTOP) + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); + #endif + + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); + } + #else + { + (void)pContext; } - - #if defined(MA_WIN32_DESKTOP) - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); #endif - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); -#else - (void)pContext; -#endif - return MA_SUCCESS; } static ma_result ma_context_init_backend_apis__win32(ma_context* pContext) { -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - #if defined(MA_WIN32_DESKTOP) - /* User32.dll */ - pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); - if (pContext->win32.hUser32DLL == NULL) { + /* + TODO: Reassess all of this stuff and move everything to the relevant backends. For example, I think + GetForegroundWindow() and GetDesktopWindow() are only used by the DirectSound backend. + */ + #if (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK)) && !defined(MA_XBOX) + { + #if defined(MA_WIN32_DESKTOP) + { + /* User32.dll */ + pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); + if (pContext->win32.hUser32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); + pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); + + + /* Advapi32.dll */ + pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); + if (pContext->win32.hAdvapi32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); + pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); + pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + } + #endif + + /* Ole32.dll */ + pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); + if (pContext->win32.hOle32DLL == NULL) { return MA_FAILED_TO_INIT_BACKEND; } - pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); - pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); - - - /* Advapi32.dll */ - pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); - if (pContext->win32.hAdvapi32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; - } - - pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); - pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); - pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); + pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); + pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); + pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); + pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); + pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); + pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); + } + #else + { + (void)pContext; /* Unused. */ + } #endif - /* Ole32.dll */ - pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); - if (pContext->win32.hOle32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); } + #endif - pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); - pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); - pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); - pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); - pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); - pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); - pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); -#else - (void)pContext; /* Unused. */ -#endif - - pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); return MA_SUCCESS; } #else @@ -44016,7 +45548,7 @@ static MA_INLINE void ma_pcm_s16_to_s32__reference(void* dst, const void* src, m ma_uint64 i; for (i = 0; i < count; i += 1) { - dst_s32[i] = src_s16[i] << 16; + dst_s32[i] = (ma_int32)src_s16[i] << 16; } (void)ditherMode; @@ -49347,15 +50879,15 @@ static /*__attribute__((noinline))*/ ma_result ma_gainer_process_pcm_frames_inte a += d; } } + + pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); + pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } + frameCount -= interpolatedFrameCount; + /* Make sure the timer is updated. */ pGainer->t = (ma_uint32)ma_min(pGainer->t + interpolatedFrameCount, pGainer->config.smoothTimeInFrames); - - /* Adjust our arguments so the next part can work normally. */ - frameCount -= interpolatedFrameCount; - pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); - pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } /* All we need to do here is apply the new gains using an optimized path. */ @@ -50783,13 +52315,16 @@ static float ma_calculate_angular_gain(ma_vec3f dirA, ma_vec3f dirB, float coneI MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, ma_spatializer_listener* pListener, void* pFramesOut, const void* pFramesIn, ma_uint64 frameCount) { - ma_channel* pChannelMapIn = pSpatializer->pChannelMapIn; - ma_channel* pChannelMapOut = pListener->config.pChannelMapOut; + ma_channel* pChannelMapIn; + ma_channel* pChannelMapOut; - if (pSpatializer == NULL) { + if (pSpatializer == NULL || pListener == NULL) { return MA_INVALID_ARGS; } + pChannelMapIn = pSpatializer->pChannelMapIn; + pChannelMapOut = pListener->config.pChannelMapOut; + /* If we're not spatializing we need to run an optimized path. */ if (ma_atomic_load_i32(&pSpatializer->attenuationModel) == ma_attenuation_model_none) { if (ma_spatializer_listener_is_enabled(pListener)) { @@ -50834,23 +52369,17 @@ MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, We'll need the listener velocity for doppler pitch calculations. The speed of sound is defined by the listener, so we'll grab that here too. */ - if (pListener != NULL) { - listenerVel = ma_spatializer_listener_get_velocity(pListener); - speedOfSound = pListener->config.speedOfSound; - } else { - listenerVel = ma_vec3f_init_3f(0, 0, 0); - speedOfSound = MA_DEFAULT_SPEED_OF_SOUND; - } + listenerVel = ma_spatializer_listener_get_velocity(pListener); + speedOfSound = pListener->config.speedOfSound; - if (pListener == NULL || ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { - /* There's no listener or we're using relative positioning. */ + if (ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { relativePos = ma_spatializer_get_position(pSpatializer); relativeDir = ma_spatializer_get_direction(pSpatializer); } else { /* - We've found a listener and we're using absolute positioning. We need to transform the - sound's position and direction so that it's relative to listener. Later on we'll use - this for determining the factors to apply to each channel to apply the panning effect. + We're using absolute positioning. We need to transform the sound's position and + direction so that it's relative to listener. Later on we'll use this for determining + the factors to apply to each channel to apply the panning effect. */ ma_spatializer_get_relative_position_and_direction(pSpatializer, pListener, &relativePos, &relativeDir); } @@ -52885,7 +54414,7 @@ static ma_bool32 ma_is_spatial_channel_position(ma_channel channelPosition) return MA_FALSE; } - if (channelPosition >= MA_CHANNEL_AUX_0 && channelPosition <= MA_CHANNEL_AUX_31) { + if (channelPosition >= MA_CHANNEL_AUX_0) { return MA_FALSE; } @@ -56408,8 +57937,12 @@ MA_API size_t ma_channel_map_to_string(const ma_channel* pChannelMap, ma_uint32 } /* Null terminate. Don't increment the length here. */ - if (pBufferOut != NULL && bufferCap > len + 1) { - pBufferOut[len] = '\0'; + if (pBufferOut != NULL) { + if (bufferCap > len) { + pBufferOut[len] = '\0'; + } else if (bufferCap > 0) { + pBufferOut[bufferCap - 1] = '\0'; + } } return len; @@ -56620,7 +58153,7 @@ MA_API ma_result ma_rb_init_ex(size_t subbufferSizeInBytes, size_t subbufferCoun Here is where we allocate our own buffer. We always want to align this to MA_SIMD_ALIGNMENT for future SIMD optimization opportunity. To do this we need to make sure the stride is a multiple of MA_SIMD_ALIGNMENT. */ - pRB->subbufferStrideInBytes = (pRB->subbufferSizeInBytes + (MA_SIMD_ALIGNMENT-1)) & ~MA_SIMD_ALIGNMENT; + pRB->subbufferStrideInBytes = ma_align(pRB->subbufferSizeInBytes, MA_SIMD_ALIGNMENT); bufferSizeInBytes = (size_t)pRB->subbufferCount*pRB->subbufferStrideInBytes; pRB->pBuffer = ma_aligned_malloc(bufferSizeInBytes, MA_SIMD_ALIGNMENT, &pRB->allocationCallbacks); @@ -59515,7 +61048,7 @@ MA_API ma_result ma_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo } -#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && defined(MA_WIN32_DESKTOP) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) +#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_NXDK)) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) #define MA_USE_WIN32_FILEIO #endif @@ -59592,25 +61125,34 @@ static ma_result ma_default_vfs_open__win32(ma_vfs* pVFS, const char* pFilePath, static ma_result ma_default_vfs_open_w__win32(ma_vfs* pVFS, const wchar_t* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) { - HANDLE hFile; - DWORD dwDesiredAccess; - DWORD dwShareMode; - DWORD dwCreationDisposition; + #if !defined(MA_XBOX_NXDK) + { + HANDLE hFile; + DWORD dwDesiredAccess; + DWORD dwShareMode; + DWORD dwCreationDisposition; - (void)pVFS; + (void)pVFS; - /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ - ma_win32_fileio_init(); + /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ + ma_win32_fileio_init(); - ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); + ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); - hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - return ma_result_from_GetLastError(GetLastError()); + hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + return ma_result_from_GetLastError(GetLastError()); + } + + *pFile = hFile; + return MA_SUCCESS; } - - *pFile = hFile; - return MA_SUCCESS; + #else + { + /* No CreateFileW() available. */ + return MA_NOT_IMPLEMENTED; + } + #endif } static ma_result ma_default_vfs_close__win32(ma_vfs* pVFS, ma_vfs_file file) @@ -59781,19 +61323,28 @@ static ma_result ma_default_vfs_tell__win32(ma_vfs* pVFS, ma_vfs_file file, ma_i static ma_result ma_default_vfs_info__win32(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { - BY_HANDLE_FILE_INFORMATION fi; - BOOL result; - (void)pVFS; - result = GetFileInformationByHandle((HANDLE)file, &fi); - if (result == 0) { - return ma_result_from_GetLastError(GetLastError()); + #if !defined(MA_XBOX_NXDK) + { + BY_HANDLE_FILE_INFORMATION fi; + BOOL result; + + result = GetFileInformationByHandle((HANDLE)file, &fi); + if (result == 0) { + return ma_result_from_GetLastError(GetLastError()); + } + + pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); + + return MA_SUCCESS; } - - pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); - - return MA_SUCCESS; + #else + { + /* GetFileInformationByHandle() is unavailable. */ + return MA_NOT_IMPLEMENTED; + } + #endif } #else static ma_result ma_default_vfs_open__stdio(ma_vfs* pVFS, const char* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) @@ -60131,6 +61682,8 @@ static ma_result ma_default_vfs_tell(ma_vfs* pVFS, ma_vfs_file file, ma_int64* p static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { + ma_result result; + if (pInfo == NULL) { return MA_INVALID_ARGS; } @@ -60142,10 +61695,42 @@ static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_inf } #if defined(MA_USE_WIN32_FILEIO) - return ma_default_vfs_info__win32(pVFS, file, pInfo); + result = ma_default_vfs_info__win32(pVFS, file, pInfo); #else - return ma_default_vfs_info__stdio(pVFS, file, pInfo); + result = ma_default_vfs_info__stdio(pVFS, file, pInfo); #endif + + if (result == MA_NOT_IMPLEMENTED) { + /* Not implemented. Fall back to seek/tell/seek. */ + ma_int64 cursor; + ma_int64 sizeInBytes; + + result = ma_default_vfs_tell(pVFS, file, &cursor); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_seek(pVFS, file, 0, ma_seek_origin_end); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_tell(pVFS, file, &sizeInBytes); + if (result != MA_SUCCESS) { + return result; + } + + pInfo->sizeInBytes = sizeInBytes; + + result = ma_default_vfs_seek(pVFS, file, cursor, ma_seek_origin_start); + if (result != MA_SUCCESS) { + return result; + } + + MA_ASSERT(result == MA_SUCCESS); + } + + return result; } @@ -60324,6 +61909,8 @@ Decoding and Encoding Headers. These are auto-generated from a tool. **************************************************************************************************************************************************************/ #if !defined(MA_NO_WAV) && (!defined(MA_NO_DECODING) || !defined(MA_NO_ENCODING)) +#define MA_HAS_WAV + /* dr_wav_h begin */ #ifndef ma_dr_wav_h #define ma_dr_wav_h @@ -60333,8 +61920,8 @@ extern "C" { #define MA_DR_WAV_STRINGIFY(x) #x #define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x) #define MA_DR_WAV_VERSION_MAJOR 0 -#define MA_DR_WAV_VERSION_MINOR 13 -#define MA_DR_WAV_VERSION_REVISION 18 +#define MA_DR_WAV_VERSION_MINOR 14 +#define MA_DR_WAV_VERSION_REVISION 4 #define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION) #include #define MA_DR_WAVE_FORMAT_PCM 0x1 @@ -60350,8 +61937,9 @@ MA_API void ma_dr_wav_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* p MA_API const char* ma_dr_wav_version_string(void); typedef enum { - ma_dr_wav_seek_origin_start, - ma_dr_wav_seek_origin_current + MA_DR_WAV_SEEK_SET, + MA_DR_WAV_SEEK_CUR, + MA_DR_WAV_SEEK_END } ma_dr_wav_seek_origin; typedef enum { @@ -60388,6 +61976,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT); typedef size_t (* ma_dr_wav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef size_t (* ma_dr_wav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); typedef ma_bool32 (* ma_dr_wav_seek_proc)(void* pUserData, int offset, ma_dr_wav_seek_origin origin); +typedef ma_bool32 (* ma_dr_wav_tell_proc)(void* pUserData, ma_int64* pCursor); typedef ma_uint64 (* ma_dr_wav_chunk_proc)(void* pChunkUserData, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_dr_wav_chunk_header* pChunkHeader, ma_dr_wav_container container, const ma_dr_wav_fmt* pFMT); typedef struct { @@ -60432,6 +62021,11 @@ typedef enum ma_dr_wav_metadata_type_list_info_genre = 1 << 15, ma_dr_wav_metadata_type_list_info_album = 1 << 16, ma_dr_wav_metadata_type_list_info_tracknumber = 1 << 17, + ma_dr_wav_metadata_type_list_info_location = 1 << 18, + ma_dr_wav_metadata_type_list_info_organization = 1 << 19, + ma_dr_wav_metadata_type_list_info_keywords = 1 << 20, + ma_dr_wav_metadata_type_list_info_medium = 1 << 21, + ma_dr_wav_metadata_type_list_info_description = 1 << 22, ma_dr_wav_metadata_type_list_all_info_strings = ma_dr_wav_metadata_type_list_info_software | ma_dr_wav_metadata_type_list_info_copyright | ma_dr_wav_metadata_type_list_info_title @@ -60440,7 +62034,12 @@ typedef enum | ma_dr_wav_metadata_type_list_info_date | ma_dr_wav_metadata_type_list_info_genre | ma_dr_wav_metadata_type_list_info_album - | ma_dr_wav_metadata_type_list_info_tracknumber, + | ma_dr_wav_metadata_type_list_info_tracknumber + | ma_dr_wav_metadata_type_list_info_location + | ma_dr_wav_metadata_type_list_info_organization + | ma_dr_wav_metadata_type_list_info_keywords + | ma_dr_wav_metadata_type_list_info_medium + | ma_dr_wav_metadata_type_list_info_description, ma_dr_wav_metadata_type_list_all_adtl = ma_dr_wav_metadata_type_list_label | ma_dr_wav_metadata_type_list_note | ma_dr_wav_metadata_type_list_labelled_cue_region, @@ -60457,8 +62056,8 @@ typedef struct { ma_uint32 cuePointId; ma_uint32 type; - ma_uint32 firstSampleByteOffset; - ma_uint32 lastSampleByteOffset; + ma_uint32 firstSampleOffset; + ma_uint32 lastSampleOffset; ma_uint32 sampleFraction; ma_uint32 playCount; } ma_dr_wav_smpl_loop; @@ -60493,7 +62092,7 @@ typedef struct ma_uint8 dataChunkId[4]; ma_uint32 chunkStart; ma_uint32 blockStart; - ma_uint32 sampleByteOffset; + ma_uint32 sampleOffset; } ma_dr_wav_cue_point; typedef struct { @@ -60595,6 +62194,7 @@ typedef struct ma_dr_wav_read_proc onRead; ma_dr_wav_write_proc onWrite; ma_dr_wav_seek_proc onSeek; + ma_dr_wav_tell_proc onTell; void* pUserData; ma_allocation_callbacks allocationCallbacks; ma_dr_wav_container container; @@ -60637,9 +62237,9 @@ typedef struct ma_bool8 isUnsigned; } aiff; } ma_dr_wav; -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_dr_wav_write_proc onWrite, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential_pcm_frames(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60711,9 +62311,9 @@ MA_API ma_bool32 ma_dr_wav_init_memory_write(ma_dr_wav* pWav, void** ppData, siz MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential_pcm_frames(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_CONVERSION_API -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_STDIO MA_API ma_int16* ma_dr_wav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_wav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60744,6 +62344,8 @@ MA_API ma_bool32 ma_dr_wav_fourcc_equal(const ma_uint8* a, const char* b); #endif /* MA_NO_WAV */ #if !defined(MA_NO_FLAC) && !defined(MA_NO_DECODING) +#define MA_HAS_FLAC + /* dr_flac_h begin */ #ifndef ma_dr_flac_h #define ma_dr_flac_h @@ -60753,8 +62355,8 @@ extern "C" { #define MA_DR_FLAC_STRINGIFY(x) #x #define MA_DR_FLAC_XSTRINGIFY(x) MA_DR_FLAC_STRINGIFY(x) #define MA_DR_FLAC_VERSION_MAJOR 0 -#define MA_DR_FLAC_VERSION_MINOR 12 -#define MA_DR_FLAC_VERSION_REVISION 43 +#define MA_DR_FLAC_VERSION_MINOR 13 +#define MA_DR_FLAC_VERSION_REVISION 3 #define MA_DR_FLAC_VERSION_STRING MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MAJOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MINOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_REVISION) #include #if defined(_MSC_VER) && _MSC_VER >= 1700 @@ -60817,8 +62419,9 @@ typedef enum } ma_dr_flac_container; typedef enum { - ma_dr_flac_seek_origin_start, - ma_dr_flac_seek_origin_current + MA_DR_FLAC_SEEK_SET, + MA_DR_FLAC_SEEK_CUR, + MA_DR_FLAC_SEEK_END } ma_dr_flac_seek_origin; typedef struct { @@ -60841,8 +62444,9 @@ typedef struct typedef struct { ma_uint32 type; - const void* pRawData; ma_uint32 rawDataSize; + ma_uint64 rawDataOffset; + const void* pRawData; union { ma_dr_flac_streaminfo streaminfo; @@ -60888,12 +62492,14 @@ typedef struct ma_uint32 colorDepth; ma_uint32 indexColorCount; ma_uint32 pictureDataSize; + ma_uint64 pictureDataOffset; const ma_uint8* pPictureData; } picture; } data; } ma_dr_flac_metadata; typedef size_t (* ma_dr_flac_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_flac_seek_proc)(void* pUserData, int offset, ma_dr_flac_seek_origin origin); +typedef ma_bool32 (* ma_dr_flac_tell_proc)(void* pUserData, ma_int64* pCursor); typedef void (* ma_dr_flac_meta_proc)(void* pUserData, ma_dr_flac_metadata* pMetadata); typedef struct { @@ -60905,6 +62511,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; size_t unalignedByteCount; ma_dr_flac_cache_t unalignedCache; @@ -60964,10 +62571,10 @@ typedef struct ma_dr_flac_bs bs; ma_uint8 pExtraData[1]; } ma_dr_flac; -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API void ma_dr_flac_close(ma_dr_flac* pFlac); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s32(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int32* pBufferOut); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s16(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int16* pBufferOut); @@ -60981,9 +62588,9 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName #endif MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_t dataSize, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_FLAC_NO_STDIO MA_API ma_int32* ma_dr_flac_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_flac_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); @@ -61031,6 +62638,14 @@ MA_API ma_bool32 ma_dr_flac_next_cuesheet_track(ma_dr_flac_cuesheet_track_iterat #endif /* MA_NO_FLAC */ #if !defined(MA_NO_MP3) && !defined(MA_NO_DECODING) +#define MA_HAS_MP3 + +#ifndef MA_DR_MP3_NO_SIMD + #if (defined(MA_NO_NEON) && defined(MA_ARM)) || (defined(MA_NO_SSE2) && (defined(MA_X86) || defined(MA_X64))) + #define MA_DR_MP3_NO_SIMD + #endif +#endif + /* dr_mp3_h begin */ #ifndef ma_dr_mp3_h #define ma_dr_mp3_h @@ -61040,31 +62655,57 @@ extern "C" { #define MA_DR_MP3_STRINGIFY(x) #x #define MA_DR_MP3_XSTRINGIFY(x) MA_DR_MP3_STRINGIFY(x) #define MA_DR_MP3_VERSION_MAJOR 0 -#define MA_DR_MP3_VERSION_MINOR 6 -#define MA_DR_MP3_VERSION_REVISION 40 +#define MA_DR_MP3_VERSION_MINOR 7 +#define MA_DR_MP3_VERSION_REVISION 3 #define MA_DR_MP3_VERSION_STRING MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MAJOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MINOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_REVISION) #include #define MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME 1152 #define MA_DR_MP3_MAX_SAMPLES_PER_FRAME (MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME*2) MA_API void ma_dr_mp3_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* pRevision); MA_API const char* ma_dr_mp3_version_string(void); +#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 +#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 +#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE typedef struct { - int frame_bytes, channels, hz, layer, bitrate_kbps; + int frame_bytes, channels, sample_rate, layer, bitrate_kbps; } ma_dr_mp3dec_frame_info; typedef struct +{ + const ma_uint8 *buf; + int pos, limit; +} ma_dr_mp3_bs; +typedef struct +{ + const ma_uint8 *sfbtab; + ma_uint16 part_23_length, big_values, scalefac_compress; + ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; + ma_uint8 table_select[3], region_count[3], subblock_gain[3]; + ma_uint8 preflag, scalefac_scale, count1_table, scfsi; +} ma_dr_mp3_L3_gr_info; +typedef struct +{ + ma_dr_mp3_bs bs; + ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; + ma_dr_mp3_L3_gr_info gr_info[4]; + float grbuf[2][576], scf[40], syn[18 + 15][2*32]; + ma_uint8 ist_pos[2][39]; +} ma_dr_mp3dec_scratch; +typedef struct { float mdct_overlap[2][9*32], qmf_state[15*2*32]; int reserv, free_format_bytes; ma_uint8 header[4], reserv_buf[511]; + ma_dr_mp3dec_scratch scratch; } ma_dr_mp3dec; MA_API void ma_dr_mp3dec_init(ma_dr_mp3dec *dec); MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int mp3_bytes, void *pcm, ma_dr_mp3dec_frame_info *info); MA_API void ma_dr_mp3dec_f32_to_s16(const float *in, ma_int16 *out, size_t num_samples); typedef enum { - ma_dr_mp3_seek_origin_start, - ma_dr_mp3_seek_origin_current + MA_DR_MP3_SEEK_SET, + MA_DR_MP3_SEEK_CUR, + MA_DR_MP3_SEEK_END } ma_dr_mp3_seek_origin; typedef struct { @@ -61073,8 +62714,24 @@ typedef struct ma_uint16 mp3FramesToDiscard; ma_uint16 pcmFramesToDiscard; } ma_dr_mp3_seek_point; +typedef enum +{ + MA_DR_MP3_METADATA_TYPE_ID3V1, + MA_DR_MP3_METADATA_TYPE_ID3V2, + MA_DR_MP3_METADATA_TYPE_APE, + MA_DR_MP3_METADATA_TYPE_XING, + MA_DR_MP3_METADATA_TYPE_VBRI +} ma_dr_mp3_metadata_type; +typedef struct +{ + ma_dr_mp3_metadata_type type; + const void* pRawData; + size_t rawDataSize; +} ma_dr_mp3_metadata; typedef size_t (* ma_dr_mp3_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_mp3_seek_proc)(void* pUserData, int offset, ma_dr_mp3_seek_origin origin); +typedef ma_bool32 (* ma_dr_mp3_tell_proc)(void* pUserData, ma_int64* pCursor); +typedef void (* ma_dr_mp3_meta_proc)(void* pUserData, const ma_dr_mp3_metadata* pMetadata); typedef struct { ma_uint32 channels; @@ -61087,7 +62744,9 @@ typedef struct ma_uint32 sampleRate; ma_dr_mp3_read_proc onRead; ma_dr_mp3_seek_proc onSeek; + ma_dr_mp3_meta_proc onMeta; void* pUserData; + void* pUserDataMeta; ma_allocation_callbacks allocationCallbacks; ma_uint32 mp3FrameChannels; ma_uint32 mp3FrameSampleRate; @@ -61096,13 +62755,20 @@ typedef struct ma_uint8 pcmFrames[sizeof(float)*MA_DR_MP3_MAX_SAMPLES_PER_FRAME]; ma_uint64 currentPCMFrame; ma_uint64 streamCursor; + ma_uint64 streamLength; + ma_uint64 streamStartOffset; ma_dr_mp3_seek_point* pSeekPoints; ma_uint32 seekPointCount; + ma_uint32 delayInPCMFrames; + ma_uint32 paddingInPCMFrames; + ma_uint64 totalPCMFrameCount; + ma_bool32 isVBR; + ma_bool32 isCBR; size_t dataSize; size_t dataCapacity; size_t dataConsumed; ma_uint8* pData; - ma_bool32 atEnd : 1; + ma_bool32 atEnd; struct { const ma_uint8* pData; @@ -61110,9 +62776,12 @@ typedef struct size_t currentReadPos; } memory; } ma_dr_mp3; -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); #endif @@ -61125,8 +62794,8 @@ MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3); MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint64* pMP3FrameCount, ma_uint64* pPCMFrameCount); MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSeekPointCount, ma_dr_mp3_seek_point* pSeekPoints); MA_API ma_bool32 ma_dr_mp3_bind_seek_table(ma_dr_mp3* pMP3, ma_uint32 seekPointCount, ma_dr_mp3_seek_point* pSeekPoints); -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_mp3_open_memory_and_read_pcm_frames_f32(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_mp3_open_memory_and_read_pcm_frames_s16(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO @@ -61591,7 +63260,6 @@ static ma_result ma_decoder_init_custom_from_memory__internal(const void* pData, /* WAV */ #ifdef ma_dr_wav_h -#define MA_HAS_WAV typedef struct { @@ -61679,8 +63347,10 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav MA_ASSERT(pWav != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_wav_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pWav->onSeek(pWav->pReadSeekTellUserData, offset, maSeekOrigin); @@ -61690,6 +63360,26 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav return MA_TRUE; } + +static ma_bool32 ma_wav_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_wav* pWav = (ma_wav*)pUserData; + ma_result result; + + MA_ASSERT(pWav != NULL); + MA_ASSERT(pCursor != NULL); + + if (pWav->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pWav->onTell(pWav->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_wav_init_internal(const ma_decoding_backend_config* pConfig, ma_wav* pWav) @@ -61784,7 +63474,7 @@ MA_API ma_result ma_wav_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 wavResult; - wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, pWav, pAllocationCallbacks); + wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, ma_wav_dr_callback__tell, pWav, pAllocationCallbacks); if (wavResult != MA_TRUE) { return MA_INVALID_FILE; } @@ -62275,7 +63965,6 @@ static ma_result ma_decoder_init_wav_from_memory__internal(const void* pData, si /* FLAC */ #ifdef ma_dr_flac_h -#define MA_HAS_FLAC typedef struct { @@ -62363,8 +64052,10 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl MA_ASSERT(pFlac != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_flac_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_FLAC_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_FLAC_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pFlac->onSeek(pFlac->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62374,6 +64065,26 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl return MA_TRUE; } + +static ma_bool32 ma_flac_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_flac* pFlac = (ma_flac*)pUserData; + ma_result result; + + MA_ASSERT(pFlac != NULL); + MA_ASSERT(pCursor != NULL); + + if (pFlac->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pFlac->onTell(pFlac->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_flac_init_internal(const ma_decoding_backend_config* pConfig, ma_flac* pFlac) @@ -62425,7 +64136,7 @@ MA_API ma_result ma_flac_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_ #if !defined(MA_NO_FLAC) { - pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, pFlac, pAllocationCallbacks); + pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, ma_flac_dr_callback__tell, pFlac, pAllocationCallbacks); if (pFlac->dr == NULL) { return MA_INVALID_FILE; } @@ -62897,7 +64608,6 @@ static ma_result ma_decoder_init_flac_from_memory__internal(const void* pData, s /* MP3 */ #ifdef ma_dr_mp3_h -#define MA_HAS_MP3 typedef struct { @@ -62986,9 +64696,12 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 MA_ASSERT(pMP3 != NULL); - maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_mp3_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_MP3_SEEK_SET) { + maSeekOrigin = ma_seek_origin_start; + } else if (origin == MA_DR_MP3_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } else { + maSeekOrigin = ma_seek_origin_current; } result = pMP3->onSeek(pMP3->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62998,6 +64711,21 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 return MA_TRUE; } + +static ma_bool32 ma_mp3_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_mp3* pMP3 = (ma_mp3*)pUserData; + ma_result result; + + MA_ASSERT(pMP3 != NULL); + + result = pMP3->onTell(pMP3->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; + } + + return MA_TRUE; +} #endif static ma_result ma_mp3_init_internal(const ma_decoding_backend_config* pConfig, ma_mp3* pMP3) @@ -63098,7 +64826,7 @@ MA_API ma_result ma_mp3_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 mp3Result; - mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, pMP3, pAllocationCallbacks); + mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, ma_mp3_dr_callback__tell, NULL, pMP3, pAllocationCallbacks); if (mp3Result != MA_TRUE) { return MA_INVALID_FILE; } @@ -64557,11 +66285,9 @@ static ma_result ma_decoder_init__internal(ma_decoder_read_proc onRead, ma_decod We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(pConfig, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(pConfig, pDecoder); - if (result != MA_SUCCESS) { - onSeek(pDecoder, 0, ma_seek_origin_start); - } + onSeek(pDecoder, 0, ma_seek_origin_start); } /* @@ -64825,14 +66551,6 @@ MA_API ma_result ma_decoder_init_memory(const void* pData, size_t dataSize, cons /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -64997,14 +66715,16 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e ext1 = extension; ext2 = ma_path_extension_w(path); -#if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) - return _wcsicmp(ext1, ext2) == 0; -#else - /* - I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This - isn't the most efficient way to do it, but it should work OK. - */ + #if (defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__)) && !defined(MA_XBOX_NXDK) { + return _wcsicmp(ext1, ext2) == 0; + } + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) + { + /* + I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This + isn't the most efficient way to do it, but it should work OK. + */ char ext1MB[4096]; char ext2MB[4096]; const wchar_t* pext1 = ext1; @@ -65024,7 +66744,13 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e return strcasecmp(ext1MB, ext2MB) == 0; } -#endif + #else + { + /* Getting here means we don't have a way to do a case-sensitive comparison for wide strings. Fall back to a simple case-sensitive comparison. */ + /* TODO: Implement our own wchar_t-to-char conversion routine and then use the char* version for comparing. */ + return ma_wcscmp(ext1, ext2) == 0; + } + #endif } #endif /* MA_HAS_PATH_API */ @@ -65125,11 +66851,9 @@ MA_API ma_result ma_decoder_init_vfs(ma_vfs* pVFS, const char* pFilePath, const We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -65258,11 +66982,9 @@ MA_API ma_result ma_decoder_init_vfs_w(ma_vfs* pVFS, const wchar_t* pFilePath, c We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -65444,14 +67166,6 @@ MA_API ma_result ma_decoder_init_file(const char* pFilePath, const ma_decoder_co /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -65594,14 +67308,6 @@ MA_API ma_result ma_decoder_init_file_w(const wchar_t* pFilePath, const ma_decod /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -66119,10 +67825,18 @@ static ma_bool32 ma_encoder__internal_on_seek_wav(void* pUserData, int offset, m { ma_encoder* pEncoder = (ma_encoder*)pUserData; ma_result result; + ma_seek_origin maSeekOrigin; MA_ASSERT(pEncoder != NULL); - result = pEncoder->onSeek(pEncoder, offset, (origin == ma_dr_wav_seek_origin_start) ? ma_seek_origin_start : ma_seek_origin_current); + maSeekOrigin = ma_seek_origin_start; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } + + result = pEncoder->onSeek(pEncoder, offset, maSeekOrigin); if (result != MA_SUCCESS) { return MA_FALSE; } else { @@ -67644,7 +69358,7 @@ static MA_INLINE ma_uint32 ma_hash_getblock(const ma_uint32* blocks, int i) ma_uint32 block; /* Try silencing a sanitization warning about unaligned access by doing a memcpy() instead of assignment. */ - MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * sizeof(block)), sizeof(block)); + MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * (int) sizeof(block)), sizeof(block)); if (ma_is_little_endian()) { return block; @@ -67720,7 +69434,7 @@ static ma_uint32 ma_hash_string_32(const char* str) static ma_uint32 ma_hash_string_w_32(const wchar_t* str) { - return ma_hash_32(str, (int)wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); + return ma_hash_32(str, (int)ma_wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); } @@ -67880,6 +69594,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_min(pDataBufferNode->pChildHi); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_buffer_node_find_inorder_predecessor(ma_resource_manager_data_buffer_node* pDataBufferNode) { MA_ASSERT(pDataBufferNode != NULL); @@ -67887,6 +69602,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_max(pDataBufferNode->pChildLo); } +#endif static ma_result ma_resource_manager_data_buffer_node_remove(ma_resource_manager* pResourceManager, ma_resource_manager_data_buffer_node* pDataBufferNode) { @@ -68237,6 +69953,7 @@ MA_API ma_resource_manager_config ma_resource_manager_config_init(void) config.decodedSampleRate = 0; config.jobThreadCount = 1; /* A single miniaudio-managed job thread by default. */ config.jobQueueCapacity = MA_JOB_TYPE_RESOURCE_MANAGER_QUEUE_CAPACITY; + config.resampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); /* Format/channels/rate doesn't matter here. */ /* Flags. */ config.flags = 0; @@ -68490,6 +70207,7 @@ static ma_decoder_config ma_resource_manager__init_decoder_config(ma_resource_ma config.ppCustomBackendVTables = pResourceManager->config.ppCustomDecodingBackendVTables; config.customBackendCount = pResourceManager->config.customDecodingBackendCount; config.pCustomBackendUserData = pResourceManager->config.pCustomDecodingBackendUserData; + config.resampling = pResourceManager->config.resampling; return config; } @@ -69009,16 +70727,19 @@ static ma_result ma_resource_manager_data_buffer_node_acquire_critical_section(m /* Failed to post job. Probably ran out of memory. */ ma_log_postf(ma_resource_manager_get_log(pResourceManager), MA_LOG_LEVEL_ERROR, "Failed to post MA_JOB_TYPE_RESOURCE_MANAGER_LOAD_DATA_BUFFER_NODE job. %s.\n", ma_result_description(result)); - /* - Fences were acquired before posting the job, but since the job was not able to - be posted, we need to make sure we release them so nothing gets stuck waiting. - */ - if (pInitFence != NULL) { ma_fence_release(pInitFence); } - if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } - if ((flags & MA_RESOURCE_MANAGER_DATA_SOURCE_FLAG_WAIT_INIT) != 0) { ma_resource_manager_inline_notification_uninit(pInitNotification); } else { + /* + Fences were acquired before posting the job, but since the job was not able to + be posted, we need to make sure we release them so nothing gets stuck waiting. + + In the WAIT_INIT case, these will have already been released in ma_job_process() + so we should only release fences in this branch. + */ + if (pInitFence != NULL) { ma_fence_release(pInitFence); } + if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } + /* These will have been freed by the job thread, but with WAIT_INIT they will already have happened since the job has already been handled. */ ma_free(pFilePathCopy, &pResourceManager->config.allocationCallbacks); ma_free(pFilePathWCopy, &pResourceManager->config.allocationCallbacks); @@ -69812,13 +71533,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_data_format(ma_resource_man MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pCursor) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pCursor == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + *pCursor = 0; switch (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode)) @@ -69852,13 +71573,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_res MA_API ma_result ma_resource_manager_data_buffer_get_length_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pLength) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pLength == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + if (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode) == ma_resource_manager_data_supply_type_unknown) { return MA_BUSY; /* Still loading. */ } @@ -71213,8 +72934,6 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* return ma_resource_manager_post_job(pResourceManager, pJob); /* Out of order. */ } - ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); - /* The event needs to be signalled last. */ if (pJob->data.resourceManager.freeDataBufferNode.pDoneNotification != NULL) { ma_async_notification_signal(pJob->data.resourceManager.freeDataBufferNode.pDoneNotification); @@ -71225,6 +72944,9 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* } ma_atomic_fetch_add_32(&pDataBufferNode->executionPointer, 1); + + ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); + return MA_SUCCESS; } @@ -72097,6 +73819,15 @@ MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 glo return ma_node_set_time(&pNodeGraph->endpoint, globalTime); /* Global time is just the local time of the endpoint. */ } +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph) +{ + if (pNodeGraph == NULL) { + return 0; + } + + return pNodeGraph->processingSizeInFrames; +} + #define MA_NODE_OUTPUT_BUS_FLAG_HAS_READ 0x01 /* Whether or not this bus ready to read more data. Only used on nodes with multiple output buses. */ @@ -73256,12 +74987,12 @@ MA_API ma_node_state ma_node_get_state_by_time_range(const ma_node* pNode, ma_ui its start time not having been reached yet. Also, the stop time may have also been reached in which case it'll be considered stopped. */ - if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeBeg) { - return ma_node_state_stopped; /* Start time has not yet been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_stopped) < globalTimeBeg) { + return ma_node_state_stopped; /* End time is before the start of the range. */ } - if (ma_node_get_state_time(pNode, ma_node_state_stopped) <= globalTimeEnd) { - return ma_node_state_stopped; /* Stop time has been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeEnd) { + return ma_node_state_stopped; /* Start time is after the end of the range. */ } /* Getting here means the node is marked as started and is within its start/stop times. */ @@ -73341,14 +75072,14 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde return MA_INVALID_ARGS; /* Invalid output bus index. */ } + globalTimeBeg = globalTime; + globalTimeEnd = globalTime + frameCount; + /* Don't do anything if we're in a stopped state. */ - if (ma_node_get_state_by_time_range(pNode, globalTime, globalTime + frameCount) != ma_node_state_started) { + if (ma_node_get_state_by_time_range(pNode, globalTimeBeg, globalTimeEnd) != ma_node_state_started) { return MA_SUCCESS; /* We're in a stopped state. This is not an error - we just need to not read anything. */ } - - globalTimeBeg = globalTime; - globalTimeEnd = globalTime + frameCount; startTime = ma_node_get_state_time(pNode, ma_node_state_started); stopTime = ma_node_get_state_time(pNode, ma_node_state_stopped); @@ -73361,11 +75092,16 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde therefore need to offset it by a number of frames to accommodate. The same thing applies for the stop time. */ - timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(globalTimeEnd - startTime) : 0; + timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(startTime - globalTimeBeg) : 0; timeOffsetEnd = (globalTimeEnd > stopTime) ? (ma_uint32)(globalTimeEnd - stopTime) : 0; /* Trim based on the start offset. We need to silence the start of the buffer. */ if (timeOffsetBeg > 0) { + MA_ASSERT(timeOffsetBeg <= frameCount); + if (timeOffsetBeg > frameCount) { + timeOffsetBeg = frameCount; + } + ma_silence_pcm_frames(pFramesOut, timeOffsetBeg, ma_format_f32, ma_node_get_output_channels(pNode, outputBusIndex)); pFramesOut += timeOffsetBeg * ma_node_get_output_channels(pNode, outputBusIndex); frameCount -= timeOffsetBeg; @@ -73373,6 +75109,11 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde /* Trim based on the end offset. We don't need to silence the tail section because we'll just have a reduced value written to pFramesRead. */ if (timeOffsetEnd > 0) { + MA_ASSERT(timeOffsetEnd <= frameCount); + if (timeOffsetEnd > frameCount) { + timeOffsetEnd = frameCount; + } + frameCount -= timeOffsetEnd; } @@ -74787,12 +76528,20 @@ static void ma_sound_set_at_end(ma_sound* pSound, ma_bool32 atEnd) MA_ASSERT(pSound != NULL); ma_atomic_exchange_32(&pSound->atEnd, atEnd); + /* + When this function is called the state of the sound will not yet be in a stopped state. This makes it confusing + because an end callback will intuitively expect ma_sound_is_playing() to return false from inside the callback. + I'm therefore no longer firing the callback here and will instead fire it manually in the *next* processing step + when the state should be set to stopped as expected. + */ + #if 0 /* Fire any callbacks or events. */ if (atEnd) { if (pSound->endCallback != NULL) { pSound->endCallback(pSound->pEndCallbackUserData, pSound); } } + #endif } static ma_bool32 ma_sound_get_at_end(const ma_sound* pSound) @@ -74812,6 +76561,7 @@ MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_e config.isPitchDisabled = (flags & MA_SOUND_FLAG_NO_PITCH) != 0; config.isSpatializationDisabled = (flags & MA_SOUND_FLAG_NO_SPATIALIZATION) != 0; config.monoExpansionMode = pEngine->monoExpansionMode; + config.resampling = pEngine->pitchResamplingConfig; return config; } @@ -74838,7 +76588,7 @@ static void ma_engine_node_update_pitch_if_required(ma_engine_node* pEngineNode) if (isUpdateRequired) { float basePitch = (float)pEngineNode->sampleRate / ma_engine_get_sample_rate(pEngineNode->pEngine); - ma_linear_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); + ma_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); } } @@ -74857,22 +76607,6 @@ static ma_bool32 ma_engine_node_is_spatialization_enabled(const ma_engine_node* return !ma_atomic_load_explicit_32(&pEngineNode->isSpatializationDisabled, ma_atomic_memory_order_acquire); } -static ma_uint64 ma_engine_node_get_required_input_frame_count(const ma_engine_node* pEngineNode, ma_uint64 outputFrameCount) -{ - ma_uint64 inputFrameCount = 0; - - if (ma_engine_node_is_pitching_enabled(pEngineNode)) { - ma_result result = ma_linear_resampler_get_required_input_frame_count(&pEngineNode->resampler, outputFrameCount, &inputFrameCount); - if (result != MA_SUCCESS) { - inputFrameCount = 0; - } - } else { - inputFrameCount = outputFrameCount; /* No resampling, so 1:1. */ - } - - return inputFrameCount; -} - static ma_result ma_engine_node_set_volume(ma_engine_node* pEngineNode, float volume) { if (pEngineNode == NULL) { @@ -75014,7 +76748,7 @@ static void ma_engine_node_process_pcm_frames__general(ma_engine_node* pEngineNo ma_uint64 resampleFrameCountIn = framesAvailableIn; ma_uint64 resampleFrameCountOut = framesAvailableOut; - ma_linear_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); + ma_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); isWorkingBufferValid = MA_TRUE; framesJustProcessedIn = (ma_uint32)resampleFrameCountIn; @@ -75138,6 +76872,11 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* If we're marked at the end we need to stop the sound and do nothing. */ if (ma_sound_at_end(pSound)) { ma_sound_stop(pSound); + + if (pSound->endCallback != NULL) { + pSound->endCallback(pSound->pEndCallbackUserData, pSound); + } + *pFrameCountOut = 0; return; } @@ -75175,55 +76914,74 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* Keep reading until we've read as much as was requested or we reach the end of the data source. */ while (totalFramesRead < frameCount) { ma_uint32 framesRemaining = frameCount - totalFramesRead; - ma_uint32 framesToRead; ma_uint64 framesJustRead; ma_uint32 frameCountIn; ma_uint32 frameCountOut; const float* pRunningFramesIn; float* pRunningFramesOut; - /* - The first thing we need to do is read into the temporary buffer. We can calculate exactly - how many input frames we'll need after resampling. - */ - framesToRead = (ma_uint32)ma_engine_node_get_required_input_frame_count(&pSound->engineNode, framesRemaining); - if (framesToRead > tempCapInFrames) { - framesToRead = tempCapInFrames; - } + /* If there's any input frames sitting in the cache get those processed first. */ + if (pSound->processingCacheFramesRemaining > 0) { + pRunningFramesIn = pSound->pProcessingCache; + frameCountIn = pSound->processingCacheFramesRemaining; - result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToRead, &framesJustRead); + pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); + frameCountOut = framesRemaining; - /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ - if (result == MA_AT_END) { - ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ - } - - pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); - - frameCountIn = (ma_uint32)framesJustRead; - frameCountOut = framesRemaining; - - /* Convert if necessary. */ - if (dataSourceFormat == ma_format_f32) { - /* Fast path. No data conversion necessary. */ - pRunningFramesIn = (float*)temp; ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); + + MA_ASSERT(frameCountIn <= pSound->processingCacheFramesRemaining); + pSound->processingCacheFramesRemaining -= frameCountIn; + + /* Move any remaining data in the cache down. */ + if (pSound->processingCacheFramesRemaining > 0) { + MA_MOVE_MEMORY(pSound->pProcessingCache, ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, frameCountIn, dataSourceChannels), pSound->processingCacheFramesRemaining * ma_get_bytes_per_frame(ma_format_f32, dataSourceChannels)); + } + + totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ + + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; /* Might have reached the end. */ + } } else { - /* Slow path. Need to do sample format conversion to f32. If we give the f32 buffer the same count as the first temp buffer, we're guaranteed it'll be large enough. */ - float tempf32[MA_DATA_CONVERTER_STACK_BUFFER_SIZE]; /* Do not do `MA_DATA_CONVERTER_STACK_BUFFER_SIZE/sizeof(float)` here like we've done in other places. */ - ma_convert_pcm_frames_format(tempf32, ma_format_f32, temp, dataSourceFormat, framesJustRead, dataSourceChannels, ma_dither_mode_none); + /* Getting here means there's nothing in the cache. Read more data from the data source. */ + if (dataSourceFormat == ma_format_f32) { + /* Fast path. No conversion to f32 necessary. */ + result = ma_data_source_read_pcm_frames(pSound->pDataSource, pSound->pProcessingCache, pSound->processingCacheCap, &framesJustRead); + } else { + /* Slow path. Need to convert to f32. */ + ma_uint64 totalFramesConverted = 0; - /* Now that we have our samples in f32 format we can process like normal. */ - pRunningFramesIn = tempf32; - ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); - } + while (totalFramesConverted < pSound->processingCacheCap) { + ma_uint64 framesConverted; + ma_uint32 framesToConvertThisIteration = pSound->processingCacheCap - (ma_uint32)totalFramesConverted; + if (framesToConvertThisIteration > tempCapInFrames) { + framesToConvertThisIteration = tempCapInFrames; + } - /* We should have processed all of our input frames since we calculated the required number of input frames at the top. */ - MA_ASSERT(frameCountIn == framesJustRead); - totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ + result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToConvertThisIteration, &framesConverted); + if (result != MA_SUCCESS) { + break; + } - if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { - break; /* Might have reached the end. */ + ma_convert_pcm_frames_format(ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, totalFramesConverted, dataSourceChannels), ma_format_f32, temp, dataSourceFormat, framesConverted, dataSourceChannels, ma_dither_mode_none); + totalFramesConverted += framesConverted; + } + + framesJustRead = totalFramesConverted; + } + + MA_ASSERT(framesJustRead <= pSound->processingCacheCap); + pSound->processingCacheFramesRemaining = (ma_uint32)framesJustRead; + + /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ + if (result == MA_AT_END) { + ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ + } + + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; + } } } } @@ -75246,25 +77004,6 @@ static void ma_engine_node_process_pcm_frames__group(ma_node* pNode, const float ma_engine_node_process_pcm_frames__general((ma_engine_node*)pNode, ppFramesIn, pFrameCountIn, ppFramesOut, pFrameCountOut); } -static ma_result ma_engine_node_get_required_input_frame_count__group(ma_node* pNode, ma_uint32 outputFrameCount, ma_uint32* pInputFrameCount) -{ - ma_uint64 inputFrameCount; - - MA_ASSERT(pInputFrameCount != NULL); - - /* Our pitch will affect this calculation. We need to update it. */ - ma_engine_node_update_pitch_if_required((ma_engine_node*)pNode); - - inputFrameCount = ma_engine_node_get_required_input_frame_count((ma_engine_node*)pNode, outputFrameCount); - if (inputFrameCount > 0xFFFFFFFF) { - inputFrameCount = 0xFFFFFFFF; /* Will never happen because miniaudio will only ever process in relatively small chunks. */ - } - - *pInputFrameCount = (ma_uint32)inputFrameCount; - - return MA_SUCCESS; -} - static ma_node_vtable g_ma_engine_node_vtable__sound = { @@ -75278,7 +77017,7 @@ static ma_node_vtable g_ma_engine_node_vtable__sound = static ma_node_vtable g_ma_engine_node_vtable__group = { ma_engine_node_process_pcm_frames__group, - ma_engine_node_get_required_input_frame_count__group, + NULL, /* onGetRequiredInputFrameCount */ 1, /* Groups have one input bus. */ 1, /* Groups have one output bus. */ MA_NODE_FLAG_DIFFERENT_PROCESSING_RATES /* The engine node does resampling so should let miniaudio know about it. */ @@ -75324,9 +77063,10 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo ma_result result; size_t tempHeapSize; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_spatializer_config spatializerConfig; ma_gainer_config gainerConfig; + ma_uint32 sampleRate; ma_uint32 channelsIn; ma_uint32 channelsOut; ma_channel defaultStereoChannelMap[2] = {MA_CHANNEL_SIDE_LEFT, MA_CHANNEL_SIDE_RIGHT}; /* <-- Consistent with the default channel map of a stereo listener. Means channel conversion can run on a fast path. */ @@ -75345,6 +77085,7 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo pHeapLayout->sizeInBytes = 0; + sampleRate = (pConfig->sampleRate > 0) ? pConfig->sampleRate : ma_engine_get_sample_rate(pConfig->pEngine); channelsIn = (pConfig->channelsIn != 0) ? pConfig->channelsIn : ma_engine_get_channels(pConfig->pEngine); channelsOut = (pConfig->channelsOut != 0) ? pConfig->channelsOut : ma_engine_get_channels(pConfig->pEngine); @@ -75364,10 +77105,13 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo /* Resmapler. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, channelsIn, 1, 1); /* Input and output sample rates don't affect the calculation of the heap size. */ - resamplerConfig.lpfOrder = 0; + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = channelsIn; + resamplerConfig.sampleRateIn = sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pConfig->pEngine); - result = ma_linear_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); + result = ma_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); if (result != MA_SUCCESS) { return result; /* Failed to retrieve the size of the heap for the resampler. */ } @@ -75435,7 +77179,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p ma_result result; ma_engine_node_heap_layout heapLayout; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_fader_config faderConfig; ma_spatializer_config spatializerConfig; ma_panner_config pannerConfig; @@ -75510,10 +77254,13 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p */ /* We'll always do resampling first. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, baseNodeConfig.pInputChannels[0], pEngineNode->sampleRate, ma_engine_get_sample_rate(pEngineNode->pEngine)); - resamplerConfig.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = baseNodeConfig.pInputChannels[0]; + resamplerConfig.sampleRateIn = pEngineNode->sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pEngineNode->pEngine); - result = ma_linear_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); + result = ma_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); if (result != MA_SUCCESS) { goto error1; } @@ -75572,7 +77319,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p /* No need for allocation callbacks here because we use a preallocated heap. */ error3: ma_spatializer_uninit(&pEngineNode->spatializer, NULL); -error2: ma_linear_resampler_uninit(&pEngineNode->resampler, NULL); +error2: ma_resampler_uninit(&pEngineNode->resampler, NULL); error1: ma_node_uninit(&pEngineNode->baseNode, NULL); error0: return result; } @@ -75621,7 +77368,7 @@ MA_API void ma_engine_node_uninit(ma_engine_node* pEngineNode, const ma_allocati } ma_spatializer_uninit(&pEngineNode->spatializer, pAllocationCallbacks); - ma_linear_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); + ma_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); /* Free the heap last. */ if (pEngineNode->_ownsHeap) { @@ -75643,8 +77390,12 @@ MA_API ma_sound_config ma_sound_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } config.rangeEndInPCMFrames = ~((ma_uint64)0); @@ -75666,8 +77417,12 @@ MA_API ma_sound_group_config ma_sound_group_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } return config; @@ -75679,8 +77434,12 @@ MA_API ma_engine_config ma_engine_config_init(void) ma_engine_config config; MA_ZERO_OBJECT(&config); - config.listenerCount = 1; /* Always want at least one listener. */ - config.monoExpansionMode = ma_mono_expansion_mode_default; + config.listenerCount = 1; /* Always want at least one listener. */ + config.monoExpansionMode = ma_mono_expansion_mode_default; + config.resourceManagerResampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ return config; } @@ -75761,6 +77520,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng pEngine->defaultVolumeSmoothTimeInPCMFrames = engineConfig.defaultVolumeSmoothTimeInPCMFrames; pEngine->onProcess = engineConfig.onProcess; pEngine->pProcessUserData = engineConfig.pProcessUserData; + pEngine->pitchResamplingConfig = engineConfig.pitchResampling; ma_allocation_callbacks_init_copy(&pEngine->allocationCallbacks, &engineConfig.allocationCallbacks); #if !defined(MA_NO_RESOURCE_MANAGER) @@ -75943,6 +77703,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng resourceManagerConfig.decodedSampleRate = ma_engine_get_sample_rate(pEngine); ma_allocation_callbacks_init_copy(&resourceManagerConfig.allocationCallbacks, &pEngine->allocationCallbacks); resourceManagerConfig.pVFS = engineConfig.pResourceManagerVFS; + resourceManagerConfig.resampling = engineConfig.resourceManagerResampling; /* The Emscripten build cannot use threads unless it's targeting pthreads. */ #if defined(MA_EMSCRIPTEN) && !defined(__EMSCRIPTEN_PTHREADS__) @@ -76668,13 +78429,32 @@ static ma_result ma_sound_init_from_data_source_internal(ma_engine* pEngine, con } + /* + When pulling data from a data source we need a processing cache to hold onto unprocessed input data from the data source + after doing resampling. + */ + if (pSound->pDataSource != NULL) { + pSound->processingCacheFramesRemaining = 0; + pSound->processingCacheCap = ma_node_graph_get_processing_size_in_frames(&pEngine->nodeGraph); + if (pSound->processingCacheCap == 0) { + pSound->processingCacheCap = 512; + } + + pSound->pProcessingCache = (float*)ma_calloc(pSound->processingCacheCap * ma_get_bytes_per_frame(ma_format_f32, engineNodeConfig.channelsIn), &pEngine->allocationCallbacks); + if (pSound->pProcessingCache == NULL) { + ma_engine_node_uninit(&pSound->engineNode, &pEngine->allocationCallbacks); + return MA_OUT_OF_MEMORY; + } + } + + /* Apply initial range and looping state to the data source if applicable. */ if (pConfig->rangeBegInPCMFrames != 0 || pConfig->rangeEndInPCMFrames != ~((ma_uint64)0)) { ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->rangeBegInPCMFrames, pConfig->rangeEndInPCMFrames); } if (pConfig->loopPointBegInPCMFrames != 0 || pConfig->loopPointEndInPCMFrames != ~((ma_uint64)0)) { - ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); + ma_data_source_set_loop_point_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); } ma_sound_set_looping(pSound, pConfig->isLooping || ((pConfig->flags & MA_SOUND_FLAG_LOOPING) != 0)); @@ -76736,6 +78516,7 @@ MA_API ma_result ma_sound_init_from_file_internal(ma_engine* pEngine, const ma_s result = ma_resource_manager_data_source_init_ex(pEngine->pResourceManager, &resourceManagerDataSourceConfig, pSound->pResourceManagerDataSource); if (result != MA_SUCCESS) { + ma_free(pSound->pResourceManagerDataSource, &pEngine->allocationCallbacks); goto done; } @@ -76904,6 +78685,11 @@ MA_API void ma_sound_uninit(ma_sound* pSound) */ ma_engine_node_uninit(&pSound->engineNode, &pSound->engineNode.pEngine->allocationCallbacks); + if (pSound->pProcessingCache != NULL) { + ma_free(pSound->pProcessingCache, &pSound->engineNode.pEngine->allocationCallbacks); + pSound->pProcessingCache = NULL; + } + /* Once the sound is detached from the group we can guarantee that it won't be referenced by the mixer thread which means it's safe for us to destroy the data source. */ #ifndef MA_NO_RESOURCE_MANAGER if (pSound->ownsDataSource) { @@ -76999,6 +78785,27 @@ MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_ui return ma_sound_stop_with_fade_in_pcm_frames(pSound, (fadeLengthInMilliseconds * sampleRate) / 1000); } +MA_API void ma_sound_reset_start_time(ma_sound* pSound) +{ + ma_sound_set_start_time_in_pcm_frames(pSound, 0); +} + +MA_API void ma_sound_reset_stop_time(ma_sound* pSound) +{ + ma_sound_set_stop_time_in_pcm_frames(pSound, ~(ma_uint64)0); +} + +MA_API void ma_sound_reset_fade(ma_sound* pSound) +{ + ma_sound_set_fade_in_pcm_frames(pSound, 0, 1, 0); +} + +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound) +{ + ma_sound_reset_stop_time(pSound); + ma_sound_reset_fade(pSound); +} + MA_API void ma_sound_set_volume(ma_sound* pSound, float volume) { if (pSound == NULL) { @@ -77541,7 +79348,12 @@ MA_API ma_uint64 ma_sound_get_time_in_pcm_frames(const ma_sound* pSound) MA_API ma_uint64 ma_sound_get_time_in_milliseconds(const ma_sound* pSound) { - return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + ma_uint32 sampleRate = ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + if (sampleRate == 0) { + return 0; /* Prevent a division by zero. */ + } + + return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / sampleRate; } MA_API void ma_sound_set_looping(ma_sound* pSound, ma_bool32 isLooping) @@ -77625,7 +79437,7 @@ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeco return ma_sound_seek_to_pcm_frame(pSound, frameIndex); } -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77645,7 +79457,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } if (pSampleRate != NULL) { - *pSampleRate = pSound->engineNode.resampler.config.sampleRateIn; + *pSampleRate = pSound->engineNode.resampler.sampleRateIn; } if (pChannelMap != NULL) { @@ -77658,7 +79470,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } } -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor) +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor) { ma_uint64 seekTarget; @@ -77680,7 +79492,7 @@ MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* } } -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength) +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77694,7 +79506,7 @@ MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* return ma_data_source_get_length_in_pcm_frames(pSound->pDataSource, pLength); } -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor) +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor) { ma_result result; ma_uint64 cursorInPCMFrames; @@ -77720,7 +79532,7 @@ MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor return MA_SUCCESS; } -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength) +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -78539,12 +80351,12 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint ma_uint64 bytesRemainingToSeek = offset; while (bytesRemainingToSeek > 0) { if (bytesRemainingToSeek > 0x7FFFFFFF) { - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek -= 0x7FFFFFFF; } else { - if (!onSeek(pUserData, (int)bytesRemainingToSeek, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek = 0; @@ -78555,17 +80367,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint MA_PRIVATE ma_bool32 ma_dr_wav__seek_from_start(ma_dr_wav_seek_proc onSeek, ma_uint64 offset, void* pUserData) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_start); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_SET); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_start)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; for (;;) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_current); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_CUR); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -78588,7 +80400,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek(ma_dr_wav_seek_proc onSeek, void* pUserD if (!onSeek(pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_wav_seek_origin_start) { + if (origin == MA_DR_WAV_SEEK_SET) { *pCursor = offset; } else { *pCursor += offset; @@ -78707,12 +80519,12 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa ma_uint8 smplLoopData[MA_DR_WAV_SMPL_LOOP_BYTES]; bytesJustRead = ma_dr_wav__metadata_parser_read(pParser, smplLoopData, sizeof(smplLoopData), &totalBytesRead); if (bytesJustRead == sizeof(smplLoopData)) { - pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); - pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); - pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); - pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); - pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); - pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); + pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); + pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); + pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); + pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); + pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); + pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); } else { break; } @@ -78756,7 +80568,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_cue_to_metadata_obj(ma_dr_wav__metadata_par pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId[3] = cuePointData[11]; pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart = ma_dr_wav_bytes_to_u32(cuePointData + 12); pMetadata->data.cue.pCuePoints[iCuePoint].blockStart = ma_dr_wav_bytes_to_u32(cuePointData + 16); - pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); + pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); } else { break; } @@ -79096,7 +80908,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse if (pParser->stage == ma_dr_wav__metadata_parser_stage_count) { ma_uint8 buffer[4]; size_t bytesJustRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, 28, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 28, MA_DR_WAV_SEEK_CUR)) { return bytesRead; } bytesRead += 28; @@ -79191,7 +81003,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse return bytesRead; } allocSizeNeeded += ma_dr_wav__strlen(buffer) + 1; - allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES; + allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES + 1; ma_dr_wav__metadata_request_extra_memory_for_stage_2(pParser, allocSizeNeeded, 1); pParser->metadataCount += 1; } else { @@ -79274,6 +81086,16 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_album); } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_tracknumber, "ITRK")) { subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_tracknumber); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_location, "IARL")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_location); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_organization, "ICMS")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_organization); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_keywords, "IKEY")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_keywords); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_medium, "IMED")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_medium); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_description, "ISBJ")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_description); } else if ((allowedMetadataTypes & ma_dr_wav_metadata_type_unknown) != 0) { subchunkBytesRead = ma_dr_wav__metadata_process_unknown_chunk(pParser, subchunkId, subchunkDataSize, listType); } @@ -79281,13 +81103,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse MA_DR_WAV_ASSERT(subchunkBytesRead <= subchunkDataSize); if (subchunkBytesRead < subchunkDataSize) { ma_uint64 bytesToSeek = subchunkDataSize - subchunkBytesRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += bytesToSeek; } if ((subchunkDataSize % 2) == 1) { - if (!pParser->onSeek(pParser->pReadSeekUserData, 1, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 1, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += 1; @@ -79324,7 +81146,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT) return ma_dr_wav_bytes_to_u16(pFMT->subFormat); } } -MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pReadSeekTellUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pWav == NULL || onRead == NULL || onSeek == NULL) { return MA_FALSE; @@ -79332,7 +81154,8 @@ MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRe MA_DR_WAV_ZERO_MEMORY(pWav, sizeof(*pWav)); pWav->onRead = onRead; pWav->onSeek = onSeek; - pWav->pUserData = pReadSeekUserData; + pWav->onTell = onTell; + pWav->pUserData = pReadSeekTellUserData; pWav->allocationCallbacks = ma_dr_wav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; @@ -79546,14 +81369,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p fmt.channelMask = ma_dr_wav_bytes_to_u32_ex(fmtext + 2, pWav->container); ma_dr_wav_bytes_to_guid(fmtext + 6, fmt.subFormat); } else { - if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } } cursor += fmt.extendedSize; bytesReadSoFar += fmt.extendedSize; } - if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } cursor += (header.sizeInBytes - bytesReadSoFar); @@ -79704,15 +81527,26 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p return MA_FALSE; } offset = ma_dr_wav_bytes_to_u32_ex(offsetAndBlockSizeData + 0, pWav->container); - if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { - return MA_FALSE; - } - cursor += offset; - pWav->dataChunkDataPos = cursor; + pWav->dataChunkDataPos = cursor + offset; dataChunkSize = chunkSize; - if (sequential || !isProcessingMetadata) { - break; + if (dataChunkSize > offset) { + dataChunkSize -= offset; } else { + dataChunkSize = 0; + } + if (sequential) { + if (foundChunk_fmt) { + if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { + return MA_FALSE; + } + cursor += offset; + break; + } else { + return MA_FALSE; + } + } else { + chunkSize += header.paddingSize; + chunkSize -= sizeof(offsetAndBlockSizeData); if (ma_dr_wav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData) == MA_FALSE) { break; } @@ -79776,6 +81610,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->pMetadata = metadataParser.pMetadata; pWav->metadataCount = metadataParser.metadataCount; } + if (pWav->onTell != NULL && pWav->onSeek != NULL) { + if (pWav->onSeek(pWav->pUserData, 0, MA_DR_WAV_SEEK_END) == MA_TRUE) { + ma_int64 fileSize; + if (pWav->onTell(pWav->pUserData, &fileSize)) { + if (dataChunkSize + pWav->dataChunkDataPos > (ma_uint64)fileSize) { + dataChunkSize = (ma_uint64)fileSize - pWav->dataChunkDataPos; + } + } + } else { + } + } if (dataChunkSize == 0xFFFFFFFF && (pWav->container == ma_dr_wav_container_riff || pWav->container == ma_dr_wav_container_rifx) && pWav->isSequentialWrite == MA_FALSE) { dataChunkSize = 0; for (;;) { @@ -79795,8 +81640,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->sampleRate = fmt.sampleRate; pWav->channels = fmt.channels; pWav->bitsPerSample = fmt.bitsPerSample; - pWav->bytesRemaining = dataChunkSize; pWav->translatedFormatTag = translatedFormatTag; + if (!ma_dr_wav__is_compressed_format_tag(translatedFormatTag)) { + ma_uint32 bytesPerFrame = ma_dr_wav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame > 0) { + dataChunkSize -= (dataChunkSize % bytesPerFrame); + } + } + pWav->bytesRemaining = dataChunkSize; pWav->dataChunkDataSize = dataChunkSize; if (sampleCountFromFactChunk != 0) { pWav->totalPCMFrameCount = sampleCountFromFactChunk; @@ -79851,20 +81702,20 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p #endif return MA_TRUE; } -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_wav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); + return ma_dr_wav_init_ex(pWav, onRead, onSeek, onTell, NULL, pUserData, NULL, 0, pAllocationCallbacks); } -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pReadSeekTellUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, onChunk, pChunkUserData, flags); } -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, NULL, NULL, flags | MA_DR_WAV_WITH_METADATA); @@ -80026,8 +81877,8 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ for (iLoop = 0; iLoop < pMetadata->data.smpl.sampleLoopCount; ++iLoop) { bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].cuePointId); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].type); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleByteOffset); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleOffset); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].sampleFraction); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].playCount); } @@ -80061,7 +81912,7 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ bytesWritten += ma_dr_wav__write_or_count(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId, 4); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].blockStart); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset); } } break; case ma_dr_wav_metadata_type_acid: @@ -80147,15 +81998,20 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ if (pMetadata->type & ma_dr_wav_metadata_type_list_all_info_strings) { const char* pID = NULL; switch (pMetadata->type) { - case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; - case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; - case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; - case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; - case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; - case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; - case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; - case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; - case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; + case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; + case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; + case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; + case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; + case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; + case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; + case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; + case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_location: pID = "IARL"; break; + case ma_dr_wav_metadata_type_list_info_organization: pID = "ICMS"; break; + case ma_dr_wav_metadata_type_list_info_keywords: pID = "IKEY"; break; + case ma_dr_wav_metadata_type_list_info_medium: pID = "IMED"; break; + case ma_dr_wav_metadata_type_list_info_description: pID = "ISBJ"; break; default: break; } MA_DR_WAV_ASSERT(pID != NULL); @@ -80370,7 +82226,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init_write__internal(ma_dr_wav* pWav, const ma_dr } pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; if (pFormat->container == ma_dr_wav_container_riff) { - ma_uint32 chunkSizeRIFF = 28 + (ma_uint32)initialDataChunkSize; + ma_uint32 chunkSizeRIFF = 36 + (ma_uint32)initialDataChunkSize; runningPos += ma_dr_wav__write(pWav, "RIFF", 4); runningPos += ma_dr_wav__write_u32ne_to_le(pWav, chunkSizeRIFF); runningPos += ma_dr_wav__write(pWav, "WAVE", 4); @@ -80493,7 +82349,31 @@ MA_PRIVATE size_t ma_dr_wav__on_write_stdio(void* pUserData, const void* pData, } MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_stdio(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_wav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_WAV_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_WAV_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_WAV_ASSERT(pFileStdio != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -80502,7 +82382,7 @@ MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, cons MA_PRIVATE ma_bool32 ma_dr_wav_init_file__internal_FILE(ma_dr_wav* pWav, FILE* pFile, ma_dr_wav_chunk_proc onChunk, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; - result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, ma_dr_wav__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; @@ -80639,25 +82519,26 @@ MA_PRIVATE size_t ma_dr_wav__on_read_memory(void* pUserData, void* pBufferOut, s MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { - return MA_FALSE; - } - } else { - if (pWav->memoryStream.currentReadPos < (size_t)-offset) { - return MA_FALSE; - } - } - pWav->memoryStream.currentReadPos += offset; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStream.currentReadPos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStream.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStream.dataSize) { - pWav->memoryStream.currentReadPos = offset; - } else { - return MA_FALSE; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStream.dataSize) { + return MA_FALSE; + } + pWav->memoryStream.currentReadPos = (size_t)newCursor; return MA_TRUE; } MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) @@ -80691,25 +82572,34 @@ MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataI MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { - offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); - } - } else { - if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { - offset = -(int)pWav->memoryStreamWrite.currentWritePos; - } - } - pWav->memoryStreamWrite.currentWritePos += offset; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStreamWrite.currentWritePos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStreamWrite.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStreamWrite.dataSize) { - pWav->memoryStreamWrite.currentWritePos = offset; - } else { - pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStreamWrite.dataSize) { + return MA_FALSE; + } + pWav->memoryStreamWrite.currentWritePos = (size_t)newCursor; + return MA_TRUE; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + MA_DR_WAV_ASSERT(pWav != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pWav->memoryStream.currentReadPos; return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_memory(ma_dr_wav* pWav, const void* data, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -80721,7 +82611,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_ex(ma_dr_wav* pWav, const void* data, siz if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80734,7 +82624,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_with_metadata(ma_dr_wav* pWav, const void if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80793,30 +82683,30 @@ MA_API ma_result ma_dr_wav_uninit(ma_dr_wav* pWav) } if (pWav->onSeek && !pWav->isSequentialWrite) { if (pWav->container == ma_dr_wav_container_riff) { - if (pWav->onSeek(pWav->pUserData, 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 riffChunkSize = ma_dr_wav__riff_chunk_size_riff(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u32ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 dataChunkSize = ma_dr_wav__data_chunk_size_riff(pWav->dataChunkDataSize); ma_dr_wav__write_u32ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_w64) { - if (pWav->onSeek(pWav->pUserData, 16, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 16, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_rf64) { int ds64BodyPos = 12 + 8; - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_rf64(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_rf64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } @@ -80863,7 +82753,7 @@ MA_API size_t ma_dr_wav_read_raw(ma_dr_wav* pWav, size_t bytesToRead, void* pBuf if (bytesToSeek > 0x7FFFFFFF) { bytesToSeek = 0x7FFFFFFF; } - if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { break; } bytesRead += bytesToSeek; @@ -80962,7 +82852,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_seek_to_first_pcm_frame(ma_dr_wav* pWav) if (pWav->onWrite != NULL) { return MA_FALSE; } - if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, ma_dr_wav_seek_origin_start)) { + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } if (ma_dr_wav__is_compressed_format_tag(pWav->translatedFormatTag)) { @@ -81043,7 +82933,7 @@ MA_API ma_bool32 ma_dr_wav_seek_to_pcm_frame(ma_dr_wav* pWav, ma_uint64 targetFr } while (offset > 0) { int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); - if (!pWav->onSeek(pWav->pUserData, offset32, ma_dr_wav_seek_origin_current)) { + if (!pWav->onSeek(pWav->pUserData, offset32, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } pWav->readCursorInPCMFrames += offset32 / bytesPerFrame; @@ -81169,12 +83059,12 @@ MA_API ma_uint64 ma_dr_wav_write_pcm_frames(ma_dr_wav* pWav, ma_uint64 framesToW MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_uint64 framesToRead, ma_int16* pBufferOut) { ma_uint64 totalFramesRead = 0; - static ma_int32 adaptationTable[] = { + static const ma_int32 adaptationTable[] = { 230, 230, 230, 230, 307, 409, 512, 614, 768, 614, 512, 409, 307, 230, 230, 230 }; - static ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; - static ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + static const ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static const ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; MA_DR_WAV_ASSERT(pWav != NULL); MA_DR_WAV_ASSERT(framesToRead > 0); while (pWav->readCursorInPCMFrames < pWav->totalPCMFrameCount) { @@ -81193,7 +83083,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } else { @@ -81215,7 +83105,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table) || + pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } @@ -81252,6 +83143,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ if (pWav->channels == 1) { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -81276,6 +83170,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } else { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -81285,6 +83182,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample0; + if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; newSample1 += nibble1 * pWav->msadpcm.delta[1]; newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767); @@ -81307,11 +83207,11 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint { ma_uint64 totalFramesRead = 0; ma_uint32 iChannel; - static ma_int32 indexTable[16] = { + static const ma_int32 indexTable[16] = { -1, -1, -1, -1, 2, 4, 6, 8, -1, -1, -1, -1, 2, 4, 6, 8 }; - static ma_int32 stepTable[89] = { + static const ma_int32 stepTable[89] = { 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, @@ -81334,7 +83234,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81349,7 +83249,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable) || header[6] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81424,7 +83324,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint return totalFramesRead; } #ifndef MA_DR_WAV_NO_CONVERSION_API -static unsigned short g_ma_dr_wavAlawTable[256] = { +static const unsigned short ma_dr_wav_gAlawTable[256] = { 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, @@ -81442,7 +83342,7 @@ static unsigned short g_ma_dr_wavAlawTable[256] = { 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 }; -static unsigned short g_ma_dr_wavMulawTable[256] = { +static const unsigned short ma_dr_wav_gMulawTable[256] = { 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, @@ -81462,11 +83362,11 @@ static unsigned short g_ma_dr_wavMulawTable[256] = { }; static MA_INLINE ma_int16 ma_dr_wav__alaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavAlawTable[sampleIn]; + return (short)ma_dr_wav_gAlawTable[sampleIn]; } static MA_INLINE ma_int16 ma_dr_wav__mulaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavMulawTable[sampleIn]; + return (short)ma_dr_wav_gMulawTable[sampleIn]; } MA_PRIVATE void ma_dr_wav__pcm_to_s16(ma_int16* pOut, const ma_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) { @@ -82529,6 +84429,10 @@ MA_PRIVATE ma_int16* ma_dr_wav__read_pcm_frames_and_close_s16(ma_dr_wav* pWav, u ma_int16* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int16)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int16); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82563,6 +84467,10 @@ MA_PRIVATE float* ma_dr_wav__read_pcm_frames_and_close_f32(ma_dr_wav* pWav, unsi float* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(float)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82597,6 +84505,10 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u ma_int32* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int32)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int32); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82625,7 +84537,7 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u } return pSampleData; } -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82637,12 +84549,12 @@ MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82654,12 +84566,12 @@ MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82671,7 +84583,7 @@ MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); @@ -83979,7 +85891,7 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint64 r; __asm__ __volatile__ ( - "lzcnt{ %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{q %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return (ma_uint32)r; } @@ -83987,11 +85899,11 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint32 r; __asm__ __volatile__ ( - "lzcnt{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return r; } - #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !defined(MA_64BIT) + #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !(defined(__thumb__) && !defined(__thumb2__)) && !defined(MA_64BIT) { unsigned int r; __asm__ __volatile__ ( @@ -84106,23 +86018,23 @@ static ma_bool32 ma_dr_flac__seek_to_byte(ma_dr_flac_bs* bs, ma_uint64 offsetFro MA_DR_FLAC_ASSERT(offsetFromStart > 0); if (offsetFromStart > 0x7FFFFFFF) { ma_uint64 bytesRemaining = offsetFromStart; - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; while (bytesRemaining > 0x7FFFFFFF) { - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; } if (bytesRemaining > 0) { - if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } } @@ -86600,6 +88512,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; ma_dr_flac_meta_proc onMeta; ma_dr_flac_container container; void* pUserData; @@ -86728,11 +88641,12 @@ static void ma_dr_flac__free_from_callbacks(void* p, const ma_allocation_callbac pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); } } -static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) { ma_uint64 runningFilePos = 42; ma_uint64 seektablePos = 0; ma_uint32 seektableSize = 0; + (void)onTell; for (;;) { ma_dr_flac_metadata metadata; ma_uint8 isLastBlock = 0; @@ -86743,8 +88657,9 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } runningFilePos += 4; metadata.type = blockType; - metadata.pRawData = NULL; metadata.rawDataSize = 0; + metadata.rawDataOffset = runningFilePos; + metadata.pRawData = NULL; switch (blockType) { case MA_DR_FLAC_METADATA_BLOCK_TYPE_APPLICATION: @@ -86944,53 +88859,124 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea return MA_FALSE; } if (onMeta) { - void* pRawData; - const char* pRunningData; - const char* pRunningDataEnd; - pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { + ma_bool32 result = MA_TRUE; + ma_uint32 blockSizeRemaining = blockSize; + char* pMime = NULL; + char* pDescription = NULL; + void* pPictureData = NULL; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.type, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.type = ma_dr_flac__be2host_32(metadata.data.picture.type); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.mimeLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.mimeLength = ma_dr_flac__be2host_32(metadata.data.picture.mimeLength); + pMime = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.mimeLength + 1, pAllocationCallbacks); + if (pMime == NULL) { + result = MA_FALSE; + goto done_flac; + } + if (blockSizeRemaining < metadata.data.picture.mimeLength || onRead(pUserData, pMime, metadata.data.picture.mimeLength) != metadata.data.picture.mimeLength) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= metadata.data.picture.mimeLength; + pMime[metadata.data.picture.mimeLength] = '\0'; + metadata.data.picture.mime = (const char*)pMime; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.descriptionLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32(metadata.data.picture.descriptionLength); + pDescription = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.descriptionLength + 1, pAllocationCallbacks); + if (pDescription == NULL) { + result = MA_FALSE; + goto done_flac; + } + if (blockSizeRemaining < metadata.data.picture.descriptionLength || onRead(pUserData, pDescription, metadata.data.picture.descriptionLength) != metadata.data.picture.descriptionLength) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= metadata.data.picture.descriptionLength; + pDescription[metadata.data.picture.descriptionLength] = '\0'; + metadata.data.picture.description = (const char*)pDescription; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.width, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.width = ma_dr_flac__be2host_32(metadata.data.picture.width); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.height, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.height = ma_dr_flac__be2host_32(metadata.data.picture.height); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.colorDepth, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.colorDepth = ma_dr_flac__be2host_32(metadata.data.picture.colorDepth); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.indexColorCount, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32(metadata.data.picture.indexColorCount); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.pictureDataSize, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32(metadata.data.picture.pictureDataSize); + if (blockSizeRemaining < metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + metadata.data.picture.pictureDataOffset = runningFilePos + (blockSize - blockSizeRemaining); + #ifndef MA_DR_FLAC_NO_PICTURE_METADATA_MALLOC + pPictureData = ma_dr_flac__malloc_from_callbacks(metadata.data.picture.pictureDataSize, pAllocationCallbacks); + if (pPictureData != NULL) { + if (onRead(pUserData, pPictureData, metadata.data.picture.pictureDataSize) != metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + } else + #endif + { + if (!onSeek(pUserData, metadata.data.picture.pictureDataSize, MA_DR_FLAC_SEEK_CUR)) { + result = MA_FALSE; + goto done_flac; + } + } + blockSizeRemaining -= metadata.data.picture.pictureDataSize; + (void)blockSizeRemaining; + metadata.data.picture.pPictureData = (const ma_uint8*)pPictureData; + if (metadata.data.picture.pictureDataOffset != 0 || metadata.data.picture.pPictureData != NULL) { + onMeta(pUserDataMD, &metadata); + } else { + } + done_flac: + ma_dr_flac__free_from_callbacks(pMime, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pDescription, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pPictureData, pAllocationCallbacks); + if (result != MA_TRUE) { return MA_FALSE; } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - metadata.pRawData = pRawData; - metadata.rawDataSize = blockSize; - pRunningData = (const char*)pRawData; - pRunningDataEnd = (const char*)pRawData + blockSize; - metadata.data.picture.type = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.mimeLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 24 < (ma_int64)metadata.data.picture.mimeLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - metadata.data.picture.mime = pRunningData; pRunningData += metadata.data.picture.mimeLength; - metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 20 < (ma_int64)metadata.data.picture.descriptionLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - metadata.data.picture.description = pRunningData; pRunningData += metadata.data.picture.descriptionLength; - metadata.data.picture.width = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.height = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.colorDepth = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pPictureData = (const ma_uint8*)pRunningData; - if (pRunningDataEnd - pRunningData < (ma_int64)metadata.data.picture.pictureDataSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; - } - onMeta(pUserDataMD, &metadata); - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); } } break; case MA_DR_FLAC_METADATA_BLOCK_TYPE_PADDING: { if (onMeta) { metadata.data.padding.unused = 0; - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } else { onMeta(pUserDataMD, &metadata); @@ -87000,7 +88986,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea case MA_DR_FLAC_METADATA_BLOCK_TYPE_INVALID: { if (onMeta) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87009,12 +88995,15 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea { if (onMeta) { void* pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { - return MA_FALSE; - } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + if (pRawData != NULL) { + if (onRead(pUserData, pRawData, blockSize) != blockSize) { + ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); + return MA_FALSE; + } + } else { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { + return MA_FALSE; + } } metadata.pRawData = pRawData; metadata.rawDataSize = blockSize; @@ -87024,7 +89013,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } break; } if (onMeta == NULL && blockSize > 0) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87288,6 +89277,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; ma_uint64 currentBytePos; ma_uint64 firstBytePos; @@ -87306,29 +89296,29 @@ static size_t ma_dr_flac_oggbs__read_physical(ma_dr_flac_oggbs* oggbs, void* buf } static ma_bool32 ma_dr_flac_oggbs__seek_physical(ma_dr_flac_oggbs* oggbs, ma_uint64 offset, ma_dr_flac_seek_origin origin) { - if (origin == ma_dr_flac_seek_origin_start) { + if (origin == MA_DR_FLAC_SEEK_SET) { if (offset <= 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; return MA_TRUE; } else { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; - return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, ma_dr_flac_seek_origin_current); + return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR); } } else { while (offset > 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += 0x7FFFFFFF; offset -= 0x7FFFFFFF; } - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += offset; @@ -87354,7 +89344,7 @@ static ma_bool32 ma_dr_flac_oggbs__goto_next_page(ma_dr_flac_oggbs* oggbs, ma_dr continue; } if (header.serialNumber != oggbs->serialNumber) { - if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } continue; @@ -87416,7 +89406,7 @@ static ma_bool32 ma_dr_flac_oggbs__seek_to_next_packet(ma_dr_flac_oggbs* oggbs) } bytesToEndOfPacketOrPage += segmentSize; } - ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, ma_dr_flac_seek_origin_current); + ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, MA_DR_FLAC_SEEK_CUR); oggbs->bytesRemainingInPage -= bytesToEndOfPacketOrPage; if (atEndOfPage) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs)) { @@ -87469,36 +89459,44 @@ static ma_bool32 ma_dr_flac__on_seek_ogg(void* pUserData, int offset, ma_dr_flac int bytesSeeked = 0; MA_DR_FLAC_ASSERT(oggbs != NULL); MA_DR_FLAC_ASSERT(offset >= 0); - if (origin == ma_dr_flac_seek_origin_start) { - if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, ma_dr_flac_seek_origin_start)) { + if (origin == MA_DR_FLAC_SEEK_SET) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { return MA_FALSE; } - return ma_dr_flac__on_seek_ogg(pUserData, offset, ma_dr_flac_seek_origin_current); - } - MA_DR_FLAC_ASSERT(origin == ma_dr_flac_seek_origin_current); - while (bytesSeeked < offset) { - int bytesRemainingToSeek = offset - bytesSeeked; - MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); - if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { - bytesSeeked += bytesRemainingToSeek; - (void)bytesSeeked; - oggbs->bytesRemainingInPage -= bytesRemainingToSeek; - break; - } - if (oggbs->bytesRemainingInPage > 0) { - bytesSeeked += (int)oggbs->bytesRemainingInPage; - oggbs->bytesRemainingInPage = 0; - } - MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); - if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { - return MA_FALSE; + return ma_dr_flac__on_seek_ogg(pUserData, offset, MA_DR_FLAC_SEEK_CUR); + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + while (bytesSeeked < offset) { + int bytesRemainingToSeek = offset - bytesSeeked; + MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); + if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { + bytesSeeked += bytesRemainingToSeek; + (void)bytesSeeked; + oggbs->bytesRemainingInPage -= bytesRemainingToSeek; + break; + } + if (oggbs->bytesRemainingInPage > 0) { + bytesSeeked += (int)oggbs->bytesRemainingInPage; + oggbs->bytesRemainingInPage = 0; + } + MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); + if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { + return MA_FALSE; + } } + } else if (origin == MA_DR_FLAC_SEEK_END) { + return MA_FALSE; } return MA_TRUE; } +static ma_bool32 ma_dr_flac__on_tell_ogg(void* pUserData, ma_int64* pCursor) +{ + (void)pUserData; + (void)pCursor; + return MA_FALSE; +} static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 pcmFrameIndex) { ma_dr_flac_oggbs* oggbs = (ma_dr_flac_oggbs*)pFlac->_oggbs; @@ -87515,7 +89513,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 runningGranulePosition = 0; for (;;) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { - ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, ma_dr_flac_seek_origin_start); + ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, MA_DR_FLAC_SEEK_SET); return MA_FALSE; } runningFrameBytePos = oggbs->currentBytePos - ma_dr_flac_ogg__get_page_header_size(&oggbs->currentPageHeader) - oggbs->pageDataSize; @@ -87534,7 +89532,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 } } } - if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, ma_dr_flac_seek_origin_start)) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { @@ -87629,7 +89627,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d if (mappingVersion[0] != 1) { return MA_FALSE; } - if (!onSeek(pUserData, 2, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, 2, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } if (onRead(pUserData, sig, 4) != 4) { @@ -87674,17 +89672,17 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_FALSE; } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } @@ -87698,7 +89696,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_TRUE; } #endif -static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) +static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) { ma_bool32 relaxed; ma_uint8 id[4]; @@ -87708,12 +89706,14 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla MA_DR_FLAC_ZERO_MEMORY(pInit, sizeof(*pInit)); pInit->onRead = onRead; pInit->onSeek = onSeek; + pInit->onTell = onTell; pInit->onMeta = onMeta; pInit->container = container; pInit->pUserData = pUserData; pInit->pUserDataMD = pUserDataMD; pInit->bs.onRead = onRead; pInit->bs.onSeek = onSeek; + pInit->bs.onTell = onTell; pInit->bs.pUserData = pUserData; ma_dr_flac__reset_cache(&pInit->bs); relaxed = container != ma_dr_flac_container_unknown; @@ -87736,7 +89736,7 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla if (flags & 0x10) { headerSize += 10; } - if (!onSeek(pUserData, headerSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, headerSize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } pInit->runningFilePos += headerSize; @@ -87779,7 +89779,7 @@ static void ma_dr_flac__init_from_info(ma_dr_flac* pFlac, const ma_dr_flac_init_ pFlac->totalPCMFrameCount = pInit->totalPCMFrameCount; pFlac->container = pInit->container; } -static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac_init_info init; ma_uint32 allocationSize; @@ -87794,7 +89794,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on ma_allocation_callbacks allocationCallbacks; ma_dr_flac* pFlac; ma_dr_flac__init_cpu_caps(); - if (!ma_dr_flac__init_private(&init, onRead, onSeek, onMeta, container, pUserData, pUserDataMD)) { + if (!ma_dr_flac__init_private(&init, onRead, onSeek, onTell, onMeta, container, pUserData, pUserDataMD)) { return NULL; } if (pAllocationCallbacks != NULL) { @@ -87827,6 +89827,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on MA_DR_FLAC_ZERO_MEMORY(pOggbs, sizeof(*pOggbs)); pOggbs->onRead = onRead; pOggbs->onSeek = onSeek; + pOggbs->onTell = onTell; pOggbs->pUserData = pUserData; pOggbs->currentBytePos = init.oggFirstBytePos; pOggbs->firstBytePos = init.oggFirstBytePos; @@ -87841,15 +89842,17 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on if (init.hasMetadataBlocks) { ma_dr_flac_read_proc onReadOverride = onRead; ma_dr_flac_seek_proc onSeekOverride = onSeek; + ma_dr_flac_tell_proc onTellOverride = onTell; void* pUserDataOverride = pUserData; #ifndef MA_DR_FLAC_NO_OGG if (init.container == ma_dr_flac_container_ogg) { onReadOverride = ma_dr_flac__on_read_ogg; onSeekOverride = ma_dr_flac__on_seek_ogg; + onTellOverride = ma_dr_flac__on_tell_ogg; pUserDataOverride = (void*)pOggbs; } #endif - if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { + if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onTellOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { #ifndef MA_DR_FLAC_NO_OGG ma_dr_flac__free_from_callbacks(pOggbs, &allocationCallbacks); #endif @@ -87875,6 +89878,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pOggbs = NULL; pFlac->bs.onRead = ma_dr_flac__on_read_ogg; pFlac->bs.onSeek = ma_dr_flac__on_seek_ogg; + pFlac->bs.onTell = ma_dr_flac__on_tell_ogg; pFlac->bs.pUserData = (void*)pInternalOggbs; pFlac->_oggbs = (void*)pInternalOggbs; } @@ -87894,7 +89898,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pFlac->pSeekpoints = (ma_dr_flac_seekpoint*)((ma_uint8*)pFlac->pDecodedSamples + decodedSamplesAllocationSize); MA_DR_FLAC_ASSERT(pFlac->bs.onSeek != NULL); MA_DR_FLAC_ASSERT(pFlac->bs.onRead != NULL); - if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, ma_dr_flac_seek_origin_start)) { + if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, MA_DR_FLAC_SEEK_SET)) { ma_uint32 iSeekpoint; for (iSeekpoint = 0; iSeekpoint < seekpointCount; iSeekpoint += 1) { if (pFlac->bs.onRead(pFlac->bs.pUserData, pFlac->pSeekpoints + iSeekpoint, MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) == MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) { @@ -87907,7 +89911,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on break; } } - if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, ma_dr_flac_seek_origin_start)) { + if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, MA_DR_FLAC_SEEK_SET)) { ma_dr_flac__free_from_callbacks(pFlac, &allocationCallbacks); return NULL; } @@ -87950,8 +89954,31 @@ static size_t ma_dr_flac__on_read_stdio(void* pUserData, void* bufferOut, size_t } static ma_bool32 ma_dr_flac__on_seek_stdio(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { - MA_DR_FLAC_ASSERT(offset >= 0); - return fseek((FILE*)pUserData, offset, (origin == ma_dr_flac_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_FLAC_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_FLAC_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +static ma_bool32 ma_dr_flac__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_FLAC_ASSERT(pFileStdio != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -87960,7 +89987,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocati if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87975,7 +90002,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_w(const wchar_t* pFileName, const ma_all if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87990,7 +90017,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata(const char* pFileName, ma_ if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88005,7 +90032,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88033,24 +90060,34 @@ static size_t ma_dr_flac__on_read_memory(void* pUserData, void* bufferOut, size_ static ma_bool32 ma_dr_flac__on_seek_memory(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + ma_int64 newCursor; MA_DR_FLAC_ASSERT(memoryStream != NULL); - MA_DR_FLAC_ASSERT(offset >= 0); - if (offset > (ma_int64)memoryStream->dataSize) { + if (origin == MA_DR_FLAC_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + newCursor = (ma_int64)memoryStream->currentReadPos; + } else if (origin == MA_DR_FLAC_SEEK_END) { + newCursor = (ma_int64)memoryStream->dataSize; + } else { + MA_DR_FLAC_ASSERT(!"Invalid seek origin"); return MA_FALSE; } - if (origin == ma_dr_flac_seek_origin_current) { - if (memoryStream->currentReadPos + offset <= memoryStream->dataSize) { - memoryStream->currentReadPos += offset; - } else { - return MA_FALSE; - } - } else { - if ((ma_uint32)offset <= memoryStream->dataSize) { - memoryStream->currentReadPos = offset; - } else { - return MA_FALSE; - } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; } + if ((size_t)newCursor > memoryStream->dataSize) { + return MA_FALSE; + } + memoryStream->currentReadPos = (size_t)newCursor; + return MA_TRUE; +} +static ma_bool32 ma_dr_flac__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + MA_DR_FLAC_ASSERT(memoryStream != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)memoryStream->currentReadPos; return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -88060,7 +90097,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, co memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, &memoryStream, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, &memoryStream, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88085,7 +90122,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88103,21 +90140,21 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ } return pFlac; } -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, container, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, container, pUserData, pUserData, pAllocationCallbacks); } MA_API void ma_dr_flac_close(ma_dr_flac* pFlac) { @@ -90345,56 +92382,41 @@ static type* ma_dr_flac__full_read_and_close_ ## extension (ma_dr_flac* pFlac, u { \ type* pSampleData = NULL; \ ma_uint64 totalPCMFrameCount; \ + type buffer[4096]; \ + ma_uint64 pcmFramesRead; \ + size_t sampleDataBufferSize = sizeof(buffer); \ \ MA_DR_FLAC_ASSERT(pFlac != NULL); \ \ - totalPCMFrameCount = pFlac->totalPCMFrameCount; \ + totalPCMFrameCount = 0; \ \ - if (totalPCMFrameCount == 0) { \ - type buffer[4096]; \ - ma_uint64 pcmFramesRead; \ - size_t sampleDataBufferSize = sizeof(buffer); \ + pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pSampleData == NULL) { \ + goto on_error; \ + } \ \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ - } \ + while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ + if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ + type* pNewSampleData; \ + size_t newSampleDataBufferSize; \ \ - while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ - if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ - type* pNewSampleData; \ - size_t newSampleDataBufferSize; \ - \ - newSampleDataBufferSize = sampleDataBufferSize * 2; \ - pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pNewSampleData == NULL) { \ - ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ - goto on_error; \ - } \ - \ - sampleDataBufferSize = newSampleDataBufferSize; \ - pSampleData = pNewSampleData; \ + newSampleDataBufferSize = sampleDataBufferSize * 2; \ + pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pNewSampleData == NULL) { \ + ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ + goto on_error; \ } \ \ - MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ - totalPCMFrameCount += pcmFramesRead; \ + sampleDataBufferSize = newSampleDataBufferSize; \ + pSampleData = pNewSampleData; \ } \ \ + MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ + totalPCMFrameCount += pcmFramesRead; \ + } \ + \ \ - MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ - } else { \ - ma_uint64 dataSize = totalPCMFrameCount*pFlac->channels*sizeof(type); \ - if (dataSize > (ma_uint64)MA_SIZE_MAX) { \ - goto on_error; \ - } \ - \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks((size_t)dataSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ - } \ - \ - totalPCMFrameCount = ma_dr_flac_read_pcm_frames_##extension(pFlac, pFlac->totalPCMFrameCount, pSampleData); \ - } \ + MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ \ if (sampleRateOut) *sampleRateOut = pFlac->sampleRate; \ if (channelsOut) *channelsOut = pFlac->channels; \ @@ -90410,7 +92432,7 @@ on_error: MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s32, ma_int32) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s16, ma_int16) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(f32, float) -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90422,13 +92444,13 @@ MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s32(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90440,13 +92462,13 @@ MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s16(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90458,7 +92480,7 @@ MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRea if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -90680,12 +92702,9 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_NO_SIMD #endif #define MA_DR_MP3_OFFSET_PTR(p, offset) ((void*)((ma_uint8*)(p) + (offset))) -#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 #ifndef MA_DR_MP3_MAX_FRAME_SYNC_MATCHES #define MA_DR_MP3_MAX_FRAME_SYNC_MATCHES 10 #endif -#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE -#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 #define MA_DR_MP3_SHORT_BLOCK_TYPE 2 #define MA_DR_MP3_STOP_BLOCK_TYPE 3 #define MA_DR_MP3_MODE_MONO 3 @@ -90735,7 +92754,7 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_VMUL_S(x, s) _mm_mul_ps(x, _mm_set1_ps(s)) #define MA_DR_MP3_VREV(x) _mm_shuffle_ps(x, x, _MM_SHUFFLE(0, 1, 2, 3)) typedef __m128 ma_dr_mp3_f4; -#if defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD) +#if (defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD)) && !defined(__clang__) #define ma_dr_mp3_cpuid __cpuid #else static __inline__ __attribute__((always_inline)) void ma_dr_mp3_cpuid(int CPUInfo[], const int InfoType) @@ -90851,11 +92870,6 @@ static __inline__ __attribute__((always_inline)) ma_int32 ma_dr_mp3_clip_int16_a #define MA_DR_MP3_FREE(p) free((p)) #endif typedef struct -{ - const ma_uint8 *buf; - int pos, limit; -} ma_dr_mp3_bs; -typedef struct { float scf[3*64]; ma_uint8 total_bands, stereo_bands, bitalloc[64], scfcod[64]; @@ -90864,22 +92878,6 @@ typedef struct { ma_uint8 tab_offset, code_tab_width, band_count; } ma_dr_mp3_L12_subband_alloc; -typedef struct -{ - const ma_uint8 *sfbtab; - ma_uint16 part_23_length, big_values, scalefac_compress; - ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; - ma_uint8 table_select[3], region_count[3], subblock_gain[3]; - ma_uint8 preflag, scalefac_scale, count1_table, scfsi; -} ma_dr_mp3_L3_gr_info; -typedef struct -{ - ma_dr_mp3_bs bs; - ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; - ma_dr_mp3_L3_gr_info gr_info[4]; - float grbuf[2][576], scf[40], syn[18 + 15][2*32]; - ma_uint8 ist_pos[2][39]; -} ma_dr_mp3dec_scratch; static void ma_dr_mp3_bs_init(ma_dr_mp3_bs *bs, const ma_uint8 *data, int bytes) { bs->buf = data; @@ -91262,6 +93260,10 @@ static float ma_dr_mp3_L3_ldexp_q2(float y, int exp_q2) } while ((exp_q2 -= e) > 0); return y; } +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_pos, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr, float *scf, int ch) { static const ma_uint8 g_scf_partitions[3][28] = { @@ -91320,7 +93322,10 @@ static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_ scf[i] = ma_dr_mp3_L3_ldexp_q2(gain, iscf[i] << scf_shift); } } -static const float g_ma_dr_mp3_pow43[129 + 16] = { +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic pop +#endif +static const float ma_dr_mp3_g_pow43[129 + 16] = { 0,-1,-2.519842f,-4.326749f,-6.349604f,-8.549880f,-10.902724f,-13.390518f,-16.000000f,-18.720754f,-21.544347f,-24.463781f,-27.473142f,-30.567351f,-33.741992f,-36.993181f, 0,1,2.519842f,4.326749f,6.349604f,8.549880f,10.902724f,13.390518f,16.000000f,18.720754f,21.544347f,24.463781f,27.473142f,30.567351f,33.741992f,36.993181f,40.317474f,43.711787f,47.173345f,50.699631f,54.288352f,57.937408f,61.644865f,65.408941f,69.227979f,73.100443f,77.024898f,81.000000f,85.024491f,89.097188f,93.216975f,97.382800f,101.593667f,105.848633f,110.146801f,114.487321f,118.869381f,123.292209f,127.755065f,132.257246f,136.798076f,141.376907f,145.993119f,150.646117f,155.335327f,160.060199f,164.820202f,169.614826f,174.443577f,179.305980f,184.201575f,189.129918f,194.090580f,199.083145f,204.107210f,209.162385f,214.248292f,219.364564f,224.510845f,229.686789f,234.892058f,240.126328f,245.389280f,250.680604f,256.000000f,261.347174f,266.721841f,272.123723f,277.552547f,283.008049f,288.489971f,293.998060f,299.532071f,305.091761f,310.676898f,316.287249f,321.922592f,327.582707f,333.267377f,338.976394f,344.709550f,350.466646f,356.247482f,362.051866f,367.879608f,373.730522f,379.604427f,385.501143f,391.420496f,397.362314f,403.326427f,409.312672f,415.320884f,421.350905f,427.402579f,433.475750f,439.570269f,445.685987f,451.822757f,457.980436f,464.158883f,470.357960f,476.577530f,482.817459f,489.077615f,495.357868f,501.658090f,507.978156f,514.317941f,520.677324f,527.056184f,533.454404f,539.871867f,546.308458f,552.764065f,559.238575f,565.731879f,572.243870f,578.774440f,585.323483f,591.890898f,598.476581f,605.080431f,611.702349f,618.342238f,625.000000f,631.675540f,638.368763f,645.079578f }; @@ -91330,7 +93335,7 @@ static float ma_dr_mp3_L3_pow_43(int x) int sign, mult = 256; if (x < 129) { - return g_ma_dr_mp3_pow43[16 + x]; + return ma_dr_mp3_g_pow43[16 + x]; } if (x < 1024) { @@ -91339,7 +93344,7 @@ static float ma_dr_mp3_L3_pow_43(int x) } sign = 2*x & 64; frac = (float)((x & 63) - sign) / ((x & ~63) + sign); - return g_ma_dr_mp3_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; + return ma_dr_mp3_g_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; } static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr_info, const float *scf, int layer3gr_limit) { @@ -91409,7 +93414,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L *dst = one*ma_dr_mp3_L3_pow_43(lsb)*((ma_int32)bs_cache < 0 ? -1: 1); } else { - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; } MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } @@ -91437,7 +93442,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L for (j = 0; j < 2; j++, dst++, leaf >>= 4) { int lsb = leaf & 0x0F; - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } MA_DR_MP3_CHECK_BITS; @@ -92245,7 +94250,6 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int int i = 0, igr, frame_size = 0, success = 1; const ma_uint8 *hdr; ma_dr_mp3_bs bs_frame[1]; - ma_dr_mp3dec_scratch scratch; if (mp3_bytes > 4 && dec->header[0] == 0xff && ma_dr_mp3_hdr_compare(dec->header, mp3)) { frame_size = ma_dr_mp3_hdr_frame_bytes(mp3, dec->free_format_bytes) + ma_dr_mp3_hdr_padding(mp3); @@ -92268,7 +94272,7 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int MA_DR_MP3_COPY_MEMORY(dec->header, hdr, MA_DR_MP3_HDR_SIZE); info->frame_bytes = i + frame_size; info->channels = MA_DR_MP3_HDR_IS_MONO(hdr) ? 1 : 2; - info->hz = ma_dr_mp3_hdr_sample_rate_hz(hdr); + info->sample_rate = ma_dr_mp3_hdr_sample_rate_hz(hdr); info->layer = 4 - MA_DR_MP3_HDR_GET_LAYER(hdr); info->bitrate_kbps = ma_dr_mp3_hdr_bitrate_kbps(hdr); ma_dr_mp3_bs_init(bs_frame, hdr + MA_DR_MP3_HDR_SIZE, frame_size - MA_DR_MP3_HDR_SIZE); @@ -92278,23 +94282,23 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int } if (info->layer == 3) { - int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, scratch.gr_info, hdr); + int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, dec->scratch.gr_info, hdr); if (main_data_begin < 0 || bs_frame->pos > bs_frame->limit) { ma_dr_mp3dec_init(dec); return 0; } - success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &scratch, main_data_begin); + success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &dec->scratch, main_data_begin); if (success && pcm != NULL) { for (igr = 0; igr < (MA_DR_MP3_HDR_TEST_MPEG1(hdr) ? 2 : 1); igr++, pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*576*info->channels)) { - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); - ma_dr_mp3_L3_decode(dec, &scratch, scratch.gr_info + igr*info->channels, info->channels); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L3_decode(dec, &dec->scratch, dec->scratch.gr_info + igr*info->channels, info->channels); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); } } - ma_dr_mp3_L3_save_reservoir(dec, &scratch); + ma_dr_mp3_L3_save_reservoir(dec, &dec->scratch); } else { #ifdef MA_DR_MP3_ONLY_MP3 @@ -92305,15 +94309,15 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int return ma_dr_mp3_hdr_frame_samples(hdr); } ma_dr_mp3_L12_read_scale_info(hdr, bs_frame, sci); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); for (i = 0, igr = 0; igr < 3; igr++) { - if (12 == (i += ma_dr_mp3_L12_dequantize_granule(scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) + if (12 == (i += ma_dr_mp3_L12_dequantize_granule(dec->scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) { i = 0; - ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, scratch.grbuf[0]); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, dec->scratch.grbuf[0]); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*384*info->channels); } if (bs_frame->pos > bs_frame->limit) @@ -92491,19 +94495,41 @@ static ma_allocation_callbacks ma_dr_mp3_copy_allocation_callbacks_or_defaults(c } static size_t ma_dr_mp3__on_read(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) { - size_t bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); + size_t bytesRead; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (bytesToRead == 0) { + return 0; + } + bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); pMP3->streamCursor += bytesRead; return bytesRead; } +static size_t ma_dr_mp3__on_read_clamped(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) +{ + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (pMP3->streamLength == MA_UINT64_MAX) { + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } else { + ma_uint64 bytesRemaining; + bytesRemaining = (pMP3->streamLength - pMP3->streamCursor); + if (bytesToRead > bytesRemaining) { + bytesToRead = (size_t)bytesRemaining; + } + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } +} static ma_bool32 ma_dr_mp3__on_seek(ma_dr_mp3* pMP3, int offset, ma_dr_mp3_seek_origin origin) { MA_DR_MP3_ASSERT(offset >= 0); + MA_DR_MP3_ASSERT(origin == MA_DR_MP3_SEEK_SET || origin == MA_DR_MP3_SEEK_CUR); if (!pMP3->onSeek(pMP3->pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_mp3_seek_origin_start) { + if (origin == MA_DR_MP3_SEEK_SET) { pMP3->streamCursor = (ma_uint64)offset; - } else { + } else{ pMP3->streamCursor += offset; } return MA_TRUE; @@ -92513,18 +94539,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ if (offset <= 0x7FFFFFFF) { return ma_dr_mp3__on_seek(pMP3, (int)offset, origin); } - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; while (offset > 0) { if (offset <= 0x7FFFFFFF) { - if (!ma_dr_mp3__on_seek(pMP3, (int)offset, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, (int)offset, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset = 0; } else { - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -92532,7 +94558,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ } return MA_TRUE; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static void ma_dr_mp3__on_meta(ma_dr_mp3* pMP3, ma_dr_mp3_metadata_type type, const void* pRawData, size_t rawDataSize) +{ + if (pMP3->onMeta) { + ma_dr_mp3_metadata metadata; + MA_DR_MP3_ZERO_OBJECT(&metadata); + metadata.type = type; + metadata.pRawData = pRawData; + metadata.rawDataSize = rawDataSize; + pMP3->onMeta(pMP3->pUserDataMeta, &metadata); + } +} +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; MA_DR_MP3_ASSERT(pMP3 != NULL); @@ -92559,7 +94596,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { if (pMP3->dataSize == 0) { pMP3->atEnd = MA_TRUE; @@ -92578,16 +94615,20 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d return 0; } pcmFramesRead = ma_dr_mp3dec_decode_frame(&pMP3->decoder, pMP3->pData + pMP3->dataConsumed, (int)pMP3->dataSize, pPCMFrames, &info); - if (info.frame_bytes > 0) { - pMP3->dataConsumed += (size_t)info.frame_bytes; - pMP3->dataSize -= (size_t)info.frame_bytes; - } + pMP3->dataConsumed += (size_t)info.frame_bytes; + pMP3->dataSize -= (size_t)info.frame_bytes; if (pcmFramesRead > 0) { pcmFramesRead = ma_dr_mp3_hdr_frame_samples(pMP3->decoder.header); pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->pData + pMP3->dataConsumed - (size_t)info.frame_bytes; + } break; } else if (info.frame_bytes == 0) { size_t bytesRead; @@ -92604,7 +94645,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { pMP3->atEnd = MA_TRUE; return 0; @@ -92614,7 +94655,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d }; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; ma_dr_mp3dec_frame_info info; @@ -92630,36 +94671,44 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_m pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->memory.pData + pMP3->memory.currentReadPos; + } break; } else if (info.frame_bytes > 0) { pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; } else { break; } } pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { if (pMP3->memory.pData != NULL && pMP3->memory.dataSize > 0) { - return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } else { - return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } } static ma_uint32 ma_dr_mp3_decode_next_frame(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); - return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames); + return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, NULL, NULL); } #if 0 static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) { ma_uint32 pcmFrameCount; MA_DR_MP3_ASSERT(pMP3 != NULL); - pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFrameCount == 0) { return 0; } @@ -92669,33 +94718,252 @@ static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) return pcmFrameCount; } #endif -static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { + ma_dr_mp3dec_frame_info firstFrameInfo; + const ma_uint8* pFirstFrameData; + ma_uint32 firstFramePCMFrameCount; + ma_uint32 detectedMP3FrameCount = 0xFFFFFFFF; MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(onRead != NULL); ma_dr_mp3dec_init(&pMP3->decoder); pMP3->onRead = onRead; pMP3->onSeek = onSeek; + pMP3->onMeta = onMeta; pMP3->pUserData = pUserData; + pMP3->pUserDataMeta = pUserDataMeta; pMP3->allocationCallbacks = ma_dr_mp3_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pMP3->allocationCallbacks.onFree == NULL || (pMP3->allocationCallbacks.onMalloc == NULL && pMP3->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; } - if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { + pMP3->streamCursor = 0; + pMP3->streamLength = MA_UINT64_MAX; + pMP3->streamStartOffset = 0; + pMP3->delayInPCMFrames = 0; + pMP3->paddingInPCMFrames = 0; + pMP3->totalPCMFrameCount = MA_UINT64_MAX; + #if 1 + if (onSeek != NULL && onTell != NULL) { + if (onSeek(pUserData, 0, MA_DR_MP3_SEEK_END)) { + ma_int64 streamLen; + int streamEndOffset = 0; + if (onTell(pUserData, &streamLen)) { + if (streamLen > 128) { + char id3[3]; + if (onSeek(pUserData, streamEndOffset - 128, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, id3, 3) == 3 && id3[0] == 'T' && id3[1] == 'A' && id3[2] == 'G') { + streamEndOffset -= 128; + streamLen -= 128; + if (onMeta != NULL) { + ma_uint8 tag[128]; + tag[0] = 'T'; tag[1] = 'A'; tag[2] = 'G'; + if (onRead(pUserData, tag + 3, 125) == 125) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V1, tag, 128); + } + } + } else { + } + } else { + } + } else { + } + if (streamLen > 32) { + char ape[32]; + if (onSeek(pUserData, streamEndOffset - 32, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, ape, 32) == 32 && ape[0] == 'A' && ape[1] == 'P' && ape[2] == 'E' && ape[3] == 'T' && ape[4] == 'A' && ape[5] == 'G' && ape[6] == 'E' && ape[7] == 'X') { + ma_uint32 tagSize = + ((ma_uint32)ape[24] << 0) | + ((ma_uint32)ape[25] << 8) | + ((ma_uint32)ape[26] << 16) | + ((ma_uint32)ape[27] << 24); + if (32 + tagSize < streamLen) { + streamEndOffset -= 32 + tagSize; + streamLen -= 32 + tagSize; + if (onMeta != NULL) { + if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) { + size_t apeTagSize = (size_t)tagSize + 32; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks); + if (pTagData != NULL) { + if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } + } + } else { + } + } + } + } else { + } + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + pMP3->streamLength = (ma_uint64)streamLen; + if (pMP3->memory.pData != NULL) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + } else { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } + } else { + } + } else { + } + #endif + #if 1 + { + char header[10]; + if (onRead(pUserData, header, 10) == 10) { + if (header[0] == 'I' && header[1] == 'D' && header[2] == '3') { + ma_uint32 tagSize = + (((ma_uint32)header[6] & 0x7F) << 21) | + (((ma_uint32)header[7] & 0x7F) << 14) | + (((ma_uint32)header[8] & 0x7F) << 7) | + (((ma_uint32)header[9] & 0x7F) << 0); + if (header[5] & 0x10) { + tagSize += 10; + } + if (onMeta != NULL) { + size_t tagSizeWithHeader = 10 + tagSize; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(tagSizeWithHeader, pAllocationCallbacks); + if (pTagData != NULL) { + MA_DR_MP3_COPY_MEMORY(pTagData, header, 10); + if (onRead(pUserData, pTagData + 10, tagSize) == tagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V2, pTagData, tagSizeWithHeader); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, tagSize, MA_DR_MP3_SEEK_CUR)) { + return MA_FALSE; + } + } else { + char discard[1024]; + while (tagSize > 0) { + size_t bytesToRead = tagSize; + if (bytesToRead > sizeof(discard)) { + bytesToRead = sizeof(discard); + } + if (onRead(pUserData, discard, bytesToRead) != bytesToRead) { + return MA_FALSE; + } + tagSize -= (ma_uint32)bytesToRead; + } + } + } + pMP3->streamStartOffset += 10 + tagSize; + pMP3->streamCursor = pMP3->streamStartOffset; + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } else { + } + } + } else { + return MA_FALSE; + } + } + #endif + firstFramePCMFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, &firstFrameInfo, &pFirstFrameData); + if (firstFramePCMFrameCount > 0) { + MA_DR_MP3_ASSERT(pFirstFrameData != NULL); + #if 1 + MA_DR_MP3_ASSERT(firstFrameInfo.frame_bytes > 0); + { + ma_dr_mp3_bs bs; + ma_dr_mp3_L3_gr_info grInfo[4]; + ma_dr_mp3_bs_init(&bs, pFirstFrameData + MA_DR_MP3_HDR_SIZE, firstFrameInfo.frame_bytes - MA_DR_MP3_HDR_SIZE); + if (MA_DR_MP3_HDR_IS_CRC(pFirstFrameData)) { + ma_dr_mp3_bs_get_bits(&bs, 16); + } + if (ma_dr_mp3_L3_read_side_info(&bs, grInfo, pFirstFrameData) >= 0) { + ma_bool32 isXing = MA_FALSE; + ma_bool32 isInfo = MA_FALSE; + const ma_uint8* pTagData; + const ma_uint8* pTagDataBeg; + pTagDataBeg = pFirstFrameData + MA_DR_MP3_HDR_SIZE + (bs.pos/8); + pTagData = pTagDataBeg; + isXing = (pTagData[0] == 'X' && pTagData[1] == 'i' && pTagData[2] == 'n' && pTagData[3] == 'g'); + isInfo = (pTagData[0] == 'I' && pTagData[1] == 'n' && pTagData[2] == 'f' && pTagData[3] == 'o'); + if (isXing || isInfo) { + ma_uint32 bytes = 0; + ma_uint32 flags = pTagData[7]; + pTagData += 8; + if (flags & 0x01) { + detectedMP3FrameCount = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + pTagData += 4; + } + if (flags & 0x02) { + bytes = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + (void)bytes; + pTagData += 4; + } + if (flags & 0x04) { + pTagData += 100; + } + if (flags & 0x08) { + pTagData += 4; + } + if (pTagData[0]) { + pTagData += 21; + if (pTagData - pFirstFrameData + 14 < firstFrameInfo.frame_bytes) { + int delayInPCMFrames; + int paddingInPCMFrames; + delayInPCMFrames = (( (ma_uint32)pTagData[0] << 4) | ((ma_uint32)pTagData[1] >> 4)) + (528 + 1); + paddingInPCMFrames = ((((ma_uint32)pTagData[1] & 0xF) << 8) | ((ma_uint32)pTagData[2] )) - (528 + 1); + if (paddingInPCMFrames < 0) { + paddingInPCMFrames = 0; + } + pMP3->delayInPCMFrames = (ma_uint32)delayInPCMFrames; + pMP3->paddingInPCMFrames = (ma_uint32)paddingInPCMFrames; + } + } + if (isXing) { + pMP3->isVBR = MA_TRUE; + } else if (isInfo) { + pMP3->isCBR = MA_TRUE; + } + if (onMeta != NULL) { + ma_dr_mp3_metadata_type metadataType = isXing ? MA_DR_MP3_METADATA_TYPE_XING : MA_DR_MP3_METADATA_TYPE_VBRI; + size_t tagDataSize; + tagDataSize = (size_t)firstFrameInfo.frame_bytes; + tagDataSize -= (size_t)(pTagDataBeg - pFirstFrameData); + ma_dr_mp3__on_meta(pMP3, metadataType, pTagDataBeg, tagDataSize); + } + pMP3->pcmFramesRemainingInMP3Frame = 0; + pMP3->streamStartOffset += (ma_uint32)(firstFrameInfo.frame_bytes); + pMP3->streamCursor = pMP3->streamStartOffset; + ma_dr_mp3dec_init(&pMP3->decoder); + } + } else { + } + } + #endif + } else { ma_dr_mp3__free_from_callbacks(pMP3->pData, &pMP3->allocationCallbacks); return MA_FALSE; } + if (detectedMP3FrameCount != 0xFFFFFFFF) { + pMP3->totalPCMFrameCount = detectedMP3FrameCount * firstFramePCMFrameCount; + } pMP3->channels = pMP3->mp3FrameChannels; pMP3->sampleRate = pMP3->mp3FrameSampleRate; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pMP3 == NULL || onRead == NULL) { return MA_FALSE; } MA_DR_MP3_ZERO_OBJECT(pMP3); - return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, pUserData, pAllocationCallbacks); + return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, onTell, onMeta, pUserData, pUserData, pAllocationCallbacks); } static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) { @@ -92716,29 +94984,39 @@ static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_ static ma_bool32 ma_dr_mp3__on_seek_memory(void* pUserData, int byteOffset, ma_dr_mp3_seek_origin origin) { ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + ma_int64 newCursor; MA_DR_MP3_ASSERT(pMP3 != NULL); - if (origin == ma_dr_mp3_seek_origin_current) { - if (byteOffset > 0) { - if (pMP3->memory.currentReadPos + byteOffset > pMP3->memory.dataSize) { - byteOffset = (int)(pMP3->memory.dataSize - pMP3->memory.currentReadPos); - } - } else { - if (pMP3->memory.currentReadPos < (size_t)-byteOffset) { - byteOffset = -(int)pMP3->memory.currentReadPos; - } - } - pMP3->memory.currentReadPos += byteOffset; + if (origin == MA_DR_MP3_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_MP3_SEEK_CUR) { + newCursor = (ma_int64)pMP3->memory.currentReadPos; + } else if (origin == MA_DR_MP3_SEEK_END) { + newCursor = (ma_int64)pMP3->memory.dataSize; } else { - if ((ma_uint32)byteOffset <= pMP3->memory.dataSize) { - pMP3->memory.currentReadPos = byteOffset; - } else { - pMP3->memory.currentReadPos = pMP3->memory.dataSize; - } + MA_DR_MP3_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += byteOffset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pMP3->memory.dataSize) { + return MA_FALSE; + } + pMP3->memory.currentReadPos = (size_t)newCursor; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_memory(void* pUserData, ma_int64* pCursor) { + ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pMP3->memory.currentReadPos; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) +{ + ma_bool32 result; if (pMP3 == NULL) { return MA_FALSE; } @@ -92749,7 +95027,21 @@ MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_ pMP3->memory.pData = (const ma_uint8*)pData; pMP3->memory.dataSize = dataSize; pMP3->memory.currentReadPos = 0; - return ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, pMP3, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, ma_dr_mp3__on_tell_memory, onMeta, pMP3, pUserDataMeta, pAllocationCallbacks); + if (result == MA_FALSE) { + return MA_FALSE; + } + if (pMP3->streamLength <= (ma_uint64)MA_SIZE_MAX) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + if (pMP3->streamStartOffset > (ma_uint64)MA_SIZE_MAX) { + return MA_FALSE; + } + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_memory_with_metadata(pMP3, pData, dataSize, NULL, NULL, pAllocationCallbacks); } #ifndef MA_DR_MP3_NO_STDIO #include @@ -92760,36 +95052,76 @@ static size_t ma_dr_mp3__on_read_stdio(void* pUserData, void* pBufferOut, size_t } static ma_bool32 ma_dr_mp3__on_seek_stdio(void* pUserData, int offset, ma_dr_mp3_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_mp3_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_MP3_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_MP3_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; } -MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_MP3_ASSERT(pFileStdio != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_fopen(&pFile, pFilePath, "rb") != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_wfopen(&pFile, pFilePath, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } +MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} +MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata_w(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} #endif MA_API void ma_dr_mp3_uninit(ma_dr_mp3* pMP3) { @@ -92859,17 +95191,38 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onRead != NULL); while (framesToRead > 0) { - ma_uint32 framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + ma_uint32 framesToConsume; + if (pMP3->currentPCMFrame < pMP3->delayInPCMFrames) { + ma_uint32 framesToSkip = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, pMP3->delayInPCMFrames - pMP3->currentPCMFrame); + pMP3->currentPCMFrame += framesToSkip; + pMP3->pcmFramesConsumedInMP3Frame += framesToSkip; + pMP3->pcmFramesRemainingInMP3Frame -= framesToSkip; + } + framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames) { + if (pMP3->currentPCMFrame < (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + ma_uint64 framesRemainigToPadding = (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames) - pMP3->currentPCMFrame; + if (framesToConsume > framesRemainigToPadding) { + framesToConsume = (ma_uint32)framesRemainigToPadding; + } + } else { + break; + } + } if (pBufferOut != NULL) { - #if defined(MA_DR_MP3_FLOAT_OUTPUT) - float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); - float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); - #else - ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); - ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); - #endif + #if defined(MA_DR_MP3_FLOAT_OUTPUT) + { + float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); + float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); + } + #else + { + ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); + ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); + } + #endif } pMP3->currentPCMFrame += framesToConsume; pMP3->pcmFramesConsumedInMP3Frame += framesToConsume; @@ -92879,6 +95232,9 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames if (framesToRead == 0) { break; } + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames && pMP3->currentPCMFrame >= (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + break; + } MA_DR_MP3_ASSERT(pMP3->pcmFramesRemainingInMP3Frame == 0); if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { break; @@ -92958,7 +95314,7 @@ static ma_bool32 ma_dr_mp3_seek_to_start_of_stream(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onSeek != NULL); - if (!ma_dr_mp3__on_seek(pMP3, 0, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, pMP3->streamStartOffset, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93024,7 +95380,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin seekPoint.mp3FramesToDiscard = 0; seekPoint.pcmFramesToDiscard = 0; } - if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93035,7 +95391,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin if (iMP3Frame == seekPoint.mp3FramesToDiscard-1) { pPCMFrames = (ma_dr_mp3d_sample_t*)pMP3->pcmFrames; } - pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames); + pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames, NULL, NULL); if (pcmFramesRead == 0) { return MA_FALSE; } @@ -93077,7 +95433,7 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 totalMP3FrameCount = 0; for (;;) { ma_uint32 pcmFramesInCurrentMP3Frame; - pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3Frame == 0) { break; } @@ -93101,10 +95457,26 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 MA_API ma_uint64 ma_dr_mp3_get_pcm_frame_count(ma_dr_mp3* pMP3) { ma_uint64 totalPCMFrameCount; - if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + if (pMP3 == NULL) { return 0; } - return totalPCMFrameCount; + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX) { + totalPCMFrameCount = pMP3->totalPCMFrameCount; + if (totalPCMFrameCount >= pMP3->delayInPCMFrames) { + totalPCMFrameCount -= pMP3->delayInPCMFrames; + } else { + } + if (totalPCMFrameCount >= pMP3->paddingInPCMFrames) { + totalPCMFrameCount -= pMP3->paddingInPCMFrames; + } else { + } + return totalPCMFrameCount; + } else { + if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + return 0; + } + return totalPCMFrameCount; + } } MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3) { @@ -93174,7 +95546,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe MA_DR_MP3_ASSERT(pMP3->streamCursor >= pMP3->dataSize); mp3FrameInfo[iMP3Frame].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[iMP3Frame].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { return MA_FALSE; } @@ -93198,7 +95570,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe } mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { pSeekPoints[iSeekPoint].seekPosInBytes = mp3FrameInfo[0].bytePos; pSeekPoints[iSeekPoint].pcmFrameIndex = nextTargetPCMFrame; @@ -93264,6 +95636,8 @@ static float* ma_dr_mp3__full_read_and_close_f32(ma_dr_mp3* pMP3, ma_dr_mp3_conf pNewFrames = (float*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -93315,6 +95689,8 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c pNewFrames = (ma_int16*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -93336,18 +95712,18 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c } return pFrames; } -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_f32(&mp3, pConfig, pTotalFrameCount); } -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_s16(&mp3, pConfig, pTotalFrameCount); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 866ac4ea..f6a7a831 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -103,6 +103,7 @@ struct whisper_params { bool no_timestamps = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool no_context = true; bool no_language_probabilities = false; @@ -179,6 +180,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); @@ -198,6 +200,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para } bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) { + if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -237,6 +243,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -643,6 +650,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { @@ -740,9 +748,9 @@ int main(int argc, char ** argv) {

Whisper.cpp Server

-

/inference

+

)" + sparams.request_path + sparams.inference_path + R"(

-    curl 127.0.0.1:)" + std::to_string(sparams.port) + R"(/inference \
+    curl 127.0.0.1:)" + std::to_string(sparams.port) + sparams.request_path + sparams.inference_path + R"( \
     -H "Content-Type: multipart/form-data" \
     -F file="@<file-path>" \
     -F temperature="0.0" \
@@ -759,7 +767,7 @@ int main(int argc, char ** argv) {
 
         

Try it out

-
+
@@ -803,6 +811,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: no 'file' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -829,6 +838,7 @@ int main(int argc, char ** argv) { std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}"; const bool is_converted = convert_to_wav(temp_filename, error_resp); if (!is_converted) { + res.status = 500; res.set_content(error_resp, "application/json"); return; } @@ -838,6 +848,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); std::remove(temp_filename.c_str()); return; @@ -849,6 +860,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -927,7 +939,7 @@ int main(int argc, char ** argv) { wparams.logprob_thold = params.logprob_thold; wparams.no_timestamps = params.no_timestamps; - wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; + wparams.token_timestamps = !params.no_timestamps; wparams.no_context = params.no_context; wparams.suppress_nst = params.suppress_nst; @@ -1119,6 +1131,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: no 'model' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -1127,6 +1140,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: 'model': %s not found!\n", model.c_str()); const std::string error_resp = "{\"error\":\"model not found!\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index cac46705..1adeef8f 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -22,18 +22,19 @@ if (WHISPER_SDL2) llama-kv-cache-iswa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp + llama-memory-hybrid-iswa.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp llama-model-saver.cpp llama-model.cpp llama-quant.cpp - llama-sampling.cpp + llama-sampler.cpp llama-vocab.cpp unicode.cpp unicode-data.cpp ${SRC_MODELS}) - target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) + target_include_directories(${TARGET} PRIVATE . ${SDL2_INCLUDE_DIRS}) target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index bdc24c2d..d6a5800e 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -146,11 +146,9 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { return nullptr; } -static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_lora & adapter) { +static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); - llama_model & model = adapter.model; - ggml_context * ctx_init; gguf_init_params meta_gguf_params = { /* .no_alloc = */ true, @@ -413,17 +411,17 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l } } - // update number of nodes used - model.n_lora_nodes += adapter.get_n_nodes(); + // register adapter with model + model.loras.insert(&adapter); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(*model); + llama_adapter_lora * adapter = new llama_adapter_lora(); try { - llama_adapter_lora_init_impl(path_lora, *adapter); + llama_adapter_lora_init_impl(*model, path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -473,12 +471,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora * adapter) { - // update number of nodes used - GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes()); - adapter->model.n_lora_nodes -= adapter->get_n_nodes(); - - delete adapter; +void llama_adapter_lora_free(llama_adapter_lora *) { + // deprecated: adapters are freed by llama_model's destructor } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index 42d64a6e..aa3ab63a 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -39,6 +39,8 @@ private: std::vector tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr; + // // llama_adapter_lora // @@ -59,8 +61,6 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { - llama_model & model; - // map tensor name to lora_a_b std::unordered_map ab_map; @@ -75,7 +75,7 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector alora_invocation_tokens; - llama_adapter_lora(llama_model & model) : model(model) {} + llama_adapter_lora() = default; ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); @@ -86,3 +86,4 @@ struct llama_adapter_lora { }; using llama_adapter_loras = std::unordered_map; +using llama_adapter_loras_ptr = std::unique_ptr; diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index f736ee67..799d1616 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -4,6 +4,7 @@ #include #include +#include static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize @@ -26,6 +27,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, + { LLM_ARCH_EUROBERT, "eurobert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -37,6 +39,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35, "qwen35" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -72,15 +76,18 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_JAIS2, "jais2" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, + { LLM_ARCH_EXAONE_MOE, "exaone-moe" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -116,9 +123,12 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_PADDLEOCR, "paddleocr" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -160,6 +170,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -173,6 +185,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" }, { LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, @@ -190,6 +203,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -217,22 +231,28 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -245,6 +265,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_KDA_HEAD_DIM, "%s.kda.head_dim" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, @@ -332,6 +354,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, @@ -343,6 +366,8 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" }, + { LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" }, { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, @@ -353,12 +378,14 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_CLS, "cls" }, { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, @@ -370,6 +397,15 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, @@ -496,6 +532,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; static std::set llm_get_tensor_names(llm_arch arch) { @@ -709,6 +749,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_INTERNLM2: case LLM_ARCH_GRANITE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_PADDLEOCR: case LLM_ARCH_SMOLLM3: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: @@ -787,6 +828,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, }; + case LLM_ARCH_EUROBERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_MODERN_BERT: return { LLM_TENSOR_TOKEN_EMBD, @@ -800,6 +855,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, }; case LLM_ARCH_JINA_BERT_V2: return { @@ -952,11 +1008,11 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_OUT, LLM_TENSOR_ATTN_QKV, LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -969,6 +1025,64 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, }; + case LLM_ARCH_QWEN35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN35MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; case LLM_ARCH_QWEN3VL: case LLM_ARCH_CHAMELEON: case LLM_ARCH_HUNYUAN_DENSE: @@ -976,6 +1090,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT, + LLM_TENSOR_CLS_OUT, LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_Q_NORM, @@ -1497,6 +1612,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1549,6 +1665,12 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; case LLM_ARCH_GLM4_MOE: return { @@ -1581,6 +1703,46 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; + case LLM_ARCH_GLM_DSA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_BITNET: return { LLM_TENSOR_TOKEN_EMBD, @@ -1659,6 +1821,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, }; + case LLM_ARCH_JAIS2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_NEMOTRON_H: return { LLM_TENSOR_TOKEN_EMBD, @@ -1706,6 +1882,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, // MoE shared expert layer LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -1728,6 +1906,38 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_EXAONE_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_RWKV6: return { LLM_TENSOR_TOKEN_EMBD, @@ -2234,6 +2444,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, }; + case LLM_ARCH_STEP35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { @@ -2256,6 +2495,54 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, }; + case LLM_ARCH_KIMI_LINEAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + // Dense FFN (layer 0 only) + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + // MoE FFN (layers 1+) + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + // Shared experts + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + LLM_TENSOR_SSM_CONV1D_Q, + LLM_TENSOR_SSM_CONV1D_K, + LLM_TENSOR_SSM_CONV1D_V, + LLM_TENSOR_SSM_F_A, + LLM_TENSOR_SSM_F_B, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_G_A, + LLM_TENSOR_SSM_G_B, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_NORM, + // MLA + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_KV_A_NORM, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } @@ -2279,6 +2566,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2331,6 +2619,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2359,6 +2648,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2401,6 +2699,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, @@ -2448,6 +2747,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2456,6 +2759,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Nemotron 3 Super + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2493,6 +2799,15 @@ std::string LLM_TN_IMPL::str() const { return name; } +std::vector llm_arch_all() { + std::vector ret; + ret.reserve(LLM_ARCH_NAMES.size()); + for (const auto & [arch, _] : LLM_ARCH_NAMES) { + ret.push_back(arch); + } + return ret; +} + const char * llm_arch_name(llm_arch arch) { auto it = LLM_ARCH_NAMES.find(arch); if (it == LLM_ARCH_NAMES.end()) { @@ -2540,6 +2855,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_KIMI_LINEAR: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return true; default: return false; diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 68ec6a18..b1b1dcf1 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -4,6 +4,7 @@ #include #include +#include // // gguf constants (sync with gguf.py) @@ -30,6 +31,7 @@ enum llm_arch { LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V3, + LLM_ARCH_EUROBERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -41,6 +43,8 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -76,15 +80,18 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_JAIS2, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, + LLM_ARCH_EXAONE_MOE, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -120,9 +127,12 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -164,6 +174,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -177,6 +189,7 @@ enum llm_kv { LLM_KV_EXPERT_GROUP_SCALE, LLM_KV_EXPERTS_PER_GROUP, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, LLM_KV_POOLING_TYPE, @@ -194,6 +207,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -221,8 +235,14 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_KEY_LENGTH_SWA, + LLM_KV_ATTENTION_VALUE_LENGTH_SWA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_COUNT_SWA, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_FREQ_BASE_SWA, @@ -249,6 +269,8 @@ enum llm_kv { LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_KDA_HEAD_DIM, + LLM_KV_WKV_HEAD_SIZE, LLM_KV_TOKENIZER_MODEL, @@ -356,6 +378,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -363,6 +386,8 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_CHEXPS, LLM_TENSOR_FFN_UP_CHEXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -397,6 +422,16 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + LLM_TENSOR_SSM_ALPHA, // qwen3.5 + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5 + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, @@ -473,6 +508,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, @@ -497,6 +533,10 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, @@ -575,6 +615,8 @@ struct llm_tensor_info { ggml_op op; }; +std::vector llm_arch_all(); + const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index 386fab04..6bf76939 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t clear(); split_reset(); + const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd; + auto udata = std::make_shared(); udata->token .resize(n_tokens); udata->embd .clear(); - udata->pos .resize(n_tokens); + udata->pos .resize(n_pos_all); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index b54ebbd1..c415a998 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -57,6 +57,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, + { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, @@ -137,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("[gMASK]")) { return LLM_CHAT_TEMPLATE_CHATGLM_4; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + if (tmpl_contains("<|tool_declare|>")) { + return LLM_CHAT_TEMPLATE_EXAONE_MOE; + } return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { return LLM_CHAT_TEMPLATE_GLMEDGE; @@ -229,7 +233,7 @@ int32_t llm_chat_apply_template( llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + // Taken from the research: https://github.com/ggml-org/llama.cpp/issues/5527 std::stringstream ss; if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template @@ -576,6 +580,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "user") { + ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "assistant") { + ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "tool") { + ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n"; + } + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index e1f79524..9ed1db12 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -36,6 +36,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_4, + LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index f220010a..1f7a52d7 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -7,6 +7,7 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-ext.h" #include #include @@ -22,6 +23,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique()), + loras(std::make_unique()), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -146,6 +149,11 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + + cparams.fused_gdn_ar = true; + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -155,6 +163,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // initialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -249,11 +260,7 @@ llama_context::llama_context( // graph outputs buffer { - // resized during inference when a batch uses more outputs - // Create a dummy batch for initialization. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { + if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -302,16 +309,6 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = @@ -327,6 +324,7 @@ llama_context::llama_context( auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend + // TODO: should we ignore ACCEL types too? continue; } auto * dev = ggml_backend_get_device(backend.get()); @@ -340,142 +338,26 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); - } + if (cparams.pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); + if (!graph_reuse_disable) { + // TODO: figure out a way to make graph reuse work with pipeline parallelism + // ref: https://github.com/ggml-org/llama.cpp/pull/20463 + LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); + + graph_reuse_disable = true; } } - cross.v_embd.clear(); + sched_reserve(); - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; - - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); - - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); + if (!cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } - - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { - continue; - } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); - - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; - break; - } - } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } - } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); - } - } - - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; - - int n_splits_tg = -1; - int n_nodes_tg = -1; - - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), - model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); - if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); - } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); - } - - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } - - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); - } - - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); - } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); - } - } - - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } - - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } @@ -510,7 +392,254 @@ llama_context::~llama_context() { ggml_opt_free(opt_ctx); } +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } + + sched_need_reserve = false; + + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); + + synchronize(); + + const int64_t t_start_us = ggml_time_us(); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + // avoid reserving graphs with zero outputs - assume one output per sequence + const int n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + + cparams.auto_fa = false; + } + + if (cparams.auto_fgdn) { + LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__); + + if (cparams.fused_gdn_ar) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__); + } + } + + if (cparams.fused_gdn_ch) { + // more than one token in the batch per sequence in order to take the chunked path + // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, + // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies + // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, + // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + const uint32_t n_tokens_ch = 16*n_seqs; + auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__); + } + } + + cparams.auto_fgdn = false; + } + + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } + + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); + } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + + const int64_t t_end_us = ggml_time_us(); + + LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", + __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); +} + void llama_context::synchronize() { + if (!sched) { + return; + } + ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -645,7 +774,7 @@ enum llama_pooling_type llama_context::pooling_type() const { float * llama_context::get_logits() { output_reorder(); - return logits; + return logits.data; } int64_t llama_context::output_resolve_row(int32_t i) const { @@ -678,36 +807,15 @@ int64_t llama_context::output_resolve_row(int32_t i) const { } float * llama_context::get_logits_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { - if (logits == nullptr) { + if (logits.data == nullptr) { throw std::runtime_error("no logits"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - return logits + j*model.vocab.n_tokens(); + const int64_t j = output_resolve_row(i); + return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -721,45 +829,24 @@ float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_embeddings() { output_reorder(); - return embd; + return embd.data; } llama_token * llama_context::get_sampled_tokens() const{ - return sampling.sampled; + return sampling.sampled.data; } float * llama_context::get_embeddings_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { - if (embd == nullptr) { + if (embd.data == nullptr) { throw std::runtime_error("no embeddings"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - const uint32_t n_embd_out = model.hparams.get_n_embd_out(); - return embd + j*n_embd_out; + const int64_t j = output_resolve_row(i); + const uint32_t n_embd_out = model.hparams.n_embd_out(); + return embd.data + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -782,14 +869,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); - if (sampling.sampled == nullptr) { + if (!sampling.sampled.has_data()) { return LLAMA_TOKEN_NULL; } try { const int64_t row = output_resolve_row(idx); - GGML_ASSERT(row < (int64_t) sampling.sampled_size); - return sampling.sampled[row]; + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; @@ -799,7 +886,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return nullptr; } @@ -808,7 +895,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { return nullptr; } - return sampling.probs + row*model.vocab.n_tokens(); + return sampling.probs.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -818,7 +905,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return nullptr; } @@ -827,7 +914,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) { if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { return nullptr; } - return sampling.logits + row*model.vocab.n_tokens(); + return sampling.logits.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -839,13 +926,14 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { try { const int64_t row = output_resolve_row(idx); - if (sampling.candidates != nullptr && + if (sampling.candidates.has_data() && (size_t) row < sampling.candidates_count.size() && sampling.candidates_count[row] > 0) { - return sampling.candidates + row*model.vocab.n_tokens(); + return sampling.candidates.data + row*model.vocab.n_tokens(); } } catch (const std::exception & err) { // fallback to full vocab list + GGML_UNUSED(err); } return sampling.token_ids_full_vocab.data(); @@ -854,7 +942,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) { output_reorder(); - if (sampling.candidates == nullptr) { + if (!sampling.candidates.has_data()) { return 0; } @@ -873,7 +961,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return model.vocab.n_tokens(); } @@ -892,7 +980,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return 0; } @@ -951,21 +1039,41 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.causal_attn == value) { + return; + } + cparams.causal_attn = value; + + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + // warmups are usually with small batches, so no need to reserve + //sched_need_reserve = true; } bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); const bool can_offload = @@ -975,22 +1083,24 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); sampling.samplers[seq_id] = sampler; + sched_need_reserve = true; + return true; } if (sampler && !can_offload) { LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); return false; @@ -998,37 +1108,56 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers.erase(seq_id); + sched_need_reserve = true; + return true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - loras[adapter] = scale; -} - -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - - auto pos = loras.find(adapter); - if (pos != loras.end()) { - loras.erase(pos); - return true; + if (adapters_lora_are_same(adapters, n_adapters, scales)) { + return; } - return false; + loras.reset(new llama_adapter_loras()); + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras->insert({adapters[i], scales[i]}); + } + } + + sched_need_reserve = true; } -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - loras.clear(); + // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison. + size_t n_non_zero = 0; + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] == 0.0f) { + continue; + } + n_non_zero++; + + auto it = loras->find(adapters[i]); + + if (it == loras->end() || it->second != scales[i]) { + return false; + } + } + + if (n_non_zero != loras->size()) { + return false; + } + + return true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -1036,7 +1165,9 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - return cvec.apply(model, data, len, n_embd, il_start, il_end); + // TODO: should we reserve? + + return cvec->apply(model, data, len, n_embd, il_start, il_end); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1086,6 +1217,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll { //const auto t_start_us = ggml_time_us(); + // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); @@ -1138,10 +1270,12 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens, batch_inp) < n_tokens) { + if (output_reserve(n_tokens) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -1177,16 +1311,16 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits.data && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings - if (embd && t_embd) { + if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1194,11 +1328,11 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1246,7 +1380,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cross.n_embd = t_embd->ne[0]; cross.n_enc = t_embd->ne[1]; cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); const auto & batch = balloc->get_batch(); @@ -1286,11 +1420,10 @@ static std::map build_seq_to_output_row(const llama_ubat static void copy_tensor_async_ints( const std::map & tensor_map, - llama_token * sampled, - size_t sampled_size, + const buffer_view & sampled, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (sampled == nullptr) { + if (!sampled.has_data()) { return; } @@ -1301,23 +1434,23 @@ static void copy_tensor_async_ints( } const uint32_t row = it->second; - GGML_ASSERT(row < sampled_size); + GGML_ASSERT(row < sampled.size); GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); } } static void copy_tensor_async_floats( const std::map & tensor_map, - float * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1333,7 +1466,7 @@ static void copy_tensor_async_floats( GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst + (size_t) row * stride; + float * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of logits/probabilities that were written for this row. @@ -1343,12 +1476,12 @@ static void copy_tensor_async_floats( static void copy_tensor_async_candidates( const std::map & tensor_map, - llama_token * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1364,7 +1497,7 @@ static void copy_tensor_async_candidates( GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst + (size_t) row * stride; + llama_token * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of candidates that were written. @@ -1372,6 +1505,23 @@ static void copy_tensor_async_candidates( } } +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers) { + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + // Check if the output token has at least one sequence without a backend sampler. + for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = ubatch.seq_id[i][j]; + if (samplers.find(seq_id) == samplers.end()) { + return true; + } + } + } + return false; // all sequences use backend sampling +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1451,6 +1601,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies @@ -1502,7 +1654,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1575,25 +1727,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - // For multi-sequence batches that mix backend samplers and CPU sampler - // this is currently inefficient as we copy all logits even for the - // backend sampled tokens. - if (logits && t_logits && n_outputs > 0) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits.data + n_outputs_prev*n_vocab; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } // extract embeddings - if (embd && t_embd && n_outputs > 0) { + if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1601,13 +1750,13 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); - float * embd_out = embd + n_outputs_prev*n_embd_out; + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); + float * embd_out = embd.data + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; @@ -1648,16 +1797,13 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // This flag indicates whether a backend sampler has actually sampled a specific - // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. - const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); - - if (has_samplers && has_sampled) { + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto stride = n_vocab; // async copy the sampling data from the backend to the host - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); @@ -1727,7 +1873,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { +uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1735,7 +1881,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd_out = hparams.get_n_embd_out(); + const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; bool has_embd = cparams.embeddings; @@ -1746,52 +1892,18 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba has_embd = true; } - // Check which sampling modes are needed for the current batch. - // TODO: avoid this branching by working with the worst-case - bool has_sampling = false; - bool cpu_logits = false; - - if (batch.logits) { - for (int32_t i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; - } - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - llama_seq_id seq_id = batch.seq_id[i][j]; - if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { - has_sampling = true; - } else { - cpu_logits = true; - } - } - } - } else { - // When batch.logits is nullptr (when loading state with a dummy batch), - // allocate CPU logits. - cpu_logits = true; - } size_t backend_float_count = 0; size_t backend_token_count = 0; - // Allocate CPU logits buffer only if needed by sequences in this batch - logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - // TODO: avoid this branching by working with the worst-case - if (!has_sampling) { - sampling.logits_size = 0; - sampling.probs_size = 0; - sampling.sampled_size = 0; - sampling.candidates_size = 0; - } else { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; - - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + // Allocate backend sampling output buffers if there are backend samplers configured. + const bool has_sampling = !sampling.samplers.empty(); + if (has_sampling) { + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates } if (output_ids.empty()) { @@ -1801,7 +1913,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits_size + embd_size + backend_float_count) * sizeof(float) + + (logits.size + embd.size + backend_float_count) * sizeof(float) + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required @@ -1816,8 +1928,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba // TODO: not needed? buf_output = nullptr; - logits = nullptr; - embd = nullptr; + logits.data = nullptr; + embd.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1836,35 +1948,27 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = nullptr; - embd = nullptr; - size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = (has_logits && cpu_logits) ? output_base : nullptr; - offset += logits_size * sizeof(float); + logits = has_logits ? buffer_view{output_base, logits.size} : buffer_view{nullptr, 0}; + offset += logits.size * sizeof(float); - embd = has_embd ? (float *) (base + offset) : nullptr; - offset += embd_size * sizeof(float); - - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; + offset += embd.size * sizeof(float); if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); // The count vectors keep track of the actual number of logits/probs/candidates // copied from the backend for each output row. @@ -1877,7 +1981,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); } // set all ids as invalid (negative) @@ -1896,49 +2009,42 @@ void llama_context::output_reorder() { const uint64_t i0 = output_swaps[s].i0; const uint64_t i1 = output_swaps[s].i1; - if (logits_size > 0) { + if (logits.size > 0) { for (uint64_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); } } - if (embd_size > 0) { + if (embd.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); } } - if (sampling.logits && sampling.logits_size > 0) { + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } - } - if (sampling.probs && sampling.probs_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } - } - if (sampling.candidates && sampling.candidates_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } - } - if (sampling.sampled && sampling.sampled_size > 0) { - std::swap(sampling.sampled[i0], sampling.sampled[i1]); - } - - if (!sampling.logits_count.empty()) { - std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); - } - - if (!sampling.probs_count.empty()) { - std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); - } - - if (!sampling.candidates_count.empty()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); } } @@ -1951,11 +2057,13 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); - res += model.n_lora_nodes; + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } return res; } @@ -1977,7 +2085,7 @@ ggml_cgraph * llama_context::graph_reserve( ggml_backend_sched_reset(sched.get()); - // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that + // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); // store the n_outputs as it is, and restore it afterwards @@ -2037,8 +2145,8 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, @@ -2085,13 +2193,6 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; @@ -2443,63 +2544,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2525,73 +2569,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - // Create a dummy batch for state loading. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (n_outputs > output_reserve(n_outputs, dummy_batch)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd_size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); @@ -2724,6 +2701,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->cls_b, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); for (struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { @@ -2780,7 +2758,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2815,7 +2793,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); res->set_inputs(&ubatch); @@ -2957,19 +2935,23 @@ llama_context * llama_init_from_model( if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); - if (model->hparams.n_embd_head_k % blck_size != 0) { - LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); - return nullptr; + for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + if (model->hparams.n_embd_head_k(il) % blck_size != 0) { + LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); + return nullptr; + } } } if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); - if (model->hparams.n_embd_head_v % blck_size != 0) { - LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); - return nullptr; + for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + if (model->hparams.n_embd_head_v(il) % blck_size != 0) { + LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", + __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); + return nullptr; + } } } @@ -3161,37 +3143,43 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { return static_cast(ctx->get_sampled_probs_count(i)); } +struct ggml_cgraph * llama_graph_reserve( + struct llama_context * ctx, + uint32_t n_tokens, + uint32_t n_seqs, + uint32_t n_outputs) { + auto * memory = ctx->get_memory(); + llama_memory_context_ptr mctx; + if (memory) { + mctx = memory->init_full(); + } + return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get()); +} + // llama adapter API -int32_t llama_set_adapter_lora( +int32_t llama_set_adapters_lora( llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } + + ctx->set_adapters_lora(adapters, n_adapters, scales); return 0; } -int32_t llama_rm_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); - - return res ? 0 : -1; -} - -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); -} - -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index b29edf4d..e0d0085c 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -4,6 +4,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-impl.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -40,6 +41,14 @@ struct llama_context { ~llama_context(); + // reserve a new backend scheduler (if needed) + // for example, when: + // - changing loras + // - changing samplers + // - changing attention type + // - etc. + void sched_reserve(); + void synchronize(); const llama_model & get_model() const; @@ -96,16 +105,11 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool rm_adapter_lora( - llama_adapter_lora * adapter); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); - - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -204,7 +208,7 @@ private: // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); + uint32_t output_reserve(int32_t n_outputs); void output_reorder(); @@ -252,43 +256,36 @@ private: const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - size_t logits_size = 0; // capacity (of floats) for logits - float * logits = nullptr; + buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; // capacity (of floats) for embeddings - float * embd = nullptr; + buffer_view embd = {nullptr, 0}; - // TODO: simplify struct sampling_info { + // !samplers.empty() to check if any samplers are active std::map samplers; - float * logits = nullptr; - size_t logits_size = 0; - - llama_token * sampled = nullptr; - size_t sampled_size = 0; - - float * probs = nullptr; - size_t probs_size = 0; - - llama_token * candidates = nullptr; - size_t candidates_size = 0; + buffer_view logits = {nullptr, 0}; + buffer_view sampled = {nullptr, 0}; + buffer_view probs = {nullptr, 0}; + buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count; std::vector candidates_count; + // optimization std::vector token_ids_full_vocab; }; @@ -314,6 +311,8 @@ private: ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector backends; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index fcef8fa9..9d359474 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -30,10 +30,15 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; + bool fused_gdn_ar; // use fused gated delta net (autoregressive) + bool fused_gdn_ch; // use fused gated delta net (chunked) + bool auto_fgdn; bool no_perf; bool warmup; bool op_offload; bool kv_unified; + bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h new file mode 100644 index 00000000..13ced783 --- /dev/null +++ b/examples/talk-llama/llama-ext.h @@ -0,0 +1,12 @@ +#pragma once + +#include "llama-context.h" +#include "ggml.h" +#include "stdint.h" + +// Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. +LLAMA_API struct ggml_cgraph * llama_graph_reserve( + struct llama_context * ctx, + uint32_t n_tokens, + uint32_t n_seqs, + uint32_t n_outputs); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 64ea2fd0..aac0d41f 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -2,7 +2,7 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-sampling.h" +#include "llama-sampler.h" #include #include @@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting an int at ") + pos); } const char * int_end = parse_int(pos); - uint64_t min_times = std::stoul(std::string(pos, int_end - pos)); + uint64_t min_times = std::stoull(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); uint64_t max_times = UINT64_MAX; // default: no max limit @@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence( if (is_digit_char(*pos)) { const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); + max_times = std::stoull(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); } @@ -1160,13 +1160,13 @@ struct llama_grammar * llama_grammar_init_impl( // if there is a grammar, parse it // rules will be empty (default) if there are parse errors if (!parser.parse(grammar_str) || parser.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); + LLAMA_LOG_ERROR("failed to parse grammar\n"); return nullptr; } - // Ensure that there is a "root" node. - if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + // Ensure that the grammar contains the start symbol + if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) { + LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root); return nullptr; } @@ -1195,7 +1195,7 @@ struct llama_grammar * llama_grammar_init_impl( continue; } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { - LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i); return nullptr; } } diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 374ff1eb..9a215bb7 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -7,13 +7,51 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include #include #include +#include +#include #include +// dedup helpers + +static ggml_tensor * build_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -22,7 +60,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -32,8 +71,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; - res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -96,11 +135,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { int32_t * data = (int32_t *) pos_bucket->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); - } + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); } } } @@ -148,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { } void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && + (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) { + const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -210,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const bool last = ( cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || - (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token + (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token ); for (int i = 0; i < n_tokens; ++i) { @@ -323,34 +363,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - const llama_pos p1 = ubatch->pos[i1]; + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; + const uint64_t idst = i1*n_kv; - for (int i0 = 0; i0 < n_tokens; ++i0) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; - const llama_pos p0 = ubatch->pos[i0]; + for (int i0 = 0; i0 < n_tokens; ++i0) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; - // mask different sequences - if (s0 != s1) { - continue; - } - - // mask future tokens - if (cparams.causal_attn && p0 > p1) { - continue; - } - - // apply SWA if any - if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { - continue; - } - - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + // mask different sequences + if (s0 != s1) { + continue; } + + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; + } + + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; + } + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } }; @@ -403,8 +441,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); + + return res; +} + +void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -434,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; - - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -454,27 +508,20 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { float * data = (float *) cross_kq_mask->data; - for (int h = 0; h < 1; ++h) { - for (int i = 0; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; - } + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; } - - data[h*(n_enc*n_tokens) + i*n_enc + j] = f; } - } - for (int i = n_tokens; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; - } + data[i*n_enc + j] = f; } } } @@ -508,8 +555,118 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + } res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -575,7 +732,8 @@ int64_t llm_graph_result::get_max_nodes() const { } void llm_graph_result::reset() { - t_tokens = nullptr; + t_inp_tokens = nullptr; + t_inp_embd = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; @@ -691,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : ubatch (params.ubatch), n_embd (hparams.n_embd), n_layer (hparams.n_layer), - n_rot (hparams.n_rot), + n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), n_head_kv (hparams.n_head_kv()), - n_embd_head_k (hparams.n_embd_head_k), + n_embd_head_k (hparams.n_embd_head_k()), n_embd_k_gqa (hparams.n_embd_k_gqa()), - n_embd_head_v (hparams.n_embd_head_v), + n_embd_head_v (hparams.n_embd_head_v()), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used), @@ -742,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * llm_graph_context::build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const { + ggml_tensor * cur, + ggml_tensor * w_s) const { ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); for (const auto & lora : *loras) { @@ -763,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm( res = ggml_add(ctx0, res, ab_cur); } + if (w_s) { + res = ggml_mul(ctx0, res, w_s); + } + return res; } @@ -888,6 +1051,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -951,8 +1134,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -984,11 +1167,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * up_exps_s, + ggml_tensor * gate_exps_s, + ggml_tensor * down_exps_s) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1000,11 +1186,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( n_expert_used, type_op, norm_w, - scale_w, w_scale, gating_op, il, - probs_in + probs_in, + gate_up_exps, + /* gate_up_exps_b */ nullptr, + up_exps_s, + gate_exps_s, + down_exps_s ); } @@ -1023,11 +1213,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * gate_up_exps_b, + ggml_tensor * up_exps_s, + ggml_tensor * gate_exps_s, + ggml_tensor * down_exps_s) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1149,7 +1343,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); } - if (scale_w) { + if (w_scale != 0.0f && w_scale != 1.0f) { weights = ggml_scale(ctx0, weights, w_scale); cb(weights, "ffn_moe_weights_scaled", il); } @@ -1166,30 +1360,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - if (up_exps_b) { - up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); - cb(up, "ffn_moe_up_biased", il); - } - + ggml_tensor * up = nullptr; ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + + if (gate_up_exps) { + // merged gate_up path: one mul_mat_id, then split into gate and up views + ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens] + cb(gate_up, "ffn_moe_gate_up", il); + + if (gate_up_exps_b) { + gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts); + cb(gate_up, "ffn_moe_gate_up_biased", il); + } + + // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused) + if (up_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + gate_up = ggml_mul(ctx0, gate_up, s); + cb(gate_up, "ffn_moe_gate_up_scaled", il); + } + + const int64_t n_ff = gate_up->ne[0] / 2; + cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); + up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]); + cb(up, "ffn_moe_up", il); } else { - cur = up; + // separate gate and up path + up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + + // apply per-expert scale2 to up + if (up_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + up = ggml_mul(ctx0, up, s); + cb(up, "ffn_moe_up_scaled", il); + } + + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } + + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } + + // apply per-expert scale2 to gate + if (gate_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + cur = ggml_mul(ctx0, cur, s); + cb(cur, "ffn_moe_gate_scaled", il); + } } - if (gate_exps_b) { - cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); - cb(cur, "ffn_moe_gate_biased", il); - } + const bool has_gate = gate_exps || gate_up_exps; switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + } + + if (has_gate) { cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1197,7 +1461,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - if (gate_exps) { + if (has_gate) { cur = ggml_geglu_split(ctx0, cur, up); cb(cur, "ffn_moe_geglu", il); } else { @@ -1213,7 +1477,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_swiglu_oai", il); } break; case LLM_FFN_RELU: - if (gate_exps) { + if (has_gate) { cur = ggml_reglu_split(ctx0, cur, up); cb(cur, "ffn_moe_reglu", il); } else { @@ -1221,7 +1485,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_relu", il); } break; case LLM_FFN_RELU_SQR: - if (gate_exps) { + if (has_gate) { // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { @@ -1241,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(experts, "ffn_moe_down_biased", il); } + // apply per-expert scale2 to down + if (down_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + experts = ggml_mul(ctx0, experts, s); + cb(experts, "ffn_moe_down_scaled", il); + } + if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); cb(cur, "ffn_moe_weighted", il); @@ -1279,17 +1552,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd = hparams.n_embd; - auto inp = std::make_unique(); + assert(n_embd_inp >= n_embd); - ggml_tensor * cur = nullptr; + auto inp = std::make_unique(n_embd_inp); - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); + + // select one of the 2 inputs, based on the batch contents + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + std::array inps; + + // token embeddings path (ubatch.token != nullptr) + { + auto & cur = inps[0]; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1310,19 +1595,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); + + if (n_embd_inp != n_embd) { + cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0); + } + } + + // vector embeddings path (ubatch.embd != nullptr) + { + auto & cur = inps[1]; cur = inp->embd; } + assert(ggml_are_same_shape (inps[0], inps[1])); + assert(ggml_are_same_stride(inps[0], inps[1])); + + ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + + if (n_embd_inp != n_embd) { + cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0); + } + + res->t_inp_embd = cur; + // For Granite architecture if (hparams.f_embedding_scale != 0.0f) { cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); @@ -1354,6 +1656,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { // this need to be 1x1xN for broadcasting cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens); ggml_set_input(cur); + ggml_set_name(cur, "attn_scale"); res->add_input(std::move(inp)); @@ -1363,7 +1666,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, // but this would make the graph topology depend on the number of output tokens, which can interere with - // features that require constant topology such as pipline parallelism + // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { // return nullptr; @@ -1421,7 +1724,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { //} const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); - const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(cur); @@ -1499,7 +1802,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - if (cparams.flash_attn && kq_b == nullptr) { + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; + if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1525,7 +1829,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (v_mla) { #if 0 // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. - // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. + // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); #else @@ -1695,14 +1999,11 @@ static std::unique_ptr build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1728,9 +2029,11 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, ggml_tensor * sinks, - ggml_tensor * v_mla, + ggml_tensor * v_mla, // TODO: remove float kq_scale, int il) const { + GGML_ASSERT(v_mla == nullptr); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -1758,6 +2061,89 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +static std::unique_ptr build_attn_inp_k_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx_cur) { + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return inp; +} + +llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + + return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + if (wo) { cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { @@ -1903,15 +2289,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); ggml_set_input(inp->self_kq_mask); ggml_set_name(inp->self_kq_mask, "self_kq_mask"); @@ -1922,12 +2304,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); ggml_set_input(inp->self_kq_mask_swa); ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); @@ -2068,10 +2448,57 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); + + { + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); + ggml_set_input(inp_attn->self_kq_mask); + + inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + } + + { + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); + ggml_set_input(inp_attn->self_kq_mask_swa); + + inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const { - if (!cparams.embeddings || !(dense_2 || dense_3)) { + if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; @@ -2080,6 +2507,9 @@ void llm_graph_context::build_dense_out( if (dense_2) { cur = ggml_mul_mat(ctx0, dense_2, cur); } + if (dense_2_b) { + cur = ggml_add(ctx0, cur, dense_2_b); + } if (dense_3) { cur = ggml_mul_mat(ctx0, dense_3, cur); } @@ -2093,7 +2523,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -2132,8 +2563,15 @@ void llm_graph_context::build_pooling( } break; case LLAMA_POOLING_TYPE_RANK: { - ggml_tensor * inp_cls = build_inp_cls(); - cur = ggml_get_rows(ctx0, inp, inp_cls); + if (arch == LLM_ARCH_MODERN_BERT) { + // modern bert gte reranker builds mean first then applies prediction head and classifier + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411 + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } else { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } // classification head // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 @@ -2142,7 +2580,15 @@ void llm_graph_context::build_pooling( if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } - cur = ggml_tanh(ctx0, cur); + if (arch == LLM_ARCH_MODERN_BERT) { + cur = ggml_gelu(ctx0, cur); + } else { + cur = ggml_tanh(ctx0, cur); + } + if (cls_norm) { + // head norm + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } } // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en @@ -2157,7 +2603,7 @@ void llm_graph_context::build_pooling( } // softmax for qwen3 reranker - if (arch == LLM_ARCH_QWEN3) { + if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) { cur = ggml_soft_max(ctx0, cur); } } break; @@ -2178,6 +2624,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); @@ -2198,14 +2647,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2222,22 +2671,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 503ffd69..4855685e 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -24,6 +24,7 @@ class llama_kv_cache_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -105,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -114,6 +115,8 @@ public: ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { @@ -314,6 +317,39 @@ public: const llama_kv_cache_context * mctx; }; +// V-less input for the KV cache +// ref: https://github.com/ggml-org/llama.cpp/pull/19067 +class llm_graph_input_attn_k : public llm_graph_input_i { +public: + llm_graph_input_attn_k( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -397,6 +433,62 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + class llm_graph_input_sampling : public llm_graph_input_i { public: llm_graph_input_sampling(std::map samplers) : @@ -537,7 +629,7 @@ public: virtual ~llm_graph_result() = default; - ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } @@ -564,7 +656,8 @@ public: void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; @@ -671,10 +764,11 @@ struct llm_graph_context { ggml_tensor * cur, int il) const; - // do mat_mul, while optionally apply lora + // do mat_mul, while optionally apply lora and per-tensor scale ggml_tensor * build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const; + ggml_tensor * cur, + ggml_tensor * w_s = nullptr) const; // do mat_mul_id, while optionally apply lora ggml_tensor * build_lora_mm_id( @@ -717,11 +811,14 @@ struct llm_graph_context { int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * up_exps_s = nullptr, + ggml_tensor * gate_exps_s = nullptr, + ggml_tensor * down_exps_s = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -738,11 +835,15 @@ struct llm_graph_context { int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * gate_up_exps_b = nullptr, + ggml_tensor * up_exps_s = nullptr, + ggml_tensor * gate_exps_s = nullptr, + ggml_tensor * down_exps_s = nullptr) const; // // inputs @@ -801,6 +902,21 @@ struct llm_graph_context { ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove + float kq_scale, + int il) const; + + llm_graph_input_attn_k * build_attn_inp_k() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -880,6 +996,9 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; + + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; // // pooling @@ -889,7 +1008,8 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; // // sampling (backend sampling) @@ -903,6 +1023,7 @@ struct llm_graph_context { void build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const; }; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index c847ef91..002d15d4 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -62,6 +62,14 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { return n_head/n_head_kv; } +uint32_t llama_hparams::n_rot(uint32_t il) const { + if (il < n_layer) { + return is_swa(il) ? n_rot_swa : n_rot_full; + } + + GGML_ABORT("fatal error"); +} + uint32_t llama_hparams::n_embd_inp() const { uint32_t n_embd_inp = n_embd; @@ -72,20 +80,36 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } -uint32_t llama_hparams::get_n_embd_out() const { - return n_embd_out > 0 ? n_embd_out : n_embd; +uint32_t llama_hparams::n_embd_out() const { + return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; +} + +uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { + if (il < n_layer) { + return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_embd_head_v(uint32_t il) const { + if (il < n_layer) { + return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full; + } + + GGML_ABORT("fatal error"); } uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); - return n_embd_head_k * n_head_kv; + return n_embd_head_k(il) * n_head_kv; } uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); - return n_embd_head_v * n_head_kv; + return n_embd_head_v(il) * n_head_kv; } bool llama_hparams::is_n_embd_k_gqa_variable() const { @@ -139,6 +163,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -151,6 +182,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -175,6 +213,21 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } +bool llama_hparams::is_mla() const { + assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) || + (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0)); + + return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0; +} + +uint32_t llama_hparams::n_embd_head_k_mla() const { + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k(); +} + +uint32_t llama_hparams::n_embd_head_v_mla() const { + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v(); +} + bool llama_hparams::has_kv(uint32_t il) const { if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { @@ -200,42 +253,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 7ae3ec29..78c0bc27 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -41,19 +42,25 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_embd_features = 0; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache - uint32_t n_rot; - uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads - uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // different head size for full_attention and SWA layers + uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head + uint32_t n_embd_head_k_swa; + uint32_t n_embd_head_v_swa; + + // different RoPE dimensions for full_attention and SWA layers + uint32_t n_rot_full; + uint32_t n_rot_swa; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - uint32_t n_embd_head_k_mla = 0; - uint32_t n_embd_head_v_mla = 0; + uint32_t n_embd_head_k_mla_impl = 0; + uint32_t n_embd_head_v_mla_impl = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -82,6 +89,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; + uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; float f_norm_eps; @@ -136,6 +144,9 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Linear KDA + uint32_t n_embd_head_kda = 0; + // for hybrid state space models std::array recurrent_layer_arr; @@ -163,7 +174,7 @@ struct llama_hparams { uint32_t n_cls_out = 1; // output embedding dimension (0 = use n_embd) - uint32_t n_embd_out = 0; + uint32_t n_embd_out_impl = 0; // llama4 smallthinker uint32_t n_moe_layer_step = 0; @@ -190,11 +201,16 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack uint32_t n_deepstack_layers = 0; // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; uint32_t dec_n_layer = 0; @@ -202,6 +218,11 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array swiglu_clamp_exp; // clamping for expert FFN + std::array swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA @@ -234,11 +255,17 @@ struct llama_hparams { uint32_t n_gqa(uint32_t il = 0) const; + uint32_t n_rot(uint32_t il = 0) const; + // dimension of main + auxiliary input embeddings uint32_t n_embd_inp() const; // dimension of output embeddings - uint32_t get_n_embd_out() const; + uint32_t n_embd_out() const; + + // dimension of key/value embeddings for each head (per layer) + uint32_t n_embd_head_k(uint32_t il = 0) const; + uint32_t n_embd_head_v(uint32_t il = 0) const; // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; @@ -268,15 +295,57 @@ struct llama_hparams { bool is_swa(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA + bool is_mla() const; + + uint32_t n_embd_head_k_mla() const; + uint32_t n_embd_head_v_mla() const; + bool has_kv(uint32_t il) const; // number of layers for which has_kv() returns true uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 8e3e7b22..4c0188ee 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -100,18 +100,18 @@ std::string format(const char * fmt, ...) { std::string llama_format_tensor_shape(const std::vector & ne) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + snprintf(buf, sizeof(buf), "%6" PRId64, ne.at(0)); for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, ne.at(i)); } return buf; } std::string llama_format_tensor_shape(const struct ggml_tensor * t) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + snprintf(buf, sizeof(buf), "%6" PRId64, t->ne[0]); for (int i = 1; i < GGML_MAX_DIMS; i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, t->ne[i]); } return buf; } diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index c3391e79..e4f35c8e 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -49,6 +49,16 @@ struct time_meas { int64_t & t_acc; }; +template +struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } +}; + void replace_all(std::string & s, const std::string & search, const std::string & replace); // TODO: rename to llama_format ? @@ -60,4 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__" +#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__" diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 3a34102a..26e2cb42 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -218,7 +218,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 3186242d..01166fac 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache( __func__, hparams.n_embd_v_gqa_max()); } + const bool is_mla = hparams.is_mla(); + for (uint32_t il = 0; il < hparams.n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + const bool has_k = true; + const bool has_v = !is_mla; - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; + + has_k && ggml_format_name(k, "cache_k_l%d", il); + has_v && ggml_format_name(v, "cache_v_l%d", il); std::vector k_stream; std::vector v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); - v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -578,7 +583,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector 1) { + return false; + } return true; } @@ -1018,8 +1033,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; return ggml_view_4d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(k->type, hparams.n_embd_head_k), + hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(k->type, hparams.n_embd_head_k(il)), ggml_row_size(k->type, n_embd_k_gqa), ggml_row_size(k->type, n_embd_k_gqa*kv_size), ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); @@ -1041,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1] ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); @@ -1050,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1] ggml_row_size(v->type, kv_size), // v->nb[2] ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); @@ -1237,6 +1252,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; + + const std::vector & v_cells; + const std::vector & seq_to_stream; + + uint32_t n_swa; + llama_swa_type swa_type; + + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; + + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; + + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; + + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; + + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); + + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + // bookkeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; + } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); + + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + + if (causal) { + // mask future tokens + if (p0 > p1) { + goto skip; + } + + // M-RoPE causal mask + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } + } + } + } + + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; + } + } + + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; + } + + continue; +skip: + data[idst + j] = -INFINITY; + } + } + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1251,74 +1457,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; - std::fill(data, data + ggml_nelements(dst), -INFINITY); + //const int64_t t_start = ggml_time_us(); - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; - - const auto & cells = v_cells[seq_to_stream[seq_id]]; - - const llama_pos p1 = ubatch->pos[i]; - - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); - - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } - - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(j); - - // mask future tokens - if (causal_attn && p0 > p1) { - continue; - } - - // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; - } - } - - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; - } - - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; - } - } - } + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); } void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { @@ -1370,7 +1531,7 @@ size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); + size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0; } return size_v_bytes; @@ -1383,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_tensor * shift, ggml_tensor * factors, float freq_base, - float freq_scale) const { + float freq_scale, + uint32_t il) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; @@ -1391,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( const auto & yarn_beta_slow = cparams.yarn_beta_slow; const auto & yarn_attn_factor = cparams.yarn_attn_factor; - const auto & n_rot = hparams.n_rot; + const auto & n_rot = hparams.n_rot(il); const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE // @ngxson : this is a workaround // for M-RoPE, we want to rotate the whole vector when doing KV shift @@ -1445,9 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1461,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const auto n_rot = hparams.n_rot(il); + const auto n_embd_head_k = hparams.n_embd_head_k(il); + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -1468,12 +1631,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -1483,10 +1646,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); @@ -1599,8 +1758,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); - // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it - // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 + if (hparams.n_pos_per_embd() > 1) { + const llama_kv_cell_ext ext = cells.ext_get(i); + io.write(&ext, sizeof(ext)); + } for (const auto & seq_id : seq_ids) { io.write(&seq_id, sizeof(seq_id)); @@ -1618,8 +1779,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (const auto & layer : layers) { @@ -1637,7 +1796,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Read each range of cells of k_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; @@ -1652,6 +1811,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1661,7 +1823,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); - // Read each range of cells of v_size length each into tmp_buf and write out + // Read each range of cells of v_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; @@ -1678,6 +1840,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1692,7 +1857,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size_el length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; @@ -1730,6 +1895,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 return false; } + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read_to(&ext, sizeof(ext)); + + ubatch.pos[i + ubatch.n_tokens] = ext.y; + ubatch.pos[i + ubatch.n_tokens*2] = ext.x; + } + // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; @@ -1780,6 +1953,12 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 cells.pos_set(i, pos); + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read_to(&ext, sizeof(ext)); + cells.ext_set(i, ext); + } + for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); @@ -1881,6 +2060,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; @@ -1922,6 +2104,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 0c4ed648..33c78c5f 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -257,8 +257,6 @@ private: size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, @@ -266,7 +264,8 @@ private: ggml_tensor * shift, ggml_tensor * factors, float freq_base, - float freq_scale) const; + float freq_scale, + uint32_t il) const; ggml_cgraph * build_graph_shift( llm_graph_result * res, diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp new file mode 100644 index 00000000..41176967 --- /dev/null +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,275 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map llama_memory_hybrid_iswa::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast(ctx_recr.get()); +} diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h new file mode 100644 index 00000000..807c8aac --- /dev/null +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include +#include + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr mem_attn; + const std::unique_ptr mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 812bf253..6e8413f4 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); return false; } // invalidate tails which will be cleared @@ -785,23 +785,21 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell + // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (r_l[il] == nullptr) continue; - // Write key type + // Write R tensor type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); - // Write row size of key + // Write row size of R tensor const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Write each range of cells of r_size_row length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -814,15 +812,15 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (s_l[il] == nullptr) continue; - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); - // Write row size of value + // Write row size of S tensor const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Read each range of cells of s_size length each into tmp_buf and write out + // Write each range of S tensor rows for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -830,7 +828,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: } } } else { - // When v is transposed, we also need the element size and get the element ranges from each row + // When S tensor is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) @@ -838,7 +836,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_s = hparams.n_embd_s(); - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -851,7 +849,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index 2da857b3..c03228e9 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -244,11 +244,14 @@ struct llama_file::impl { } errno = 0; if (fd == -1) { - std::size_t ret = std::fread(ptr, len, 1, fp); + const size_t curr_off = tell(); + const size_t to_read = std::min(len, size - curr_off); + + std::size_t ret = std::fread(ptr, to_read, 1, fp); if (ferror(fp)) { throw std::runtime_error(format("read error: %s", strerror(errno))); } - if (ret != 1) { + if (to_read > 0 && ret != 1) { throw std::runtime_error("unexpectedly reached end of file"); } } else { @@ -262,7 +265,8 @@ struct llama_file::impl { continue; // Interrupted by signal, retry } // Fallback to std::fread in case the DMA controller cannot access the buffer - if (errno == EFAULT) { + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); auto curr_off = tell(); close(fd); fd = -1; @@ -381,6 +385,9 @@ int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else @@ -497,6 +504,8 @@ struct llama_mmap::impl { } } #elif defined(_WIN32) + HANDLE hMapping = nullptr; + impl(struct llama_file * file, size_t prefetch, bool numa) { GGML_UNUSED(numa); @@ -504,7 +513,7 @@ struct llama_mmap::impl { HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); if (hMapping == NULL) { DWORD error = GetLastError(); @@ -513,9 +522,9 @@ struct llama_mmap::impl { addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); DWORD error = GetLastError(); - CloseHandle(hMapping); if (addr == NULL) { + CloseHandle(hMapping); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } @@ -547,9 +556,17 @@ struct llama_mmap::impl { } ~impl() { - if (!UnmapViewOfFile(addr)) { - LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (hMapping) { + if (addr) { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + if (!CloseHandle(hMapping)) { + LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } } #else @@ -611,9 +628,9 @@ struct llama_mlock::impl { char* errmsg = std::strerror(errno); bool suggest = (errno == ENOMEM); -#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) - // visionOS/tvOS dont't support RLIMIT_MEMLOCK - // Skip resource limit checks on visionOS/tvOS +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) || defined(__HAIKU__) + // visionOS/tvOS/Haiku don't support RLIMIT_MEMLOCK + // Skip resource limit checks on these platforms suggest = false; #else struct rlimit lock_limit; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index e66febaa..413f34c2 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1,11 +1,17 @@ #include "llama-model-loader.h" +#include "ggml-alloc.h" #include "ggml.h" +#include "gguf.h" +#include "llama-hparams.h" +#include #include #include +#include #include #include +#include static const size_t kiB = 1024; static const size_t MiB = 1024*kiB; @@ -36,6 +42,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE"; + case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; @@ -262,7 +269,7 @@ namespace GGUFMeta { template typename std::enable_if::value, bool>::type llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const int kid = gguf_find_key(metadata, key.c_str()); if (kid < 0) { if (required) { @@ -272,7 +279,7 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(metadata, kid); result = arr_info.length; @@ -289,7 +296,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { - const gguf_context * ctx = meta.get(); + const gguf_context * ctx = metadata; const int kid = gguf_find_key(ctx, key.c_str()); if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { @@ -330,7 +337,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { - const gguf_context * ctx = meta.get(); + const gguf_context * ctx = metadata; const int kid = gguf_find_key(ctx, key.c_str()); if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { @@ -344,6 +351,7 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { + case GGUF_TYPE_BOOL: case GGUF_TYPE_UINT32: case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || (std::is_same::value)); break; @@ -365,7 +373,13 @@ namespace GGUFMeta { result[i] = value; } } else { - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if (arr_info.gt == GGUF_TYPE_BOOL) { + std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { + return static_cast(x); + }); + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } } return true; @@ -385,7 +399,7 @@ namespace GGUFMeta { const struct llama_model_kv_override * override = it != kv_overrides.end() ? &it->second : nullptr; - const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override); + const bool found = GGUFMeta::GKV::set(metadata, key, result, override); if (required && !found) { throw std::runtime_error(format("key not found in model: %s", key.c_str())); @@ -419,7 +433,7 @@ namespace GGUFMeta { // get array of n <= N_MAX elements, or a single element repeated n times template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const int kid = gguf_find_key(metadata, key.c_str()); if (kid < 0) { if (required) { @@ -432,9 +446,9 @@ namespace GGUFMeta { throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); } - if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) { + if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(metadata, kid); if (n != arr_info.length) { throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); @@ -465,7 +479,7 @@ namespace GGUFMeta { bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) { const std::string key = llm_kv(kid); - const int id = gguf_find_key(meta.get(), key.c_str()); + const int id = gguf_find_key(metadata, key.c_str()); if (id < 0) { if (required) { @@ -475,7 +489,7 @@ namespace GGUFMeta { } // throw and error if type is an array - if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) { + if (gguf_get_kv_type(metadata, id) == GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str())); } @@ -492,6 +506,9 @@ namespace GGUFMeta { llama_model_loader::llama_model_loader( + struct gguf_context * meta, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & fname, std::vector & splits, bool use_mmap, @@ -499,7 +516,8 @@ llama_model_loader::llama_model_loader( bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, - const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) + : metadata(meta), set_tensor_data(set_tensor_data), set_tensor_data_ud(set_tensor_data_ud) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -513,130 +531,142 @@ llama_model_loader::llama_model_loader( tensor_buft_overrides = param_tensor_buft_overrides_p; - // Load the main GGUF - struct ggml_context * ctx = NULL; - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx, - }; + if (!fname.empty()) { + // Load the main GGUF + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; - meta.reset(gguf_init_from_file(fname.c_str(), params)); - if (!meta) { - throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); - } - - get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); - llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - - files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); - contexts.emplace_back(ctx); - - use_direct_io = use_direct_io && files.back()->has_direct_io(); - - // Disable mmap in case Direct I/O is enabled and available - if (use_direct_io && use_mmap) { - use_mmap = false; - LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); - } - - // Save tensors data offset of the main file. - // For subsidiary files, `meta` tensor data offset must not be used, - // so we build a unified tensors index for weights. - for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string tensor_name = std::string(cur->name); - // make sure there is no duplicated tensor names - if (weights_map.find(tensor_name) != weights_map.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); - } - n_elements += ggml_nelements(cur); - n_bytes += ggml_nbytes(cur); - weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); - } - uint16_t n_split = 0; - get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); - - // Load additional GGML contexts - if (n_split > 1) { - // make sure the main file is loaded first - uint16_t idx = 0; - const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); - get_key(kv_split_no, idx); - if (idx != 0) { - throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); + metadata_ptr.reset(gguf_init_from_file(fname.c_str(), params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); } - // generate list of splits if needed - if (splits.empty()) { - splits = llama_get_list_splits(fname, idx, n_split); + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); + contexts.emplace_back(ctx); + + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; + } else { + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + use_direct_io = false; + + // reopen file using std::fopen for mmap + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + } } - // in case user give a custom list of splits, check if it matches the expected number - if (n_split != (uint16_t)splits.size()) { - throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); - if (trace > 0) { - LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); - } - - // load other splits - for (idx = 1; idx < n_split; idx++) { - const char * fname_split = splits[idx].c_str(); - - struct gguf_init_params split_params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx, - }; - gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; - if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); + // Load additional GGML contexts + if (n_split > 1) { + // make sure the main file is loaded first + uint16_t idx = 0; + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); } - // check idx + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } + + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); + } + + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + // load other splits + for (idx = 1; idx < n_split; idx++) { + const char * fname_split = splits[idx].c_str(); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); + } + + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } + } + + files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); + } + } + + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + + // sanity check { - const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); - if (kid < 0) { - throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); - } - int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); - if (idx_gguf != idx) { - throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + const int n_tensors_loaded = (int) weights_map.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); } } - files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); - contexts.emplace_back(ctx); - - // Save tensors data offset info of the shard. - for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string tensor_name = std::string(cur->name); - // make sure there is no duplicated tensor names - if (weights_map.find(tensor_name) != weights_map.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); - } - n_elements += ggml_nelements(cur); - n_bytes += ggml_nbytes(cur); - weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); - } + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); } - - get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); - - // sanity check - { - const int n_tensors_loaded = (int) weights_map.size(); - if (n_tensors != n_tensors_loaded) { - throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); - } - } - - LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } else { + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); } - n_kv = gguf_get_n_kv(meta.get()); + n_kv = gguf_get_n_kv(metadata); n_tensors = weights_map.size(); - fver = (enum llama_fver) gguf_get_version(meta.get()); + fver = (enum llama_fver) gguf_get_version(metadata); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); @@ -695,6 +725,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -715,14 +746,14 @@ llama_model_loader::llama_model_loader( LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); for (int i = 0; i < n_kv; i++) { - const char * name = gguf_get_key(meta.get(), i); - const enum gguf_type type = gguf_get_kv_type(meta.get(), i); + const char * name = gguf_get_key(metadata, i); + const enum gguf_type type = gguf_get_kv_type(metadata, i); const std::string type_name = type == GGUF_TYPE_ARRAY - ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(metadata, i)), gguf_get_arr_n(metadata, i)) : gguf_type_name(type); - std::string value = gguf_kv_to_str(meta.get(), i); + std::string value = gguf_kv_to_str(metadata, i); const size_t MAX_VALUE_LEN = 40; if (value.size() > MAX_VALUE_LEN) { value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); @@ -824,15 +855,382 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri return cur; } -struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { - LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); - const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_expert_used = hparams.n_expert_used; + GGML_ASSERT(n_expert_used > 0); + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_ADD_ID: + { + const int n_expert_used = hparams.n_expert_used; + GGML_ASSERT(n_expert_used > 0); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_add_id(ctx, a, w, c); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + const int n_embd_head = hparams.n_embd_head_v(); + const int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd_inp = hparams.n_embd_inp(); + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + case GGML_OP_SCALE: + { + op_tensor = ggml_scale(ctx, w, 1.0f); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; +} + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t * buft_list) { + GGML_ASSERT(!buft_list->empty()); + for (const auto & cur : *buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + + return nullptr; +} + +struct ggml_tensor * llama_model_loader::create_tensor( + const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output, + const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // one ggml context per buffer type + int max_n_tensors = n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += hparams.n_layer*2; // duplicated rope freq tensors + if (files.empty()) { + max_n_tensors += hparams.n_layer*256; // this should be well above what any model actually uses + } + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map.emplace(buft, ctx); + + return ctx; + } + return it->second.get(); + }; + + auto buft_for_tensor = [&](ggml_tensor * t_meta) -> ggml_backend_buffer_type_t { + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && (flags & TENSOR_DUPLICATED)) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // skip unused tensors + if (info.op == GGML_OP_NONE || (flags & TENSOR_SKIP)) { + const size_t nbytes = ggml_nbytes(t_meta); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); + + size_data -= nbytes; + n_created++; + + return nullptr; + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + if (info.op == GGML_OP_MUL_MAT_ID) { + op = GGML_OP_ADD_ID; + } else { + op = GGML_OP_ADD; + } + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + const buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = buft_list_input; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = buft_list_output; + break; + case LLM_TENSOR_LAYER_REPEATING: + GGML_ASSERT(buft_list_layer != nullptr); + buft_list = buft_list_layer; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { + if (overrides->buft == ggml_backend_cpu_buffer_type()) { + // when overriding to a CPU buffer, consider the extra buffer types + buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu); + } else { + buft = overrides->buft; + } + + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", + tensor_name.c_str(), + ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), + ggml_backend_buft_name(buft)); + break; + } + } + } + + if (!buft) { + buft = select_weight_buft(hparams, t_meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + if (n_tensors_moved == 0) { + first_tensor_moved_name = t_meta->name; + first_tensor_moved_type_name = ggml_type_name(t_meta->type); + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + n_tensors_moved++; + } + + return buft; + }; + + if (files.empty()) { + if (flags & TENSOR_SKIP_IF_VIRTUAL) { + return nullptr; + } + ggml_type type = GGML_TYPE_F32; + const int64_t tid = gguf_find_tensor(metadata, tn.str().c_str()); + if (tid != -1) { + type = gguf_get_tensor_type(metadata, tid); + } + + // for tensors that are not required some of the dimensions can be invalid: + if (flags & TENSOR_NOT_REQUIRED) { + for (size_t dim = 0; dim < ne.size(); dim++) { + if (ne.begin()[dim] <= 0) { + return nullptr; + } + } + } + + ggml_tensor t_meta; + memset(&t_meta, 0, sizeof(ggml_tensor)); + t_meta.type = type; + for (size_t dim = 0; dim < GGML_MAX_DIMS; dim++) { + t_meta.ne[dim] = dim < ne.size() ? ne.begin()[dim] : 1; + GGML_ASSERT(t_meta.ne[dim] >= 1); + t_meta.nb[dim] = dim == 0 ? ggml_type_size(type) : t_meta.ne[dim-1]*t_meta.nb[dim-1]; + GGML_ASSERT(t_meta.nb[dim] >= 1); + } + ggml_set_name(&t_meta, tn.str().c_str()); + + ggml_backend_buffer_type_t buft = buft_for_tensor(&t_meta); + GGML_ASSERT(buft != nullptr); + ggml_context * ctx = ctx_for_buft(buft); + ggml_tensor * ret = ggml_dup_tensor(ctx, &t_meta); + ggml_set_name(ret, tn.str().c_str()); + return ret; + } + + ggml_tensor * t_meta = get_tensor_meta(tn.str().c_str()); + ggml_backend_buffer_type_t buft = buft_for_tensor(t_meta); + if (buft == nullptr) { + return nullptr; // return type is ggml_tensor * + } + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, tn.str().c_str()); + const struct ggml_tensor * cur = check_tensor_dims(tn.str(), ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { return NULL; } - bool duplicated = flags & TENSOR_DUPLICATED; + const bool duplicated = flags & TENSOR_DUPLICATED; struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); ggml_set_name(tensor, ggml_get_name(cur)); @@ -844,7 +1242,6 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx } return tensor; - } struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required) { @@ -879,6 +1276,11 @@ void llama_model_loader::done_getting_tensors() const { if (n_created != n_tensors) { throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); } + if (n_tensors_moved > 0) { + LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", + __func__, first_tensor_moved_name.c_str(), first_tensor_moved_type_name.c_str(), n_tensors_moved - 1, + ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); + } } void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) { @@ -960,6 +1362,12 @@ bool llama_model_loader::load_all_data( llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data) { + if (files.empty()) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + set_tensor_data(t, set_tensor_data_ud); + } + return true; + } GGML_ASSERT(size_data != 0 && "call init_mappings() first"); std::vector> read_buf; diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 65953dd3..ed5de729 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -4,17 +4,22 @@ #include "llama-impl.h" #include "llama-arch.h" +#include "llama-hparams.h" #include "llama-mmap.h" #include "ggml-cpp.h" #include +#include #include #include #include using llama_buf_map = std::unordered_map; +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + enum llama_fver { GGUF_FILE_VERSION_V1 = 1, GGUF_FILE_VERSION_V2 = 2, @@ -58,9 +63,10 @@ struct llama_model_loader { } }; - static const int TENSOR_NOT_REQUIRED = 1 << 0; - static const int TENSOR_DUPLICATED = 1 << 1; - static const int TENSOR_SKIP = 1 << 2; + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; + static const int TENSOR_SKIP_IF_VIRTUAL = 1 << 3; int n_kv = 0; int n_tensors = 0; @@ -84,7 +90,10 @@ struct llama_model_loader { std::unordered_map kv_overrides; const llama_model_tensor_buft_override * tensor_buft_overrides; - gguf_context_ptr meta; + gguf_context_ptr metadata_ptr; + struct gguf_context * metadata; // either metadata_ptr.get() or externally set + llama_model_set_tensor_data_t set_tensor_data; + void * set_tensor_data_ud; std::vector contexts; std::string arch_name; @@ -94,7 +103,26 @@ struct llama_model_loader { size_t size_data = 0; std::vector> mmaps_used; + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + + std::map ctx_map; + + // track tensors that had to be moved for debugging: + size_t n_tensors_moved = 0; + std::string first_tensor_moved_name; + std::string first_tensor_moved_type_name; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + llama_model_loader( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, @@ -149,7 +177,9 @@ struct llama_model_loader { const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0); + struct ggml_tensor * create_tensor( + const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output, + const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags); struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index ae27c71c..6f6538ae 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -7,14 +7,19 @@ #include "llama-model.h" #include "llama-vocab.h" +#include #include -llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { - gguf_ctx = gguf_init_empty(); -} +llama_model_saver::llama_model_saver(const struct llama_model * model) : + gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) {} + +llama_model_saver::llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx) : + gguf_ctx(gguf_ctx == nullptr ? gguf_init_empty() : gguf_ctx), gguf_ctx_owned(gguf_ctx == nullptr), model(nullptr), llm_kv(arch) {} llama_model_saver::~llama_model_saver() { - gguf_free(gguf_ctx); + if (gguf_ctx_owned) { + gguf_free(gguf_ctx); + } } void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { @@ -46,7 +51,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const char value) { template void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { - const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(model != nullptr || !per_layer); + const size_t n_values = per_layer ? size_t(model->hparams.n_layer) : value.size(); GGML_ASSERT(n_values <= value.size()); if (n_values == 0) { @@ -83,6 +89,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c GGML_ABORT("fatal error"); } } +// instantiate for external usage: +template void llama_model_saver::add_kv>(const enum llm_kv, const std::vector &, const bool); void llama_model_saver::add_kv(const enum llm_kv key, const std::vector & value) { std::vector tmp(value.size()); @@ -104,37 +112,39 @@ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { } void llama_model_saver::add_kv_from_model() { - const llama_hparams & hparams = model.hparams; - const llama_vocab & vocab = model.vocab; + const llama_hparams & hparams = model->hparams; + const llama_vocab & vocab = model->vocab; const int32_t n_vocab = vocab.n_tokens(); std::vector tokens(n_vocab); std::vector scores(n_vocab); std::vector token_types(n_vocab); - for (int32_t id = 0; id < n_vocab; ++id) { - const llama_vocab::token_data & token_data = vocab.get_token_data(id); + if (vocab.get_type() != LLAMA_VOCAB_TYPE_NONE) { + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); - tokens[id] = token_data.text; - scores[id] = token_data.score; + tokens[id] = token_data.text; + scores[id] = token_data.score; - switch(token_data.attr) { - case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; - case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; - case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; - case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; - case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; - case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; - case LLAMA_TOKEN_ATTR_UNDEFINED: - default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } } } // add_kv(LLM_KV_GENERAL_TYPE, ???); - add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model->arch_name()); // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); - add_kv(LLM_KV_GENERAL_NAME, model.name); + add_kv(LLM_KV_GENERAL_NAME, model->name); // add_kv(LLM_KV_GENERAL_AUTHOR, ???); // add_kv(LLM_KV_GENERAL_VERSION, ???); // add_kv(LLM_KV_GENERAL_URL, ???); @@ -146,8 +156,8 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - if (hparams.n_embd_out > 0) { - add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out); + if (hparams.n_embd_out_impl > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); @@ -176,8 +186,10 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); - add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); - add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); @@ -189,7 +201,8 @@ void llama_model_saver::add_kv_from_model() { const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; - add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full); + add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa); add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); @@ -255,24 +268,25 @@ void llama_model_saver::add_kv_from_model() { } void llama_model_saver::add_tensors_from_model() { - if (std::string(model.output->name) != std::string(model.tok_embd->name)) { - add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + if (std::string(model->output->name) != std::string(model->tok_embd->name)) { + add_tensor(model->tok_embd); // some models use the same tensor for tok_embd and output } - add_tensor(model.type_embd); - add_tensor(model.pos_embd); - add_tensor(model.tok_norm); - add_tensor(model.tok_norm_b); - add_tensor(model.output_norm); - add_tensor(model.output_norm_b); - add_tensor(model.output); - add_tensor(model.output_b); - add_tensor(model.output_norm_enc); - add_tensor(model.cls); - add_tensor(model.cls_b); - add_tensor(model.cls_out); - add_tensor(model.cls_out_b); + add_tensor(model->type_embd); + add_tensor(model->pos_embd); + add_tensor(model->tok_norm); + add_tensor(model->tok_norm_b); + add_tensor(model->output_norm); + add_tensor(model->output_norm_b); + add_tensor(model->output); + add_tensor(model->output_b); + add_tensor(model->output_norm_enc); + add_tensor(model->cls); + add_tensor(model->cls_b); + add_tensor(model->cls_out); + add_tensor(model->cls_out_b); + add_tensor(model->cls_norm); - for (const struct llama_layer & layer : model.layers) { + for (const struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { add_tensor(reinterpret_cast(&layer)[i]); } diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h index a5a434c3..2b3541ce 100644 --- a/examples/talk-llama/llama-model-saver.h +++ b/examples/talk-llama/llama-model-saver.h @@ -1,5 +1,6 @@ #pragma once +#include "gguf.h" #include "llama.h" #include "llama-arch.h" @@ -7,10 +8,12 @@ struct llama_model_saver { struct gguf_context * gguf_ctx = nullptr; - const struct llama_model & model; + const bool gguf_ctx_owned; + const struct llama_model * model; const struct LLM_KV llm_kv; - llama_model_saver(const struct llama_model & model); + llama_model_saver(const struct llama_model * model); + llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx); ~llama_model_saver(); void add_kv(enum llm_kv key, uint32_t value); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index f6cea8f8..e8e1bbf1 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,5 +1,6 @@ #include "llama-model.h" +#include "ggml.h" #include "llama-impl.h" #include "llama-mmap.h" #include "llama-cparams.h" @@ -8,6 +9,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include "ggml-cpp.h" @@ -17,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -60,6 +63,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; + case LLM_TYPE_0_8B: return "0.8B"; case LLM_TYPE_1B: return "1B"; case LLM_TYPE_1_2B: return "1.2B"; case LLM_TYPE_1_3B: return "1.3B"; @@ -122,17 +126,25 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_24B_A2B: return "24B.A2B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_35B_A3B: return "35B.A3B"; + case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_120B_A12B: return "120B.A12B"; + case LLM_TYPE_122B_A10B: return "122B.A10B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_397B_A17B: return "397B.A17B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -168,160 +180,6 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -// checks if the weight tensor can be used with the specified buffer type and device -static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { - GGML_ASSERT(w != nullptr); - - if (op == GGML_OP_NONE) { - return true; - } - - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error(format("failed to create ggml context")); - } - ggml_context * ctx = ctx_ptr.get(); - - ggml_tensor * op_tensor = nullptr; - - switch (op) { - case GGML_OP_GET_ROWS: - { - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_get_rows(ctx, w, b); - } break; - case GGML_OP_MUL_MAT: - { - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); - op_tensor = ggml_mul_mat(ctx, w, b); - } break; - case GGML_OP_MUL_MAT_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_mul_mat_id(ctx, w, b, ids); - } break; - case GGML_OP_ADD: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_add(ctx, a, w); - } break; - case GGML_OP_ADD_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_add_id(ctx, a, w, c); - } break; - case GGML_OP_MUL: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_mul(ctx, a, w); - } break; - case GGML_OP_DIV: - { - ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); - op_tensor = ggml_div(ctx, a, w); - } break; - case GGML_OP_ROPE: - { - int n_embd_head = hparams.n_embd_head_v; - int n_head = hparams.n_head(); - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_rope_ext( - ctx, a, b, w, - 0, 0, 0, 0, 0, - 0, 0, 0, 0 - ); - - } break; - case GGML_OP_SSM_CONV: - { - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 3; - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); - op_tensor = ggml_ssm_conv(ctx, conv_x, w); - } break; - case GGML_OP_SSM_SCAN: - { - // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 - const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; - const int64_t n_head = w->ne[1]; - const int64_t head_dim = hparams.ssm_d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 3; - ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); - } break; - case GGML_OP_RWKV_WKV6: - { - // FIXME - const int64_t S = 123; - const int64_t H = 123; - const int64_t n_tokens = 123; - const int64_t n_seqs = 123; - ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * tf = w; - ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); - } break; - case GGML_OP_IM2COL: - { - const int n_embd_inp = hparams.n_embd_inp(); - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); - op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); - } break; - case GGML_OP_SCALE: - { - op_tensor = ggml_scale(ctx, w, 1.0f); - } break; - default: - GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); - } - - // create a temporary dummy buffer for the weight so that supports_op can check the buffer type - GGML_ASSERT(w->buffer == nullptr); - w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - ggml_backend_buffer_free(w->buffer); - w->buffer = nullptr; - - return op_supported; -} - -// lists of buffer types used for each layer -using buft_list_t = std::vector>; - -// find the first buffer type in the list that can use the tensor -static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { - GGML_ASSERT(!buft_list.empty()); - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { - return cur_buft; - } - } - - return nullptr; -} - // CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; @@ -446,7 +304,7 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + // contexts where the model tensors metadata is stored as well as the corresponding buffers: std::vector>> ctxs_bufs; buft_list_t cpu_buft_list; @@ -468,7 +326,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() = default; +llama_model::~llama_model() { + for (auto * lora : loras) { + delete lora; + } +} void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -483,7 +345,7 @@ void llama_model::load_arch(llama_model_loader & ml) { } void llama_model::load_hparams(llama_model_loader & ml) { - const gguf_context * ctx = ml.meta.get(); + const gguf_context * ctx = ml.metadata; // get metadata as string for (int i = 0; i < gguf_get_n_kv(ctx); i++) { @@ -507,7 +369,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out, false); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); @@ -515,7 +377,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { - ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); @@ -554,6 +417,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -595,26 +460,37 @@ void llama_model::load_hparams(llama_model_loader & ml) { // gpt-neox n_rot = rotary_pct * (n_embd / n_head) // gpt-j n_rot = rotary_dim - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); // sanity check for n_rot (optional) - hparams.n_rot = hparams.n_embd_head_k; + hparams.n_rot_full = hparams.n_embd_head_k_full; - ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false); if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) { - if (hparams.n_rot != hparams.n_embd_head_k) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + if (hparams.n_rot_full != hparams.n_embd_head_k_full) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full)); } } } else { - hparams.n_rot = 0; - hparams.n_embd_head_k = 0; - hparams.n_embd_head_v = 0; + hparams.n_rot_full = 0; + hparams.n_embd_head_k_full = 0; + hparams.n_embd_head_v_full = 0; + } + + // head size and n_rot for SWA layers + { + hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full; + hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full; + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false); + + hparams.n_rot_swa = hparams.n_rot_full; + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false); } // for differentiating model types @@ -674,7 +550,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; - hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -711,7 +589,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_AFMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); @@ -723,7 +601,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) if (hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -868,7 +748,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -891,18 +771,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { - uint32_t swa_period = 3; hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 3; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); + hparams.set_swa_pattern(swa_period, true); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -918,7 +797,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; @@ -931,7 +810,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -944,8 +823,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); if (hparams.n_layer == 12 && hparams.n_embd == 768) { @@ -959,13 +838,23 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NEO_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 28) { type = LLM_TYPE_250M; } } break; + case LLM_ARCH_EUROBERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + if (hparams.n_layer == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -988,7 +877,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; @@ -1237,19 +1126,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { break; default: type = LLM_TYPE_UNKNOWN; } - - // Load attention parameters - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_PLAMO3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { - uint32_t swa_period = 8; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 8; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); } else { @@ -1312,7 +1197,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; // default value of gemma 2 - hparams.set_swa_pattern(2); + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.attn_soft_cap = true; hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -1333,14 +1220,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); } break; case LLM_ARCH_GEMMA3: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { @@ -1364,12 +1253,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); } break; case LLM_ARCH_GEMMA3N: { + uint32_t swa_period = 5; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(5); + hparams.set_swa_pattern(swa_period); hparams.n_layer_kv_from_start = 20; hparams.f_attention_scale = 1.0f; @@ -1387,14 +1278,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA_EMBEDDING: { hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - hparams.set_swa_pattern(6); + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.causal_attn = false; // embeddings do not use causal attention ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); //applied only if model converted with --sentence-transformers-dense-modules ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); @@ -1409,7 +1302,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 24: type = LLM_TYPE_0_3B; break; default: type = LLM_TYPE_UNKNOWN; } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); } break; case LLM_ARCH_STARCODER2: @@ -1501,7 +1394,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. case 12: // 900M 8x???M case 32: // 51B 16x?B default: type = LLM_TYPE_UNKNOWN; @@ -1519,7 +1412,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_COMMAND_R: { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 40: type = LLM_TYPE_35B; break; @@ -1529,7 +1422,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_COHERE2: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -1571,7 +1466,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp @@ -1678,10 +1575,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); switch (hparams.n_ff_exp) { case 1408: type = LLM_TYPE_16B; break; @@ -1691,16 +1588,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); @@ -1709,7 +1607,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } } if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { @@ -1726,6 +1629,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1765,7 +1669,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + switch (hparams.n_layer) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1782,7 +1694,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); // Expert gating function (GLM-4.5 uses sigmoid) @@ -1804,6 +1716,50 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM_DSA: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1857,7 +1813,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JAIS: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); switch (hparams.n_layer) { case 24: type = LLM_TYPE_1_3B; break; @@ -1866,6 +1822,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAIS2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1896,10 +1862,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); switch (hparams.n_layer) { case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; + case 88: type = LLM_TYPE_120B_A12B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1917,7 +1885,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.n_layer == 64) { // 32B hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -1933,6 +1903,36 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE_MOE: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: + case 49: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -2006,9 +2006,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); // Granite uses rope_finetuned as a switch for rope, so default to true bool rope_finetuned = true; @@ -2066,7 +2066,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default - ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; @@ -2079,15 +2079,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); } break; case LLM_ARCH_BAILINGMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); switch (hparams.n_layer) { @@ -2099,11 +2099,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BAILINGMOE2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); @@ -2122,10 +2122,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); switch (hparams.n_layer) { @@ -2135,13 +2135,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (arch == LLM_ARCH_ERNIE4_5_MOE) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); } switch (hparams.n_layer) { @@ -2186,7 +2190,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); switch (hparams.n_layer) { case 32: type = LLM_TYPE_A13B; break; @@ -2222,7 +2226,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(2); + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -2249,12 +2255,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } } break; case LLM_ARCH_LFM2MOE: { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); @@ -2262,16 +2274,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } - type = LLM_TYPE_8B_A1B; + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - hparams.set_swa_pattern(4, true); + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -2294,7 +2312,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GROVEMOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2359,8 +2377,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } } switch (hparams.n_layer) { @@ -2368,6 +2390,65 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; + case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; + case 64: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN35MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_122B_A10B; break; + case 60: type = LLM_TYPE_397B_A17B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_MISTRAL3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2402,7 +2483,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); switch (hparams.n_layer) { @@ -2410,7 +2491,69 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; - default: throw std::runtime_error("unsupported model architecture"); + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STEP35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // full_attention layer only use half of the RoPE dimensions + hparams.n_rot_full = hparams.n_rot_full / 2; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + default: throw std::runtime_error("unsupported model architecture: " + arch_name()); } pimpl->n_bytes = ml.n_bytes; @@ -2517,44 +2660,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // assign the output layer pimpl->dev_output = get_layer_buft_list(n_layer); - // one ggml context per buffer type - int max_n_tensors = ml.n_tensors; - max_n_tensors += 1; // duplicated output tensor - max_n_tensors += n_layer*2; // duplicated rope freq tensors - const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; - - // define a comparator for the buft -> ctx map to ensure that the order is well-defined: - struct ggml_backend_buft_comparator { - bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; - } - }; - std::map ctx_map; - - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ctx_map.emplace(buft, ctx); - - return ctx; - } - return it->second.get(); - }; - - const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; - const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; + const auto TENSOR_SKIP_IF_VIRTUAL = llama_model_loader::TENSOR_SKIP_IF_VIRTUAL; // create tensors for the weights { @@ -2564,13 +2673,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd = hparams.n_embd; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k(); + const int64_t n_embd_head_v = hparams.n_embd_head_v(); const int64_t n_ff = hparams.n_ff(); const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = vocab.n_tokens(); const int64_t n_token_types = vocab.n_token_types(); - const int64_t n_rot = hparams.n_rot; + const int64_t n_rot = hparams.n_rot(); const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -2579,153 +2688,26 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error("model has expert layers but no expert layers are used"); } - int n_moved_tensors = 0; - ggml_tensor * first_moved_tensor = nullptr; - ggml_backend_buffer_type_t first_moved_from_buft = nullptr; - ggml_backend_buffer_type_t first_moved_to_buft = nullptr; - auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); - - if (!t_meta) { - if (flags & TENSOR_NOT_REQUIRED) { - return nullptr; - } - throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); - } - - // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops - // the tensor is duplicated - // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor - llm_tensor tn_tensor = tn.tensor; - if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { - tn_tensor = LLM_TENSOR_OUTPUT; - } - - llm_tensor_info info; - try { - info = llm_tensor_info_for(tn_tensor); - } catch (const std::out_of_range & e) { - throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); - } - - // skip unused tensors - if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) { - const size_t nbytes = ggml_nbytes(t_meta); - LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); - - ml.size_data -= nbytes; - ml.n_created++; - - return nullptr; - } - - // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID - ggml_op op; - bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; - if (bias) { - if (info.op == GGML_OP_MUL_MAT_ID) { - op = GGML_OP_ADD_ID; - } else { - op = GGML_OP_ADD; - } - } else { - op = info.op; - } - - // sanity checks - if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { - if (tn.bid != -1) { - GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); - } - } else { - if (tn.bid == -1) { - GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); - } - } - - // select the buffer type for this tensor - buft_list_t * buft_list; - switch (info.layer) { - case LLM_TENSOR_LAYER_INPUT: - buft_list = pimpl->dev_input.buft_list; - break; - case LLM_TENSOR_LAYER_OUTPUT: - buft_list = pimpl->dev_output.buft_list; - break; - case LLM_TENSOR_LAYER_REPEATING: - buft_list = pimpl->dev_layer.at(tn.bid).buft_list; - break; - default: - GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); - } - - ggml_backend_buffer_type_t buft = nullptr; - - // check overrides - if (ml.tensor_buft_overrides) { - std::string tensor_name = tn.str(); - for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { - std::regex pattern(overrides->pattern); - if (std::regex_search(tensor_name, pattern)) { - if (overrides->buft == ggml_backend_cpu_buffer_type()) { - // when overriding to a CPU buffer, consider the extra buffer types - buft = select_weight_buft(hparams, t_meta, op, pimpl->cpu_buft_list); - } else { - buft = overrides->buft; - } - - LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", - tensor_name.c_str(), - ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), - ggml_backend_buft_name(buft)); - break; - } - } - } - - if (!buft) { - buft = select_weight_buft(hparams, t_meta, op, *buft_list); - if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); - } - } - - // avoid using a host buffer when using mmap - auto * buft_dev = ggml_backend_buft_get_device(buft); - if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (!cpu_dev) { - throw std::runtime_error("no CPU backend found"); - } - buft = ggml_backend_dev_buffer_type(cpu_dev); - } - - if (buft != buft_list->front().second) { - n_moved_tensors++; - if (!first_moved_tensor) { - first_moved_tensor = t_meta; - first_moved_from_buft = buft_list->front().second; - first_moved_to_buft = buft; - } - } - - ggml_context * ctx = ctx_for_buft(buft); - - // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one - if (flags & TENSOR_DUPLICATED) { - ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); - if (t) { - return t; - } - } - return ml.create_tensor(ctx, tn, ne, flags); + const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; + return ml.create_tensor( + hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, + tn, ne, flags); }; layers.resize(n_layer); // TODO: move to a separate function const auto tn = LLM_TN(arch); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } + }; switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -2879,6 +2861,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_LLAMA4: { + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -2891,7 +2876,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { - bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; + const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; auto & layer = layers[i]; @@ -2907,7 +2892,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); if (is_moe_layer) { - int n_ff_exp = hparams.n_ff_exp; + const int64_t n_ff_exp = hparams.n_ff_exp; layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); @@ -2994,8 +2979,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_MINICPM3: { - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -3038,7 +3023,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_GROK: { if (n_expert == 0) { - throw std::runtime_error("Grok model cannot have zero experts"); + throw std::runtime_error(arch_name() + " model cannot have zero experts"); } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3210,6 +3195,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_JINA_BERT_V3: { + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -3294,9 +3282,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_NEO_BERT: @@ -3325,6 +3314,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; + case LLM_ARCH_EUROBERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings @@ -3452,8 +3464,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + // FIXME test-llama-archs crashes if q_norm is created + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -3839,8 +3852,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); // attention parameters - const uint32_t qk_dim = hparams.n_embd_head_k; - const uint32_t v_dim = hparams.n_embd_head_v; + const uint32_t qk_dim = hparams.n_embd_head_k(); + const uint32_t v_dim = hparams.n_embd_head_v(); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3900,8 +3913,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO3: { - const int64_t head_dim_q = hparams.n_embd_head_k; - const int64_t head_dim_v = hparams.n_embd_head_v; + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4648,7 +4661,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_SEED_OSS: { - const uint32_t head_dim = hparams.n_embd_head_k; + const uint32_t head_dim = hparams.n_embd_head_k(); const int64_t n_qo_dim = n_head * head_dim; const int64_t n_kv_dim = n_head_kv * head_dim; @@ -4871,17 +4884,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + GGML_ASSERT(n_embd_head_qk_nope >= 1); const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -4903,13 +4914,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); } layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { @@ -4946,9 +4957,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); @@ -4959,8 +4969,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLM: { - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); const int64_t kv_lora_rank = hparams.n_lora_kv; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5000,23 +5010,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_T5: @@ -5074,7 +5084,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + layer.attn_rel_b_cross = create_tensor( + tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); @@ -5152,6 +5163,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); } } break; + case LLM_ARCH_JAIS2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } + } break; case LLM_ARCH_CHATGLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5202,30 +5252,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + auto & layer = layers[i]; - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + } - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_GLM4_MOE: @@ -5329,6 +5397,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM_DSA: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5377,6 +5547,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_ssm_head = hparams.ssm_dt_rank; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5436,8 +5607,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); // Shared expert branch layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); @@ -5516,6 +5690,84 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_EXAONE_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; + const int64_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5806,9 +6058,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); // posnet @@ -5904,8 +6156,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); } break; case LLM_ARCH_BAILINGMOE: { @@ -6161,6 +6413,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6303,6 +6556,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6321,9 +6575,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; case LLM_ARCH_HUNYUAN_DENSE: @@ -6481,7 +6735,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -6637,6 +6892,141 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_KIMI_LINEAR: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + if (hparams.is_recurrent(i)) { + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6718,6 +7108,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3NEXT: { + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -6746,6 +7140,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); @@ -6776,15 +7171,138 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared experts layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN35MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; case LLM_ARCH_MIMO2: @@ -6825,6 +7343,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_STEP35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_MAINCODER: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6860,10 +7444,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error("unknown architecture"); } - if (n_moved_tensors > 0) { - LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n", - __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, - ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); + // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) + // this avoids having to add scale loading to every architecture + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // attention weight scales (per-tensor, shape {1}) + if (!layer.wq_s && layer.wq) { + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_s && layer.wk) { + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_s && layer.wv) { + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_s && layer.wo) { + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + + // dense FFN weight scales (per-tensor, shape {1}) + if (!layer.ffn_gate_s && layer.ffn_gate) { + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_s && layer.ffn_down) { + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_s && layer.ffn_up) { + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + + // MoE expert weight scales (per-expert, shape {n_expert}) + if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { + layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { + layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } @@ -6874,13 +7520,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // create the backend buffers std::vector> ctx_buf_maps; - ctx_buf_maps.reserve(ctx_map.size()); + ctx_buf_maps.reserve(ml.ctx_map.size()); // Ensure we have enough capacity for the maximum backend buffer we will potentially create - const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); + const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - for (auto & [buft, ctx_ptr] : ctx_map) { + for (auto & [buft, ctx_ptr] : ml.ctx_map) { ggml_context * ctx = ctx_ptr.get(); // skip contexts without tensors @@ -7101,59 +7747,62 @@ void llama_model::print_info() const { }; // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); - LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); + LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); + LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); } - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); } if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; for (auto label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } } @@ -7165,57 +7814,59 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); } // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } if (arch == LLM_ARCH_MINICPM || @@ -7223,41 +7874,41 @@ void llama_model::print_info() const { arch == LLM_ARCH_GRANITE_MOE || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); } vocab.print_info(); @@ -7372,6 +8023,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: @@ -7396,7 +8048,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -7413,23 +8064,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; @@ -7549,6 +8221,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_EUROBERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BLOOM: { llm = std::make_unique(*this, params); @@ -7748,6 +8424,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: { llm = std::make_unique(*this, params); } break; @@ -7790,6 +8467,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_JAIS2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -7811,6 +8492,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique>(*this, params); } } break; + case LLM_ARCH_EXAONE_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_RWKV6: { llm = std::make_unique(*this, params); @@ -7881,6 +8566,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PADDLEOCR: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_HUNYUAN_MOE: { llm = std::make_unique(*this, params); @@ -7904,7 +8593,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -7938,6 +8631,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN35: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN35MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_MISTRAL3: { llm = std::make_unique(*this, params); @@ -7946,12 +8647,20 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_STEP35: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); // add backend sampling layers (if any) llm->build_sampling(); @@ -7960,7 +8669,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize - llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); llm->res->set_outputs(); @@ -7985,7 +8694,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, - /*.use_direct_io =*/ true, + /*.use_direct_io =*/ false, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, @@ -8021,7 +8730,7 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { } int32_t llama_model_n_embd_out(const llama_model * model) { - return model->hparams.get_n_embd_out(); + return model->hparams.n_embd_out(); } int32_t llama_model_n_layer(const llama_model * model) { @@ -8095,6 +8804,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_KIMI_LINEAR: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -8128,6 +8838,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -8140,6 +8851,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_EUROBERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: @@ -8171,10 +8883,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: + case LLM_ARCH_EXAONE_MOE: case LLM_ARCH_MINICPM3: case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_JAIS2: case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: @@ -8189,12 +8903,16 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_PADDLEOCR: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 79200a0d..25bf892e 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -11,6 +11,7 @@ #include #include #include +#include #include struct llama_cparams; @@ -53,6 +54,7 @@ enum llm_type { LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, + LLM_TYPE_0_8B, LLM_TYPE_1B, LLM_TYPE_1_2B, LLM_TYPE_1_3B, @@ -115,17 +117,25 @@ enum llm_type { LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_24B_A2B, // lfm2moe LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_35B_A3B, // Qwen3.5 + LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_120B_A12B, // Nemotron 3 Super + LLM_TYPE_122B_A10B, // Qwen3.5 + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_397B_A17B, // Qwen3.5 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -274,14 +284,25 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; - struct ggml_tensor * ffn_gate_inp_b = nullptr; - struct ggml_tensor * ffn_gate_exps_b = nullptr; - struct ggml_tensor * ffn_down_exps_b = nullptr; - struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_up_exps_b = nullptr; + + // ff MoE per-expert scales (NVFP4 per-tensor scale2) + struct ggml_tensor * ffn_gate_exps_s = nullptr; + struct ggml_tensor * ffn_down_exps_s = nullptr; + struct ggml_tensor * ffn_up_exps_s = nullptr; + + // ff MoE latent proj + struct ggml_tensor * ffn_latent_down = nullptr; + struct ggml_tensor * ffn_latent_up = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; @@ -319,6 +340,9 @@ struct llama_layer { // qwen3next struct ggml_tensor * ssm_beta_alpha = nullptr; + // qwen3.5 + struct ggml_tensor * ssm_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -373,13 +397,21 @@ struct llama_layer { struct ggml_tensor * rope_freqs = nullptr; // bitnet scale - struct ggml_tensor * wq_scale = nullptr; - struct ggml_tensor * wk_scale = nullptr; - struct ggml_tensor * wv_scale = nullptr; - struct ggml_tensor * wo_scale = nullptr; - struct ggml_tensor * ffn_gate_scale = nullptr; - struct ggml_tensor * ffn_up_scale = nullptr; - struct ggml_tensor * ffn_down_scale = nullptr; + struct ggml_tensor * wq_s = nullptr; + struct ggml_tensor * wk_s = nullptr; + struct ggml_tensor * wv_s = nullptr; + struct ggml_tensor * wo_s = nullptr; + struct ggml_tensor * wqkv_s = nullptr; + struct ggml_tensor * wqkv_gate_s = nullptr; + struct ggml_tensor * ffn_gate_s = nullptr; + struct ggml_tensor * ffn_up_s = nullptr; + struct ggml_tensor * ffn_down_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_s = nullptr; + struct ggml_tensor * ffn_up_shexp_s = nullptr; + struct ggml_tensor * ffn_down_shexp_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_alpha_s = nullptr; + struct ggml_tensor * ssm_beta_s = nullptr; // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; @@ -410,6 +442,25 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -448,6 +499,7 @@ struct llama_model { struct ggml_tensor * cls_b = nullptr; struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + struct ggml_tensor * cls_norm = nullptr; struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; @@ -464,8 +516,9 @@ struct llama_model { //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models - struct ggml_tensor * dense_2_out_layers = nullptr; - struct ggml_tensor * dense_3_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers_b = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; // gguf metadata std::unordered_map gguf_kv; @@ -476,8 +529,8 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; - // for keeping track of extra nodes used by lora adapters - uint32_t n_lora_nodes = 0; + // for keeping track of associated LoRA adapters + std::unordered_set loras; int64_t t_load_us = 0; int64_t t_start_us = 0; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 048d65a7..8e8ce231 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -1,11 +1,11 @@ -#include "llama-quant.h" +#include "llama.h" #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" -#include #include #include +#include #include #include #include @@ -13,10 +13,28 @@ #include #include -// Quantization types. Changes to this struct must be replicated in quantize.cpp -struct tensor_quantization { +// result of parsing --tensor-type option +// (changes to this struct must be reflected in tools/quantize/quantize.cpp) +struct tensor_type_option { std::string name; - ggml_type quant = GGML_TYPE_COUNT; + ggml_type type = GGML_TYPE_COUNT; +}; + +// tensor categorization - used to avoid repeated string matching in quantization logic. +// this is different from LLM_TN - we want broad categories, not specific tensor names per arch. +enum class tensor_category { + TOKEN_EMBD, + ATTENTION_Q, + ATTENTION_V, + ATTENTION_K, + ATTENTION_QKV, + ATTENTION_KV_B, + ATTENTION_OUTPUT, + FFN_UP, + FFN_GATE, + FFN_DOWN, + OUTPUT, + OTHER }; static void zeros(std::ofstream & file, size_t n) { @@ -54,7 +72,7 @@ static std::string remap_layer(const std::string & orig_name, const std::vector< return orig_name; } -static std::string remap_imatrix (const std::string & orig_name, const std::map & mapped) { +static std::string remap_imatrix(const std::string & orig_name, const std::map & mapped) { if (mapped.empty()) { return orig_name; } @@ -76,6 +94,73 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map< return orig_name; } +// +// helper functions for tensor name matching +// + +static bool tensor_name_match_token_embd(const char * tensor_name) { + return std::strcmp(tensor_name, "token_embd.weight") == 0 || + std::strcmp(tensor_name, "per_layer_token_embd.weight") == 0; +} + +static bool tensor_name_match_output_weight(const char * tensor_name) { + return std::strcmp(tensor_name, "output.weight") == 0; +} + +// +// tensor categorization for quantization +// +// (this is different from LLM_TN - we want broad categories, not specific tensor names per arch) +// + +static tensor_category tensor_get_category(const std::string & tensor_name) { + if (tensor_name_match_output_weight(tensor_name.c_str())) { + return tensor_category::OUTPUT; + } + if (tensor_name_match_token_embd(tensor_name.c_str())) { + return tensor_category::TOKEN_EMBD; + } + if (tensor_name.find("attn_qkv.weight") != std::string::npos) { + return tensor_category::ATTENTION_QKV; + } + if (tensor_name.find("attn_kv_b.weight") != std::string::npos) { + return tensor_category::ATTENTION_KV_B; + } + if (tensor_name.find("attn_v.weight") != std::string::npos) { + return tensor_category::ATTENTION_V; + } + if (tensor_name.find("attn_k.weight") != std::string::npos) { + return tensor_category::ATTENTION_K; + } + if (tensor_name.find("attn_q.weight") != std::string::npos) { + return tensor_category::ATTENTION_Q; + } + if (tensor_name.find("attn_output.weight") != std::string::npos) { + return tensor_category::ATTENTION_OUTPUT; + } + if (tensor_name.find("ffn_up") != std::string::npos) { + return tensor_category::FFN_UP; + } + if (tensor_name.find("ffn_gate") != std::string::npos) { + return tensor_category::FFN_GATE; + } + if (tensor_name.find("ffn_down") != std::string::npos) { + return tensor_category::FFN_DOWN; + } + return tensor_category::OTHER; +} + +// check if category is for attention-v-like tensors (more sensitive to quantization) +static bool category_is_attn_v(tensor_category cat) { + return cat == tensor_category::ATTENTION_V || + cat == tensor_category::ATTENTION_QKV || + cat == tensor_category::ATTENTION_KV_B; +} + +// +// quantization state +// + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -89,20 +174,42 @@ struct quantize_state_impl { int i_ffn_gate = 0; int i_ffn_up = 0; - int n_k_quantized = 0; int n_fallback = 0; bool has_imatrix = false; - // used to figure out if a model shares tok_embd with the output weight - bool has_output = false; + // used to figure out if a model has tied embeddings (tok_embd shares weights with output) + bool has_tied_embeddings = true; // assume tied until we see output.weight - quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) - : model(model) - , params(params) - {} + // tensor type override patterns (compiled once, used twice) + std::vector> tensor_type_patterns; + + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params): + model(model), params(params) + { + // compile regex patterns once - they are expensive + if (params->tensor_types) { + const auto & tensor_types = *static_cast *>(params->tensor_types); + for (const auto & [tname, qtype] : tensor_types) { + tensor_type_patterns.emplace_back(std::regex(tname), qtype); + } + } + } }; +// per-tensor metadata, computed in the preliminary loop and used in the main loop +struct tensor_metadata { + ggml_type target_type; + tensor_category category; + std::string remapped_imatrix_name; + bool allows_quantization; + bool requires_imatrix; +}; + +// +// dequantization +// + static void llama_tensor_dequantize_impl( ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread @@ -175,12 +282,132 @@ static void llama_tensor_dequantize_impl( workers.clear(); } -static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { +// +// do we allow this tensor to be quantized? +// + +static bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_arch arch, const ggml_tensor * tensor) { + // trivial checks first -- no string ops needed + if (params->only_copy) return false; + + // quantize only 2D and 3D tensors (experts) + if (ggml_n_dims(tensor) < 2) return false; + + const std::string name = ggml_get_name(tensor); + + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + + quantize &= params->quantize_output_tensor || name != "output.weight"; + + // do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // these are very small (e.g. 4x4) + quantize &= name.find("altup") == std::string::npos; + quantize &= name.find("laurel") == std::string::npos; + + // these are not too big so keep them as it is + quantize &= name.find("per_layer_model_proj") == std::string::npos; + + // do not quantize positional embeddings and token types (BERT) + quantize &= name != LLM_TN(arch)(LLM_TENSOR_POS_EMBD, "weight"); + quantize &= name != LLM_TN(arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // do not quantize Mamba/Kimi's small conv1d weights + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ssm_conv1d") == std::string::npos; + quantize &= name.find("shortconv.conv.weight") == std::string::npos; + + // do not quantize RWKV's small yet 2D weights + quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w0.weight") == std::string::npos; + quantize &= name.find("time_mix_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_v0.weight") == std::string::npos; + quantize &= name.find("time_mix_v1.weight") == std::string::npos; + quantize &= name.find("time_mix_v2.weight") == std::string::npos; + quantize &= name.find("time_mix_a0.weight") == std::string::npos; + quantize &= name.find("time_mix_a1.weight") == std::string::npos; + quantize &= name.find("time_mix_a2.weight") == std::string::npos; + quantize &= name.find("time_mix_g1.weight") == std::string::npos; + quantize &= name.find("time_mix_g2.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // do not quantize relative position bias (T5) + quantize &= name.find("attn_rel_b.weight") == std::string::npos; + + // do not quantize specific multimodal tensors + quantize &= name.find(".position_embd.") == std::string::npos; + + return quantize; +} + +// +// tensor type selection +// + +// incompatible tensor shapes are handled here - fallback to a compatible type +static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tensor * t, const ggml_type target_type) { + ggml_type return_type = target_type; + + const int64_t ncols = t->ne[0]; + const int64_t qk_k = ggml_blck_size(target_type); + + if (ncols % qk_k != 0) { // this tensor's shape is incompatible with this quant + LLAMA_LOG_WARN("warning: %-36s - ncols %6" PRId64 " not divisible by %3" PRId64 " (required for type %7s) ", + t->name, ncols, qk_k, ggml_type_name(target_type)); + ++qs.n_fallback; + + switch (target_type) { + // types on the left: block size 256 + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: // types on the right: block size 32 + case GGML_TYPE_IQ4_XS: return_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: return_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_Q4_K: return_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: return_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: return_type = GGML_TYPE_Q8_0; break; + default: + throw std::runtime_error(format("no tensor type fallback is defined for type %s", + ggml_type_name(target_type))); + } + if (ncols % ggml_blck_size(return_type) != 0) { + // + // the fallback return type is still not compatible for this tensor! + // + // most likely, this tensor's first dimension is not divisible by 32. + // this is very rare. we can either abort the quantization, or + // fallback to F16 / F32. + // + LLAMA_LOG_WARN("(WARNING: must use F16 due to unusual shape) "); + return_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN("-> falling back to %7s\n", ggml_type_name(return_type)); + } + return return_type; +} + +// internal standard logic for selecting the target tensor type based on tensor category, ftype, and model arch +static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype, tensor_category category) { const std::string name = ggml_get_name(tensor); // TODO: avoid hardcoded tensor names - use the TN_* constants const llm_arch arch = qs.model.arch; - const auto tn = LLM_TN(arch); auto use_more_bits = [](int i_layer, int n_layers) -> bool { return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; @@ -204,7 +431,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings // with the quantization of the output tensor - if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + if (category == tensor_category::OUTPUT || (qs.has_tied_embeddings && category == tensor_category::TOKEN_EMBD)) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { @@ -234,7 +461,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } else { new_type = GGML_TYPE_Q8_0; } - } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") { + } else if (category == tensor_category::TOKEN_EMBD) { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { new_type = qs.params->token_embedding_type; } else { @@ -254,21 +481,21 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { - if (name.find("attn_v.weight") != std::string::npos) { + if (category_is_attn_v(category)) { if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } - else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + else if (qs.model.hparams.n_expert == 8 && category == tensor_category::ATTENTION_K) { new_type = GGML_TYPE_Q4_K; } - else if (name.find("ffn_down") != std::string::npos) { + else if (category == tensor_category::FFN_DOWN) { if (qs.i_ffn_down < qs.n_ffn_down/8) { new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; } ++qs.i_ffn_down; } - else if (name.find("attn_output.weight") != std::string::npos) { + else if (category == tensor_category::ATTENTION_OUTPUT) { if (qs.model.hparams.n_expert == 8) { new_type = GGML_TYPE_Q5_K; } else { @@ -276,7 +503,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; } } - } else if (name.find("attn_v.weight") != std::string::npos) { + } else if (category_is_attn_v(category)) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } @@ -314,7 +541,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = GGML_TYPE_Q8_0; } ++qs.i_attention_wv; - } else if (name.find("attn_k.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_K) { if (qs.model.hparams.n_expert == 8) { // for the 8-expert model, bumping this to Q8_0 trades just ~128MB // TODO: explore better strategies @@ -326,14 +553,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("attn_q.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_Q) { if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("ffn_down") != std::string::npos) { + } else if (category == tensor_category::FFN_DOWN) { auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; @@ -378,7 +605,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } ++qs.i_ffn_down; - } else if (name.find("attn_output.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_OUTPUT) { if (arch != LLM_ARCH_FALCON) { if (qs.model.hparams.n_expert == 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || @@ -398,14 +625,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; } } - else if (name.find("attn_qkv.weight") != std::string::npos) { + else if (category == tensor_category::ATTENTION_QKV) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; } - else if (name.find("ffn_gate") != std::string::npos) { + else if (category == tensor_category::FFN_GATE) { auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { @@ -413,7 +640,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } ++qs.i_ffn_gate; } - else if (name.find("ffn_up") != std::string::npos) { + else if (category == tensor_category::FFN_UP) { auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { @@ -422,60 +649,58 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); + return new_type; +} - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } +// outer wrapper: determine the ggml_type that this tensor should be quantized to +static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm) { + if (!tensor_allows_quantization(params, qs.model.arch, tensor)) { + return tensor->type; + } + if (params->token_embedding_type < GGML_TYPE_COUNT && tm.category == tensor_category::TOKEN_EMBD) { + return params->token_embedding_type; + } + if (params->output_tensor_type < GGML_TYPE_COUNT && tm.category == tensor_category::OUTPUT) { + return params->output_tensor_type; } - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + ggml_type new_type = default_type; + + // get more optimal quantization type based on the tensor shape, layer, etc. + if (!params->pure && ggml_is_quantized(default_type)) { + // if the user provided tensor types - use those + bool manual = false; + if (!qs.tensor_type_patterns.empty()) { + const std::string tensor_name(tensor->name); + for (const auto & [pattern, qtype] : qs.tensor_type_patterns) { + if (std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", + __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype)); + new_type = qtype; + manual = true; + break; + } + } + } } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type_impl(qs, new_type, tensor, params->ftype, tm.category); } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; + + // incompatible tensor shapes are handled here - fallback to a compatible type + new_type = tensor_type_fallback(qs, tensor, new_type); } return new_type; } +// +// quantization implementation +// + static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { if (nthread < 2) { // single-thread @@ -530,50 +755,85 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } -static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type default_type; - llama_ftype ftype = params->ftype; +// +// imatrix requirement check +// - switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; - case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; - case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; - case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; - case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; - case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; +static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type dst_type, const llama_ftype ftype) { + if (tensor_name_match_token_embd(tensor_name) || tensor_name_match_output_weight(tensor_name)) { + return false; + } + switch (dst_type) { + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_S: + return true; + case GGML_TYPE_Q2_K: + // as a general rule, the k-type quantizations don't require imatrix data. + // the only exception is Q2_K tensors that are part of a Q2_K_S file. + return ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S; + default: + return false; + } +} - case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break; +// +// given a file type, get the default tensor type +// + +static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { + switch (ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0; + case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1; + case LLAMA_FTYPE_MOSTLY_Q5_0: return GGML_TYPE_Q5_0; + case LLAMA_FTYPE_MOSTLY_Q5_1: return GGML_TYPE_Q5_1; + case LLAMA_FTYPE_MOSTLY_Q8_0: return GGML_TYPE_Q8_0; + case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16; + case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; + case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; + + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: - case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q2_K: return GGML_TYPE_Q2_K; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return GGML_TYPE_IQ3_S; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: - case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return GGML_TYPE_Q3_K; case LLAMA_FTYPE_MOSTLY_Q4_K_S: - case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return GGML_TYPE_Q4_K; case LLAMA_FTYPE_MOSTLY_Q5_K_S: - case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; - case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; - case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break; - case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; - case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; - case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; - case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return GGML_TYPE_Q5_K; + case LLAMA_FTYPE_MOSTLY_Q6_K: return GGML_TYPE_Q6_K; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return GGML_TYPE_TQ1_0; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return GGML_TYPE_TQ2_0; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return GGML_TYPE_IQ2_XXS; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return GGML_TYPE_IQ2_XS; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return GGML_TYPE_IQ2_XS; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return GGML_TYPE_IQ2_S; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return GGML_TYPE_IQ3_XXS; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return GGML_TYPE_IQ1_S; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return GGML_TYPE_IQ1_M; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return GGML_TYPE_IQ4_NL; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return GGML_TYPE_IQ4_XS; + case LLAMA_FTYPE_MOSTLY_IQ3_S: + case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } +} + +// +// main quantization driver +// + +static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + ggml_type default_type; + llama_ftype ftype = params->ftype; int nthread = params->nthread; @@ -581,6 +841,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } + default_type = llama_ftype_get_default_type(ftype); + // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. #if defined(__linux__) || defined(_WIN32) @@ -596,7 +858,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr, + fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -614,7 +877,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->imatrix) { imatrix_data = static_cast>*>(params->imatrix); if (imatrix_data) { - LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n", + __func__, (int)imatrix_data->size()); qs.has_imatrix = true; // check imatrix for nans or infs for (const auto & kv : *imatrix_data) { @@ -636,7 +900,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } // copy the KV pairs from the input file - gguf_set_kv (ctx_out.get(), ml.meta.get()); + gguf_set_kv (ctx_out.get(), ml.metadata); gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV @@ -697,35 +961,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } - for (const auto * it : tensors) { - const struct ggml_tensor * tensor = it->tensor; - - const std::string name = ggml_get_name(tensor); - - // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos || - name.find("attn_qkv.weight") != std::string::npos || - name.find("attn_kv_b.weight")!= std::string::npos) { - ++qs.n_attention_wv; - } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { - qs.has_output = true; - } - } - - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; - - size_t total_size_org = 0; - size_t total_size_new = 0; - - std::vector workers; - workers.reserve(nthread); - int idx = 0; - - std::vector> read_data; - std::vector> work; - std::vector> f32_conv_buf; - uint16_t n_split = 1; // Assume split index is continuous @@ -737,14 +973,68 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector ctx_outs(n_split); ctx_outs[0] = std::move(ctx_out); - // populate the original tensors so we get an initial meta data - for (const auto * it : tensors) { + // compute tensor metadata once and cache it + std::vector metadata(tensors.size()); + + // initialize quantization state before preliminary loop (counters for use_more_bits) + { + for (size_t i = 0; i < tensors.size(); ++i) { + const auto cat = tensor_get_category(tensors[i]->tensor->name); + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } + metadata[i].category = cat; // save and re-use the category while we're at it + } + // these also need to be set to n_layer by default + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; + } + + // flag for --dry-run + bool will_require_imatrix = false; + + // + // preliminary iteration over all weights + // + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto * it = tensors[i]; + const struct ggml_tensor * tensor = it->tensor; + const std::string name = ggml_get_name(tensor); + uint16_t i_split = params->keep_split ? it->idx : 0; - ggml_tensor * tensor = it->tensor; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); } gguf_add_tensor(ctx_outs[i_split].get(), tensor); + + metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor); + + if (metadata[i].allows_quantization) { + metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]); + } else { + metadata[i].target_type = tensor->type; + } + + metadata[i].requires_imatrix = tensor_requires_imatrix(tensor->name, metadata[i].target_type, ftype); + + if (params->imatrix) { + metadata[i].remapped_imatrix_name = remap_imatrix(tensor->name, mapped); + } else if (metadata[i].allows_quantization && metadata[i].requires_imatrix) { + if (params->dry_run) { + will_require_imatrix = true; + } else { + LLAMA_LOG_ERROR("\n============================================================================\n" + " ERROR: this quantization requires an importance matrix!\n" + " - offending tensor: %s\n" + " - target type: %s\n" + "============================================================================\n\n", + name.c_str(), ggml_type_name(metadata[i].target_type)); + throw std::runtime_error("this quantization requires an imatrix!"); + } + } } // Set split info if needed @@ -756,6 +1046,16 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector workers; + workers.reserve(nthread); + + std::vector> read_data; + std::vector> work; + std::vector> f32_conv_buf; + int cur_split = -1; std::ofstream fout; auto close_ofstream = [&]() { @@ -785,251 +1085,182 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: ::zeros(fout, meta_size); }; - const auto tn = LLM_TN(model.arch); - new_ofstream(0); - for (const auto * it : tensors) { - const auto & weight = *it; + // no output file for --dry-run + if (!params->dry_run) { + new_ofstream(0); + } + + // + // main loop: iterate over all weights + // + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto & weight = *tensors[i]; + const auto & tm = metadata[i]; ggml_tensor * tensor = weight.tensor; - if (weight.idx != cur_split && params->keep_split) { + + if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) { close_ofstream(); new_ofstream(weight.idx); } const std::string name = ggml_get_name(tensor); + const size_t tensor_size = ggml_nbytes(tensor); - if (!ml.use_mmap) { - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!params->dry_run) { + if (!ml.use_mmap) { + if (read_data.size() < tensor_size) { + read_data.resize(tensor_size); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); + ml.load_data_for(tensor); } - ml.load_data_for(tensor); - LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + LLAMA_LOG_INFO("[%4d/%4d] %-36s - [%s], type = %6s, ", ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); - // This used to be a regex, but has an extreme cost to compile times. - bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + const ggml_type cur_type = tensor->type; + const ggml_type new_type = tm.target_type; - // quantize only 2D and 3D tensors (experts) - quantize &= (ggml_n_dims(tensor) >= 2); + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + bool quantize = cur_type != new_type; - // do not quantize norm tensors - quantize &= name.find("_norm.weight") == std::string::npos; - - quantize &= params->quantize_output_tensor || name != "output.weight"; - quantize &= !params->only_copy; - - // do not quantize expert gating tensors - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; - - // these are very small (e.g. 4x4) - quantize &= name.find("altup") == std::string::npos; - quantize &= name.find("laurel") == std::string::npos; - - // these are not too big so keep them as it is - quantize &= name.find("per_layer_model_proj") == std::string::npos; - - // do not quantize positional embeddings and token types (BERT) - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - - // do not quantize Mamba's small yet 2D weights - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - quantize &= name.find("shortconv.conv.weight") == std::string::npos; - - // do not quantize RWKV's small yet 2D weights - quantize &= name.find("time_mix_first.weight") == std::string::npos; - quantize &= name.find("time_mix_w0.weight") == std::string::npos; - quantize &= name.find("time_mix_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_v0.weight") == std::string::npos; - quantize &= name.find("time_mix_v1.weight") == std::string::npos; - quantize &= name.find("time_mix_v2.weight") == std::string::npos; - quantize &= name.find("time_mix_a0.weight") == std::string::npos; - quantize &= name.find("time_mix_a1.weight") == std::string::npos; - quantize &= name.find("time_mix_a2.weight") == std::string::npos; - quantize &= name.find("time_mix_g1.weight") == std::string::npos; - quantize &= name.find("time_mix_g2.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; - - // do not quantize relative position bias (T5) - quantize &= name.find("attn_rel_b.weight") == std::string::npos; - - // do not quantize specific multimodal tensors - quantize &= name.find(".position_embd.") == std::string::npos; - - ggml_type new_type; void * new_data; size_t new_size; - if (quantize) { - new_type = default_type; + if (params->dry_run) { + // the --dry-run option calculates the final quantization size without quantizing + if (quantize) { + new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]); + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n", + tensor_size/1024.0/1024.0, + new_size/1024.0/1024.0, + ggml_type_name(new_type)); + if (!will_require_imatrix && tm.requires_imatrix) { + will_require_imatrix = true; + } + } else { + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", new_size/1024.0/1024.0); + } + total_size_org += tensor_size; + total_size_new += new_size; + continue; + } else { + // no --dry-run, perform quantization + if (!quantize) { + new_data = tensor->data; + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); - // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { - const std::vector & tensor_types = *static_cast *>(params->tensor_types); - const std::string tensor_name(tensor->name); - for (const auto & [tname, qtype] : tensor_types) { - if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { - if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); - new_type = qtype; // if two or more types are specified for the same tensor, the last match wins + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(tm.remapped_imatrix_name); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + } else { + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); + + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (!tensor_name_match_token_embd(tensor->name)) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); } } } } - } - if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { - new_type = params->token_embedding_type; - } - if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { - new_type = params->output_tensor_type; - } + if (!imatrix && tm.requires_imatrix) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } - // If we've decided to quantize to the same type the tensor is already - // in then there's nothing to do. - quantize = tensor->type != new_type; - } + float * f32_data; - if (!quantize) { - new_type = tensor->type; - new_data = tensor->data; - new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); - } else { - const int64_t nelements = ggml_nelements(tensor); - - const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); } else { - if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { - imatrix = it->second.data(); - } else { - LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); - - // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix - // this is a significant error and it may be good idea to abort the process if this happens, - // since many people will miss the error and not realize that most of the model is being quantized without an imatrix - // tok_embd should be ignored in this case, since it always causes this warning - if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); - } - } + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); } - } - if ((new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ2_XS || - new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ1_S || - (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { - LLAMA_LOG_ERROR("\n\n============================================================\n"); - LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); - LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); - LLAMA_LOG_ERROR("============================================================\n\n"); - throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); - } - float * f32_data; + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); - if (tensor->type == GGML_TYPE_F32) { - f32_data = (float *) tensor->data; - } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); - } else { - llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); - f32_data = (float *) f32_conv_buf.data(); - } - - LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); - fflush(stdout); - - if (work.size() < (size_t)nelements * 4) { - work.resize(nelements * 4); // upper bound on size - } - new_data = work.data(); - - const int64_t n_per_row = tensor->ne[0]; - const int64_t nrows = tensor->ne[1]; - - static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); - - const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; - - // quantize each expert separately since they have different importance matrices - new_size = 0; - for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { - const float * f32_data_03 = f32_data + i03 * nelements_matrix; - void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; - const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; - - new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); - - // TODO: temporary sanity check that the F16 -> MXFP4 is lossless -#if 0 - if (new_type == GGML_TYPE_MXFP4) { - auto * x = f32_data_03; - - //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); - std::vector deq(nrows*n_per_row); - const ggml_type_traits * qtype = ggml_get_type_traits(new_type); - qtype->to_float(new_data_03, deq.data(), deq.size()); - - double err = 0.0f; - for (int i = 0; i < (int) deq.size(); ++i) { - err += fabsf(deq[i] - x[i]); - //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { - if (deq[i] != x[i]) { - LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); - } - } - //LLAMA_LOG_INFO("err = %f\n", err); - GGML_ASSERT(err == 0.00000); + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size } -#endif + new_data = work.data(); + + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; + + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); + + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0); } - LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); - } - total_size_org += ggml_nbytes(tensor); - total_size_new += new_size; + total_size_org += tensor_size; + total_size_new += new_size; - // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } // no --dry-run + } // main loop + + if (!params->dry_run) { + close_ofstream(); } - close_ofstream(); - LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_org/1024.0/1024.0, total_size_org*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_new/1024.0/1024.0, total_size_new*8.0/ml.n_elements); + + if (!params->imatrix && params->dry_run && will_require_imatrix) { + LLAMA_LOG_WARN("%s: WARNING: dry run completed successfully, but actually completing this quantization will require an imatrix!\n", + __func__ + ); + } if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", - __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback); + __func__, qs.n_fallback, ml.n_tensors); } } @@ -1048,6 +1279,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.keep_split =*/ false, + /*.dry_run =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampler.cpp similarity index 94% rename from examples/talk-llama/llama-sampling.cpp rename to examples/talk-llama/llama-sampler.cpp index 11f0394c..9bbc5dbd 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampler.cpp @@ -1,4 +1,4 @@ -#include "llama-sampling.h" +#include "llama-sampler.h" #include "llama-impl.h" #include "llama-vocab.h" @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -1513,12 +1486,9 @@ static void llama_sampler_top_p_backend_apply( mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // top_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); ggml_set_name(top_p_bias, "top_p_bias"); data->logits = ggml_add(ctx, sorted_logits, top_p_bias); @@ -1673,15 +1643,11 @@ static void llama_sampler_min_p_backend_apply( struct ggml_tensor * mask = ggml_step(ctx, sub); ggml_set_name(mask, "min_p_mask"); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // min_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); ggml_set_name(min_p_bias, "min_p_bias"); - // Add the min_p bias to the logits. data->logits = ggml_add(ctx, data->logits, min_p_bias); ggml_set_name(data->logits, "min_p_logits"); @@ -3293,6 +3259,170 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// adaptive-p sampler state +// +// maintains an exponential moving average of the *ORIGINAL* probabilities +// of selected tokens, used to compute an adapted target at each sampling step. +// +// see llama.h for a full description of the sampler +// +// ref: https://github.com/ggml-org/llama.cpp/pull/17927 +// +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) + const uint32_t seed; // original RNG seed + uint32_t seed_cur; // actual RNG seed + std::mt19937 rng; // RNG state + float weighted_sum; // sum(p_i * decay^i) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + std::vector original_probs; // pre-transform probs, cached for EMA update + llama_token pending_token_id; // token ID of selected token + int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs +}; + +// adaptive probability transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float SHARPNESS = 10.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + +static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { + return "adaptive-p"; +} + +static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, false); + + if (ctx->target < 0.0f) { + // at negative target values, adaptive-p is no-op + // we simply sample from the existing distribution + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + return; + } + + // store the original probabilities + ctx->original_probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->original_probs[i] = cur_p->data[i].p; + } + + // using the EMA, compute the adapted target probability for the current sampling step + auto target = std::clamp(ctx->target, 0.0f, 1.0f); + float adapted_target = std::clamp( + ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + + // adaptive probability transform + // + // quadratic near target for fine differentiation, transitioning to linear decay in the + // tails. unbounded negative logits ensure proper suppression of far-from-target tokens + // after the softmax. + // + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit == -INFINITY) { + // don't transform logits that are -INFINITY + // (as masked out by e.g. min-p and top-p when using backend sampling) + continue; + } + float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); + cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); + } + + // softmax and sample from the transformed distribution + llama_sampler_softmax_impl(cur_p, false); + const int idx = llama_sample_dist(cur_p, ctx->rng); + cur_p->selected = idx; + + // store the selected token ID for acceptance later + ctx->pending_token_id = cur_p->data[idx].id; + ctx->pending_token_idx = idx; +} + +static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + if (ctx->pending_token_id == token) { + GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(ctx->pending_token_idx != -1); + // update EMA with the original probability of the selected token + ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; + } + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; +} + +static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state and pending token. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; + auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); + auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; + + // copy everything (target, decay, seed, and RNG are already set) + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->pending_token_id = ctx->pending_token_id; + result_ctx->pending_token_idx = ctx->pending_token_idx; + + return result; +} + +static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_adaptive_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_adaptive_p_i = { + /* .name = */ llama_sampler_adaptive_p_name, + /* .accept = */ llama_sampler_adaptive_p_accept, + /* .apply = */ llama_sampler_adaptive_p_apply, + /* .reset = */ llama_sampler_adaptive_p_reset, + /* .clone = */ llama_sampler_adaptive_p_clone, + /* .free = */ llama_sampler_adaptive_p_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return llama_sampler_init( + /* .iface = */ &llama_sampler_adaptive_p_i, + /* .ctx = */ new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .original_probs = */ {}, + /* .pending_token_id = */ LLAMA_TOKEN_NULL, + /* .pending_token_idx = */ -1 + } + ); +} + // logit-bias struct llama_sampler_logit_bias : public llama_sampler_backend { @@ -3304,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3369,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3405,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3413,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3471,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampler.h similarity index 92% rename from examples/talk-llama/llama-sampling.h rename to examples/talk-llama/llama-sampler.h index 6a963c0b..b9bfc20d 100644 --- a/examples/talk-llama/llama-sampling.h +++ b/examples/talk-llama/llama-sampler.h @@ -1,7 +1,5 @@ #pragma once -// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - #include "llama.h" #include diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a20c6525..68ba292d 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -90,7 +90,7 @@ static_assert(std::is_trivially_copyable::value, "llm_symbol is not // // SPM tokenizer // original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// https://github.com/ggml-org/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // struct llm_bigram_spm { @@ -285,10 +285,19 @@ struct llm_tokenizer_bpe : llm_tokenizer { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + // adapted: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2080233989 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_JAIS2: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + + // adapted: same as llama3 but with cascading whitespace pattern + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: regex_exprs = { @@ -308,6 +317,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: + case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -368,6 +378,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -415,6 +432,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: + regex_exprs = { + // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" + "\\d{1,3}(?=(?:\\d{3})*\\b)", + // original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: regex_exprs = { // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp @@ -461,6 +486,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1687,7 +1719,7 @@ private: }; void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { - struct gguf_context * ctx = ml.meta.get(); + struct gguf_context * ctx = ml.metadata; // determine vocab type { @@ -1745,26 +1777,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - bpe_ranks.emplace(std::make_pair(first, second), i); + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } } // default special tokens @@ -1794,7 +1833,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - // correct endiannes of data in precompiled_charsmap binary blob + // correct endianness of data in precompiled_charsmap binary blob uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); @@ -1851,7 +1890,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || tokenizer_pre == "midm-2.0" || - tokenizer_pre == "lfm2") { + tokenizer_pre == "lfm2" || + tokenizer_pre == "jina-v5-nano") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -1891,8 +1931,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || tokenizer_pre == "mellum" || - tokenizer_pre == "modern-bert" ) { + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jais-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -1912,6 +1955,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; @@ -1965,6 +2012,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "exaone-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -1977,10 +2027,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4" || + tokenizer_pre == "kanana2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "tiny_aya") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; + clean_spaces = false; } else if ( tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; @@ -2011,6 +2066,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan-dense") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; clean_spaces = false; + } else if ( + tokenizer_pre == "joyai-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; @@ -2216,6 +2275,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" + || t.first == "[EOT]" // Kimi-K2 || t.first == "<|end▁of▁sentence|>" // DeepSeek || t.first == "" // smoldocling ) { @@ -2252,6 +2312,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
                         || t.first == "<|code_prefix|>" // GLM-4.5
+                        || t.first == "<|prefix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_pre_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2272,6 +2333,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_suffix|>" // GLM-4.5
+                        || t.first == "<|suffix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_suf_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2292,6 +2354,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_middle|>" // GLM-4.5
+                        || t.first == "<|middle|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_mid_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2309,6 +2372,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == ""   // Granite
                         || t.first == ""
+                        || t.first == "[PAD]" // Kimi-K2
                         ) {
                     special_fim_pad_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2380,7 +2444,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
         // maintain a list of tokens that cause end-of-generation
         // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        // ref: https://github.com/ggml-org/llama.cpp/issues/9606
         special_eog_ids.clear();
 
         if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
@@ -2408,9 +2472,12 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|calls|>"  // solar-open
                     || t.first == ""
                     || t.first == "<|endoftext|>"
+                    || t.first == ""      // paddleocr
                     || t.first == "<|eom_id|>"
                     || t.first == ""
                     || t.first == "_"
+                    || t.first == "[EOT]" // Kimi-K2
+                    || t.first == "[EOS]" // Kimi-K2
                     || t.first == "<|end_of_text|>"
                     || t.first == "" // smoldocling
                ) {
@@ -2436,7 +2503,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             auto & attr = id_to_token[t.second].attr;
 
             if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") {
-                attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
+                LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n",
+                        __func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr);
+
+                attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
             }
         }
 
@@ -2489,7 +2559,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 special_eog_ids.erase(end_id);
 
                 auto & attr = id_to_token[end_id].attr;
-                attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED);
+                attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
 
                 LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
             }
@@ -3066,7 +3136,7 @@ std::vector llama_vocab::impl::tokenize(
 }
 
 int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    // ref: https://github.com/ggml-org/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
     const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
@@ -3289,34 +3359,34 @@ int32_t llama_vocab::impl::detokenize(
 }
 
 void llama_vocab::impl::print_info() const {
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: vocab type            = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab               = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges              = %u\n",     __func__, (uint32_t) bpe_ranks.size());
 
     // special tokens
-    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
-    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
-    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
-    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
-    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
-    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
-    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
-    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token             = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token             = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token             = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token             = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token             = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token             = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token             = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token            = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
 
-    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token              = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
 
-    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
-    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
-    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
-    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
-    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
-    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token         = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token         = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token         = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token         = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token         = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token         = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
 
     for (const auto & id : special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
+        LLAMA_LOG_INFO( "%s: EOG token             = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
     }
 
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+    LLAMA_LOG_INFO("%s: max token length      = %d\n", __func__, max_token_len);
 }
 
 llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h
index 2b240a54..be5b0801 100644
--- a/examples/talk-llama/llama-vocab.h
+++ b/examples/talk-llama/llama-vocab.h
@@ -53,6 +53,11 @@ enum llama_vocab_pre_type {
     LLAMA_VOCAB_PRE_TYPE_AFMOE           = 42,
     LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN      = 43,
     LLAMA_VOCAB_PRE_TYPE_YOUTU           = 44,
+    LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE      = 45,
+    LLAMA_VOCAB_PRE_TYPE_QWEN35          = 46,
+    LLAMA_VOCAB_PRE_TYPE_TINY_AYA        = 47,
+    LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM       = 48,
+    LLAMA_VOCAB_PRE_TYPE_JAIS2           = 49,
 };
 
 struct LLM_KV;
diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp
index f1096d96..872e659e 100644
--- a/examples/talk-llama/llama.cpp
+++ b/examples/talk-llama/llama.cpp
@@ -1,5 +1,6 @@
 #include "llama.h"
 
+#include "ggml-cpp.h"
 #include "llama-impl.h"
 
 #include "llama-chat.h"
@@ -12,6 +13,7 @@
 
 #include "ggml.h"
 #include "ggml-backend.h"
+#include "gguf.h"
 
 #include 
 #include 
@@ -311,8 +313,12 @@ static void llama_params_fit_impl(
                             __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
                     }
                 } else {
-                    LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
-                        __func__, hp_nct, n_ctx_min);
+                    if (n_ctx_min == UINT32_MAX) {
+                        LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct);
+                    } else {
+                        LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n",
+                            __func__, hp_nct, n_ctx_min);
+                    }
                 }
             } else {
                 LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx);
@@ -821,7 +827,8 @@ int64_t llama_time_us(void) {
 }
 
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
+static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud,
+        const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
     // loading time will be recalculated after the first eval, so
     // we take page faults deferred by mmap() into consideration
     model.t_load_us = 0;
@@ -830,7 +837,8 @@ static int llama_model_load(const std::string & fname, std::vector
     model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
+        llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, params.use_mmap, params.use_direct_io,
+            params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides);
 
         ml.print_info();
 
@@ -876,9 +884,13 @@ static int llama_model_load(const std::string & fname, std::vector
 }
 
 static struct llama_model * llama_model_load_from_file_impl(
+        struct gguf_context * metadata,
+        llama_model_set_tensor_data_t set_tensor_data,
+        void * set_tensor_data_ud,
         const std::string & path_model,
         std::vector & splits,
         struct llama_model_params params) {
+    GGML_ASSERT((metadata == nullptr) != path_model.empty() && "exactly one out of metadata and path_model needs to be defined");
     ggml_time_init();
 
     if (!params.vocab_only && ggml_backend_reg_count() == 0) {
@@ -999,7 +1011,7 @@ static struct llama_model * llama_model_load_from_file_impl(
                 props.memory_free/1024/1024);
     }
 
-    const int status = llama_model_load(path_model, splits, *model, params);
+    const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, *model, params);
     GGML_ASSERT(status <= 0);
     if (status < 0) {
         if (status == -1) {
@@ -1015,6 +1027,18 @@ static struct llama_model * llama_model_load_from_file_impl(
     return model;
 }
 
+struct llama_model * llama_model_init_from_user(
+        struct gguf_context * metadata,
+        llama_model_set_tensor_data_t set_tensor_data,
+        void * set_tensor_data_ud,
+        struct llama_model_params params) {
+    GGML_ASSERT(metadata != nullptr);
+    std::string path_model;
+    std::vector splits = {};
+    params.use_mmap = false;
+    params.use_extra_bufts = false;
+    return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, params);
+}
 // deprecated
 struct llama_model * llama_load_model_from_file(
         const char * path_model,
@@ -1026,7 +1050,7 @@ struct llama_model * llama_model_load_from_file(
         const char * path_model,
         struct llama_model_params params) {
     std::vector splits = {};
-    return llama_model_load_from_file_impl(path_model, splits, params);
+    return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, params);
 }
 
 struct llama_model * llama_model_load_from_splits(
@@ -1042,11 +1066,11 @@ struct llama_model * llama_model_load_from_splits(
     for (size_t i = 0; i < n_paths; ++i) {
         splits.push_back(paths[i]);
     }
-    return llama_model_load_from_file_impl(splits.front(), splits, params);
+    return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, params);
 }
 
 void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
-    llama_model_saver ms(*model);
+    llama_model_saver ms(model);
     ms.add_kv_from_model();
     ms.add_tensors_from_model();
     ms.save(path_model);
@@ -1091,25 +1115,55 @@ int32_t llama_chat_apply_template(
 // model split
 //
 
-int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+int32_t llama_split_path(
+    char * split_path,
+    size_t maxlen,
+    const char * path_prefix,
+    int32_t split_no,
+    int32_t split_count) {
+
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
-    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
-        return strlen(split_path);
+
+    const int written = snprintf(
+        split_path,
+        maxlen,
+        SPLIT_PATH_FORMAT,
+        path_prefix,
+        split_no + 1,
+        split_count
+    );
+
+    if (written < 0 || (size_t) written >= maxlen) {
+        return 0;
     }
-    return 0;
+
+    return (int32_t) written;
 }
 
-int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
-    std::string str_split_path(split_path);
-    char postfix[32];
-    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
-    std::string str_postfix(postfix);
+int32_t llama_split_prefix(
+    char * split_prefix,
+    size_t maxlen,
+    const char * split_path,
+    int32_t split_no,
+    int32_t split_count) {
 
-    // check if split_prefix ends with postfix
-    int size_prefix = str_split_path.size() - str_postfix.size();
-    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
-        return size_prefix;
+    const std::string str_split_path(split_path);
+
+    char postfix[32];
+    snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count);
+
+    const std::string str_postfix(postfix);
+    if (str_split_path.size() <= str_postfix.size()) {
+        return 0;
+    }
+
+    const size_t size_prefix = str_split_path.size() - str_postfix.size();
+
+    if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) {
+        const size_t copy_len = std::min(size_prefix + 1, maxlen);
+        snprintf(split_prefix, copy_len, "%s", split_path);
+
+        return (int32_t) size_prefix;
     }
 
     return 0;
diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h
index 1c17efb9..c6e102ab 100644
--- a/examples/talk-llama/llama.h
+++ b/examples/talk-llama/llama.h
@@ -5,6 +5,7 @@
 #include "ggml-cpu.h"
 #include "ggml-backend.h"
 #include "ggml-opt.h"
+#include "gguf.h"
 
 #include 
 #include 
@@ -152,6 +153,7 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_TQ1_0         = 36, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_TQ2_0         = 37, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_MXFP4_MOE     = 38, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_NVFP4         = 39, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -309,7 +311,7 @@ extern "C" {
         // Keep the booleans together to avoid misalignment during copy-by-value.
         bool vocab_only;      // only load the vocabulary, no weights
         bool use_mmap;        // use mmap if possible
-        bool use_direct_io;   // use direct io, takes precedence over use_mmap
+        bool use_direct_io;   // use direct io, takes precedence over use_mmap when supported
         bool use_mlock;       // force system to keep model in RAM
         bool check_tensors;   // validate model tensor data
         bool use_extra_bufts; // use extra buffer types (used for weight repacking)
@@ -389,6 +391,7 @@ extern "C" {
         bool only_copy;                       // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
         bool pure;                            // quantize all tensors to the default type
         bool keep_split;                      // quantize to the same number of shards
+        bool dry_run;                         // calculate and show the final quantization size without performing quantization
         void * imatrix;                       // pointer to importance matrix data
         void * kv_overrides;                  // pointer to vector containing overrides
         void * tensor_types;                  // pointer to vector containing tensor types
@@ -439,19 +442,30 @@ extern "C" {
 
     LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
 
+    typedef void (*llama_model_set_tensor_data_t)(struct ggml_tensor * tensor, void * userdata);
+
+    // Create a new model from GGUF metadata as well as a function to set the tensor data
+    //   - tensors are created as GGML_TYPE_F32 by default,
+    //     override by adding a tensor with the same name but a different name to the context
+    LLAMA_API struct llama_model * llama_model_init_from_user(
+                    struct gguf_context * metadata,
+          llama_model_set_tensor_data_t   set_tensor_data,    // function to initialize tensor data with
+                                   void * set_tensor_data_ud, // userdata for function
+              struct llama_model_params   params);
+
     DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
                              const char * path_model,
               struct llama_model_params   params),
             "use llama_model_load_from_file instead");
 
-    // Load the model from a file
+    // Load a model from a file
     // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf
     // If the split file name does not follow this pattern, use llama_model_load_from_splits
     LLAMA_API struct llama_model * llama_model_load_from_file(
                              const char * path_model,
               struct llama_model_params   params);
 
-    // Load the model from multiple splits (support custom naming scheme)
+    // Load a model from multiple splits (support custom naming scheme)
     // The paths must be in the correct order
     LLAMA_API struct llama_model * llama_model_load_from_splits(
                              const char ** paths,
@@ -482,13 +496,14 @@ extern "C" {
     enum llama_params_fit_status {
         LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
         LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
-        LLAMA_PARAMS_FIT_STATUS_ERROR   = 2, // a hard error occured, e.g. because no model could be found at the specified path
+        LLAMA_PARAMS_FIT_STATUS_ERROR   = 2, // a hard error occurred, e.g. because no model could be found at the specified path
     };
 
     // fits mparams and cparams to free device memory (assumes system memory is unlimited)
     //   - returns true if the parameters could be successfully modified to fit device memory
     //   - this function is NOT thread safe because it modifies the global llama logger state
     //   - only parameters that have the same value as in llama_default_model_params are modified
+    //     with the exception of the context size which is modified if and only if equal to 0
     LLAMA_API enum llama_params_fit_status llama_params_fit(
                                    const char   * path_model,
                     struct llama_model_params   * mparams,
@@ -646,7 +661,8 @@ extern "C" {
 
     // Manually free a LoRA adapter
     // NOTE: loaded adapters will be free when the associated model is deleted
-    LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
+    LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter),
+            "adapters are now freed together with the associated model");
 
     // Get the invocation tokens if the current lora is an alora
     LLAMA_API uint64_t            llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
@@ -654,21 +670,12 @@ extern "C" {
 
     // The following functions operate on a llama_context, hence the naming: llama_verb_...
 
-    // Add a loaded LoRA adapter to given context
-    // This will not modify model's weight
-    LLAMA_API int32_t llama_set_adapter_lora(
+    // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different.
+    LLAMA_API int32_t llama_set_adapters_lora(
             struct llama_context * ctx,
-            struct llama_adapter_lora * adapter,
-            float scale);
-
-    // Remove a specific LoRA adapter from given context
-    // Return -1 if the adapter is not present in the context
-    LLAMA_API int32_t llama_rm_adapter_lora(
-            struct llama_context * ctx,
-            struct llama_adapter_lora * adapter);
-
-    // Remove all LoRA adapters from given context
-    LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
+            struct llama_adapter_lora ** adapters,
+            size_t n_adapters,
+            float * scales);
 
     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
     // the currently loaded vector.
@@ -676,7 +683,7 @@ extern "C" {
     // to an n_embd x n_layers buffer starting from layer 1.
     // il_start and il_end are the layer range the vector should apply to (both inclusive)
     // See llama_control_vector_load in common to load a control vector.
-    LLAMA_API int32_t llama_apply_adapter_cvec(
+    LLAMA_API int32_t llama_set_adapter_cvec(
             struct llama_context * ctx,
                      const float * data,
                           size_t   len,
@@ -979,7 +986,7 @@ extern "C" {
 
     // Logits for the ith token. For positive indices, Equivalent to:
     // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
-    // Negative indicies can be used to access logits in reverse order, -1 is the last logit.
+    // Negative indices can be used to access logits in reverse order, -1 is the last logit.
     // returns NULL for invalid ids.
     LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
 
@@ -994,7 +1001,7 @@ extern "C" {
 
     // Get the embeddings for the ith token. For positive indices, Equivalent to:
     // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
-    // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
+    // Negative indices can be used to access embeddings in reverse order, -1 is the last embedding.
     // shape: [n_embd] (1-dimensional)
     // returns NULL for invalid ids.
     LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
@@ -1014,9 +1021,9 @@ extern "C" {
     // Returns LLAMA_TOKEN_NULL if no token was sampled.
     LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
 
-    // Get the backend sampled probabilites for the ith token
+    // Get the backend sampled probabilities for the ith token
     // The index matches llama_get_sampled_token_ith().
-    // Returns NULL if no probabilites were generated.
+    // Returns NULL if no probabilities were generated.
     LLAMA_API float *  llama_get_sampled_probs_ith      (struct llama_context * ctx, int32_t i);
     LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
 
@@ -1148,9 +1155,9 @@ extern "C" {
     //
 
     /// Apply chat template. Inspired by hf apply_chat_template() on python.
-    /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
+    ///
     /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
-    /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
+    /// @param tmpl A Jinja template to use for this chat.
     /// @param chat Pointer to a list of multiple llama_chat_message
     /// @param n_msg Number of llama_chat_message in this chat
     /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
@@ -1255,7 +1262,6 @@ extern "C" {
     // [EXPERIMENTAL]
     // attach a sampler to the context
     // note: prefer initializing the context with llama_context_params.samplers when possible
-    // note: changing the samplers of a context can cause graph reallocations and degraded performance
     LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
 
     // mirror of llama_sampler_i:
@@ -1344,7 +1350,7 @@ extern "C" {
                                float   tau,
                                float   eta);
 
-    /// @details Intializes a GBNF grammar, see grammars/README.md for details.
+    /// @details Initializes a GBNF grammar, see grammars/README.md for details.
     /// @param vocab The vocabulary that this grammar will be used with.
     /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
     /// @param grammar_root The name of the start symbol for the grammar.
@@ -1395,6 +1401,33 @@ extern "C" {
                           const char ** seq_breakers,
                               size_t    num_breakers);
 
+    /// adaptive-p: select tokens near a configurable target probability over time.
+    ///
+    /// the adaptive-p sampler transforms the token probability distribution to favor tokens
+    /// that fall near a user-configurable probability target.
+    ///
+    /// internally, the sampler maintains an exponential moving average of the *ORIGINAL*
+    /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an
+    /// adapted target probability at each sampling step, thus maintaining the desired target
+    /// probability over time.
+    ///
+    /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last
+    /// in the sampler chain (like mirostat, dist, greedy).
+    ///
+    /// only mild truncation before this sampler is recommended. we suggest applying min-p
+    /// before adaptive-p as the only other active sampler in the chain.
+    ///
+    /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled)
+    /// @param decay  EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99)
+    /// @param seed   RNG seed
+    ///
+    /// ref: https://github.com/ggml-org/llama.cpp/pull/17927
+    ///
+    LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p(
+                               float   target,
+                               float   decay,
+                            uint32_t   seed);
+
     LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
                              int32_t   n_vocab,
                              int32_t   n_logit_bias,
@@ -1448,12 +1481,12 @@ extern "C" {
     /// @details Build a split GGUF final path for this chunk.
     ///          llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
     //  Returns the split_path length.
-    LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count);
+    LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count);
 
     /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match.
     ///          llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0"
     //  Returns the split_prefix length.
-    LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
+    LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count);
 
     // Print system information
     LLAMA_API const char * llama_print_system_info(void);
diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp
index 6a752a40..9aabe25c 100644
--- a/examples/talk-llama/models/afmoe.cpp
+++ b/examples/talk-llama/models/afmoe.cpp
@@ -1,8 +1,8 @@
 #include "models.h"
 
 llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -127,7 +127,6 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para
                     n_expert, n_expert_used,
                     LLM_FFN_SILU,
                     hparams.expert_weights_norm,           // norm_w (route_norm=True)
-                    hparams.expert_weights_scale,          // scale_w
                     hparams.expert_weights_scale,          // w_scale (route_scale=2.826)
                     (llama_expert_gating_func_type) hparams.expert_gating_func,
                     il);
diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp
index 9af19c1b..4d65614e 100644
--- a/examples/talk-llama/models/apertus.cpp
+++ b/examples/talk-llama/models/apertus.cpp
@@ -3,10 +3,10 @@
 
 
 llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp
index aa6167db..20b9ffd4 100644
--- a/examples/talk-llama/models/arcee.cpp
+++ b/examples/talk-llama/models/arcee.cpp
@@ -2,10 +2,10 @@
 
 
 llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp
index e8f028a7..b712e08c 100644
--- a/examples/talk-llama/models/arctic.cpp
+++ b/examples/talk-llama/models/arctic.cpp
@@ -1,11 +1,10 @@
 #include "models.h"
 
-
 llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -104,7 +103,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp
index c04b0c98..abd03cd0 100644
--- a/examples/talk-llama/models/baichuan.cpp
+++ b/examples/talk-llama/models/baichuan.cpp
@@ -2,10 +2,10 @@
 
 
 llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -56,6 +56,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap
                             );
                     break;
                 case LLM_TYPE_13B:
+                case LLM_TYPE_UNKNOWN:
                     break;
                 default:
                     GGML_ABORT("fatal error");
diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp
index ed56b9c4..25e3369c 100644
--- a/examples/talk-llama/models/bailingmoe.cpp
+++ b/examples/talk-llama/models/bailingmoe.cpp
@@ -1,6 +1,5 @@
 #include "models.h"
 
-
 llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -97,7 +96,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, hparams.expert_weights_norm,
-                    false, hparams.expert_weights_scale,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
         cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp
index fbf7b210..42098624 100644
--- a/examples/talk-llama/models/bailingmoe2.cpp
+++ b/examples/talk-llama/models/bailingmoe2.cpp
@@ -1,13 +1,11 @@
 #include "models.h"
 
-
-
 llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -90,7 +88,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, hparams.expert_weights_norm,
-                true, hparams.expert_weights_scale,
+                hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
                 il);
             cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp
index bca0e254..87331791 100644
--- a/examples/talk-llama/models/bert.cpp
+++ b/examples/talk-llama/models/bert.cpp
@@ -1,12 +1,10 @@
 #include "models.h"
 
-
-
 llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -129,9 +127,17 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params
         // feed-forward network
         if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) {
             // MoE branch
-            cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr,
-                                model.layers[il].ffn_down_exps, nullptr, hparams.n_expert, hparams.n_expert_used,
-                                LLM_FFN_GELU, false, false, 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    nullptr,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    hparams.n_expert, hparams.n_expert_used,
+                    LLM_FFN_GELU, false,
+                    hparams.expert_weights_scale,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    il);
             cb(cur, "ffn_moe_out", il);
         } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE ||
                    model.arch == LLM_ARCH_JINA_BERT_V3) {
diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp
index 331a3f11..ccf5bc8e 100644
--- a/examples/talk-llama/models/bitnet.cpp
+++ b/examples/talk-llama/models/bitnet.cpp
@@ -2,9 +2,9 @@
 
 
 llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -29,10 +29,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
         // self-attention
         {
             // compute Q and K and RoPE them
-            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
-            if (model.layers[il].wq_scale) {
-                Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
-            }
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s);
             cb(Qcur, "Qcur", il);
             if (model.layers[il].bq) {
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
@@ -40,10 +37,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
             }
 
             // B1.K
-            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
-            if (model.layers[il].wk_scale) {
-                Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
-            }
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
             cb(Kcur, "Kcur", il);
             if (model.layers[il].bk) {
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
@@ -51,10 +45,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
             }
 
             // B1.V
-            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
-            if (model.layers[il].wv_scale) {
-                Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
-            }
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
             cb(Vcur, "Vcur", il);
             if (model.layers[il].bv) {
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -90,10 +81,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
                     LLM_NORM_RMS, il);
             cb(cur, "attn_sub_norm", il);
 
-            cur = build_lora_mm(model.layers[il].wo, cur);
-            if (model.layers[il].wo_scale) {
-                cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
-            }
+            cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
             if (model.layers[il].bo) {
                 cur = ggml_add(ctx0, cur, model.layers[il].bo);
             }
@@ -115,8 +103,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
         cb(cur, "ffn_norm", il);
 
         cur = build_ffn(cur,
-                model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
-                model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
+                model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_s,
+                model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
                 NULL,                      NULL, NULL,
                 NULL,
                 LLM_FFN_SILU, LLM_FFN_PAR, il);
@@ -127,10 +115,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa
                 LLM_NORM_RMS, il);
         cb(cur, "ffn_sub_norm", il);
 
-        cur = build_lora_mm(model.layers[il].ffn_down, cur);
-        if (model.layers[il].ffn_down_scale) {
-            cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
-        }
+        cur = build_lora_mm(model.layers[il].ffn_down, cur, model.layers[il].ffn_down_s);
         cb(cur, "ffn_down", il);
 
         cur = ggml_add(ctx0, cur, ffn_inp);
diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp
index 2c552d1d..b1c19bb5 100644
--- a/examples/talk-llama/models/bloom.cpp
+++ b/examples/talk-llama/models/bloom.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp
index 184511ae..2f24105f 100644
--- a/examples/talk-llama/models/chameleon.cpp
+++ b/examples/talk-llama/models/chameleon.cpp
@@ -3,10 +3,10 @@
 #include 
 
 llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp
index 2685d4fb..5887ed22 100644
--- a/examples/talk-llama/models/chatglm.cpp
+++ b/examples/talk-llama/models/chatglm.cpp
@@ -2,10 +2,10 @@
 
 
 llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp
index 0b3bdbff..e8e13e14 100644
--- a/examples/talk-llama/models/codeshell.cpp
+++ b/examples/talk-llama/models/codeshell.cpp
@@ -1,11 +1,11 @@
 #include "models.h"
 
 llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp
index 0ceae3aa..2ef2b6e3 100644
--- a/examples/talk-llama/models/cogvlm.cpp
+++ b/examples/talk-llama/models/cogvlm.cpp
@@ -2,11 +2,11 @@
 
 llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const float   kq_scale    = 1.0f / sqrtf(float(n_embd_head));
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * inpL;
     ggml_tensor * cur;
diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2-iswa.cpp
index 9334b5e4..7c71a59a 100644
--- a/examples/talk-llama/models/cohere2-iswa.cpp
+++ b/examples/talk-llama/models/cohere2-iswa.cpp
@@ -1,9 +1,9 @@
 #include "models.h"
 
 llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     const float f_logit_scale = hparams.f_logit_scale;
 
diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp
index 4d3b643b..ba1230f0 100644
--- a/examples/talk-llama/models/command-r.cpp
+++ b/examples/talk-llama/models/command-r.cpp
@@ -4,9 +4,9 @@
 
 llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     const float f_logit_scale = hparams.f_logit_scale;
 
diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp
index 6d2a0ebf..73eb5cd2 100644
--- a/examples/talk-llama/models/dbrx.cpp
+++ b/examples/talk-llama/models/dbrx.cpp
@@ -1,12 +1,11 @@
 #include "models.h"
 
-
 llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -89,7 +88,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp
index 7410a3a4..ac448bfc 100644
--- a/examples/talk-llama/models/deci.cpp
+++ b/examples/talk-llama/models/deci.cpp
@@ -3,10 +3,10 @@
 
 
 llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp
index 17866c0d..3432359e 100644
--- a/examples/talk-llama/models/deepseek.cpp
+++ b/examples/talk-llama/models/deepseek.cpp
@@ -1,13 +1,11 @@
 #include "models.h"
 
-
-
 llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -100,7 +98,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, false,
-                false, hparams.expert_weights_scale,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
             cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp
index ca63a62a..d437fe29 100644
--- a/examples/talk-llama/models/deepseek2.cpp
+++ b/examples/talk-llama/models/deepseek2.cpp
@@ -2,22 +2,19 @@
 
 llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
-    bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
-
-    const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+    const bool is_mla = hparams.is_mla();
 
     // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
-    const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
-    const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+    const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
 
-    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+    const int64_t n_embd_head_qk_rope = hparams.n_rot();
     const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
 
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation.
     // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
 
     // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
@@ -43,11 +40,13 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv();
+    auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
+    auto * inp_attn_k  =  is_mla ? build_attn_inp_k()  : nullptr;
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    for (int il = 0; il < n_layer; ++il) {
+    int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < effective_n_layers; ++il) {
         ggml_tensor * inpSA = inpL;
 
         // norm
@@ -57,6 +56,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
         // self_attention
         {
             ggml_tensor * q = NULL;
+
+            const bool is_lite = model.layers[il].wq;
+
             if (!is_lite) {
                 q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
                 cb(q, "q", il);
@@ -124,14 +126,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
 
                 // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
                 // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
                 cb(Qcur, "Qcur", il);
 
                 kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
                 cb(kv_cmpr, "kv_cmpr_reshape", il);
 
                 // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
-                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
                 cb(Kcur, "Kcur", il);
 
                 // {kv_lora_rank, 1, n_tokens}
@@ -144,8 +146,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                     cb(Qcur, "Qcur_attn_temp_scaled", il);
                 }
 
-                // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
-                cur = build_attn(inp_attn,
+                // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group)
+                cur = build_attn(inp_attn_k,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
             } else {
@@ -169,11 +171,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 Vcur = ggml_cont(ctx0, Vcur);
                 cb(Vcur, "Vcur_cont", il);
 
-                // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(Qcur, "Qcur", il);
 
-                ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(Kcur, "Kcur", il);
 
                 if (inp_attn_scale) {
@@ -183,12 +184,12 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_kv,
                             model.layers[il].wo, NULL,
                             Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             }
         }
-        if (il == n_layer - 1 && inp_out_ids) {
+        if (il == effective_n_layers - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
             inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
         }
@@ -215,9 +216,11 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, hparams.expert_weights_norm,
-                hparams.expert_weights_scale, hparams.expert_weights_scale,
+                hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
-                il);
+                il,
+                nullptr,
+                model.layers[il].ffn_gate_up_exps);
             cb(moe_out, "ffn_moe_out", il);
 
             // FFN shared expert
diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp
new file mode 100644
index 00000000..6bc989c9
--- /dev/null
+++ b/examples/talk-llama/models/delta-net-base.cpp
@@ -0,0 +1,445 @@
+#include "models.h"
+
+#include "llama-impl.h"
+
+// utility to get one slice from the third dimension
+// input dim:  [x, y, c, b]
+// output dim: [x, y, 1, b]
+static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
+    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
+}
+
+llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+std::pair llm_build_delta_net_base::build_delta_net_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+    const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
+
+    const float scale = 1.0f / sqrtf(S_k);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(b, "b_in", il);
+    cb(g, "g_in", il);
+
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+    g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs]
+    b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [  1, n_tokens, H_v, n_seqs]
+
+    const int CS = kda ? 16 : 64; // chunk size
+
+    const int pad = (CS - n_tokens % CS) % CS;
+    const int n_chunks = (n_tokens + pad) / CS;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
+    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
+
+    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
+    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
+
+    cb(v_b, "v_b", il);
+    cb(k_b, "k_b", il);
+
+    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
+    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
+    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
+    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
+    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
+
+    g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1,        CS, n_chunks, H_v * n_seqs);
+
+    // [CS, g_0, n_chunks, H_v * n_seqs]
+    // TODO: extend ggml_cumsum with axis parameter to avoid transpose
+    ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
+    cb(g_cs, "g_cs", il);
+
+    ggml_tensor * kb = nullptr;
+    ggml_tensor * kq = nullptr;
+    if (kda) {
+        const int64_t CHB = n_chunks * H_k * n_seqs;
+
+        ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
+        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB]
+
+        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
+
+        // decay_mask [chunk_size,chunk_size,S_k,CHB]
+        ggml_tensor * decay_mask;
+        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        cb(decay_mask, "decay_mask", il);
+
+        // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
+        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
+
+        ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS,  1, CHB);
+        ggml_tensor * k_j   = ggml_reshape_4d(ctx0, k,   S_k,  1, CS, CHB);
+        ggml_tensor * q_i   = ggml_reshape_4d(ctx0, q,   S_k, CS,  1, CHB);
+
+        ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
+        ggml_tensor * decay_q_i   = ggml_mul(ctx0, decay_mask, q_i);
+
+        // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
+        kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
+        kq = ggml_mul_mat(ctx0, decay_q_i,   k_j);
+
+        kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
+        kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
+    } else {
+        ggml_tensor * g_cs_i = g_cs;
+        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
+
+        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
+
+        // [CS, CS, n_chunks, H_v * n_seqs]
+        ggml_tensor * decay_mask;
+        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        cb(decay_mask, "decay_mask", il);
+
+        // [CS, CS, n_chunks, H_k * n_seqs]
+        kb = ggml_mul_mat(ctx0, k,  k_b);
+        kb = ggml_mul    (ctx0, kb, decay_mask);
+
+        // [CS, CS, n_chunks, H_k * n_seqs]
+        kq = ggml_mul_mat(ctx0, k, q);
+        kq = ggml_mul(ctx0, kq, decay_mask);
+    }
+
+    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
+    cb(kq, "kq", il);
+
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * attn;
+    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
+    cb(attn, "attn", il);
+
+    ggml_tensor * identity;
+    identity = ggml_view_1d(ctx0, attn, CS, 0);
+    identity = ggml_fill   (ctx0, identity, 1.0f);
+    identity = ggml_diag   (ctx0, identity);
+
+    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
+    cb(lhs, "dnet_add_ch_lhs", il);
+
+    attn = ggml_neg(ctx0, attn);
+    cb(attn, "attn_pre_solve", il);
+
+    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn = ggml_add(ctx0, lin_solve, identity);
+    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
+
+    // [S_v, CS, n_chunks, H_v * n_seqs]
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
+
+    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
+
+    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
+
+    // [CS, S_k, n_chunks, H_k * n_seqs]
+    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
+    cb(kbg, "k_beta_g_exp", il);
+
+    // [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
+    cb(k_cd, "k_cumdecay", il);
+
+    // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
+    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
+
+    // vectorized calculation of key_gdiff
+    // improved from the chunked version:
+    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+    //   key_gdiff = key * g_diff.unsqueeze(-1)
+    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+
+    // get last element in g_cumsum along CS dimension (ne0)
+    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
+    // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
+            g_cs->nb[1],
+            g_cs->nb[2],
+            g_cs->nb[3],
+            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
+    cb(g_last, "g_last", il);
+
+    // TODO: remove this cont when CUDA supports non-cont unary ops
+    g_last = ggml_cont(ctx0, g_last);
+
+    // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
+    cb(g_last_exp_t, "g_last_exp_t", il);
+
+    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
+    cb(g_diff, "g_diff", il);
+
+    ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
+
+    // [S_k, CS, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
+    cb(kg, "key_gdiff", il);
+
+    // [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
+    cb(kg_t, "key_gdiff_t", il);
+
+    s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
+    cb(s, "dnet_add_ch_state", il);
+
+    // [CS, S_v, n_chunks, H_v * n_seqs]
+    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
+
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
+        cb(v_t_p, "v_prime", il);
+
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
+        cb(v_t_new, "v_t_new", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
+        cb(v_attn, "v_attn", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
+        cb(attn_inter, "attn_inter", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
+        cb(o_ch, "dnet_add_ch_attn_out", il);
+
+        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
+
+        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+        // TODO: head broadcast might not work here - probably will need a transpose
+        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
+
+        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+        ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
+
+        s = ggml_mul(ctx0, s, ch_g_last_exp_t);
+        s = ggml_add(ctx0, s, kgv);
+        cb(s, "dnet_add_ch_state", il);
+    }
+
+    // truncate padded tokens
+    ggml_tensor * o = ggml_view_4d(ctx0, v,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(v->type, S_v),
+            ggml_row_size(v->type, S_v * CS * n_chunks),
+            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
+    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
+    cb(s, "output_state", il);
+
+    return {o, s};
+}
+
+std::pair llm_build_delta_net_base::build_delta_net_autoregressive(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b, // beta
+        ggml_tensor * s, // state
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
+
+    const float scale = 1.0f / sqrtf(S_k);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(b, "b_in", il);
+    cb(g, "g_in", il);
+
+    // GDA: [1,  1,  H_v, n_seqs]
+    // KDA: [1, S_k, H_v, n_seqs]
+    g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1,        1, H_v, n_seqs);
+
+    // [S_v, S_v, H_v, n_seqs]
+    g = ggml_exp(ctx0, g);
+    s = ggml_mul(ctx0, s, g);
+
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * sk;
+    sk = ggml_mul     (ctx0, s, k);
+    sk = ggml_sum_rows(ctx0, sk);
+
+    // [S_v, 1, H_v, n_seqs]
+    ggml_tensor * d;
+    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
+    d = ggml_mul(ctx0, d, b);
+
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * d_t;
+    d_t = ggml_transpose(ctx0, d);
+
+    // [S_v, S_v, H_v, n_seqs]
+    ggml_tensor * kd;
+    k  = ggml_repeat(ctx0, k, s);
+    kd = ggml_mul   (ctx0, k, d_t);
+
+    s = ggml_add(ctx0, s, kd);
+
+    cb(s, "dnet_add_ar_state", il);
+
+    ggml_tensor * s_q = ggml_mul     (ctx0, s, q);
+    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
+
+    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+
+    return {o, s};
+}
+
+std::pair llm_build_delta_net_base::build_delta_net_fused(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
+
+    ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
+    if (n_tokens == 1) {
+        cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
+    } else {
+        cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
+    }
+
+    ggml_tensor * output = ggml_view_4d(ctx0, result,
+            S_v, H_v, n_tokens, n_seqs,
+            ggml_row_size(result->type, S_v),
+            ggml_row_size(result->type, S_v * H_v),
+            ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
+
+    ggml_tensor * new_state = ggml_view_4d(ctx0, result,
+            S_v, S_v, H_v, n_seqs,
+            ggml_row_size(result->type, S_v),
+            ggml_row_size(result->type, S_v * S_v),
+            ggml_row_size(result->type, S_v * S_v * H_v),
+            ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
+
+    return {output, new_state};
+}
+
+std::pair llm_build_delta_net_base::build_delta_net(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t n_seq_tokens = q->ne[2];
+
+    if (n_seq_tokens == 1) {
+        if (cparams.fused_gdn_ar) {
+            return build_delta_net_fused(q, k, v, g, b, s, il);
+        }
+        return build_delta_net_autoregressive(q, k, v, g, b, s, il);
+    }
+
+    if (cparams.fused_gdn_ch) {
+        return build_delta_net_fused(q, k, v, g, b, s, il);
+    }
+
+    return build_delta_net_chunking(q, k, v, g, b, s, il);
+}
diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp
index 09c36f82..07236dd2 100644
--- a/examples/talk-llama/models/dots1.cpp
+++ b/examples/talk-llama/models/dots1.cpp
@@ -1,13 +1,11 @@
 #include "models.h"
 
-
-
 llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -91,7 +89,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, hparams.expert_weights_norm,
-                true, hparams.expert_weights_scale,
+                hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
                 il);
             cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp
index 2aafbae1..4edc8530 100644
--- a/examples/talk-llama/models/dream.cpp
+++ b/examples/talk-llama/models/dream.cpp
@@ -5,10 +5,10 @@
 llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
     //copied from qwen2
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp
index 0d96d14e..63baf152 100644
--- a/examples/talk-llama/models/ernie4-5-moe.cpp
+++ b/examples/talk-llama/models/ernie4-5-moe.cpp
@@ -1,13 +1,11 @@
 #include "models.h"
 
-
-
 llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -103,7 +101,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const
                                         model.layers[il].ffn_exp_probs_b,
                                         n_expert, n_expert_used,
                                         LLM_FFN_SILU, true,
-                                        false, 0.0,
+                                        hparams.expert_weights_scale,
                                         LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                                         il);
             cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp
index 99aead53..d548de05 100644
--- a/examples/talk-llama/models/ernie4-5.cpp
+++ b/examples/talk-llama/models/ernie4-5.cpp
@@ -2,10 +2,10 @@
 
 llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp
new file mode 100644
index 00000000..e8628d16
--- /dev/null
+++ b/examples/talk-llama/models/eurobert.cpp
@@ -0,0 +1,97 @@
+#include "models.h"
+
+llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "inp_embd", -1);
+
+    auto * inp_attn = build_attn_inp_no_cache();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * cur = inpL;
+
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+
+        {
+            ggml_tensor * Qcur;
+            ggml_tensor * Kcur;
+            ggml_tensor * Vcur;
+
+            Qcur = build_lora_mm(model.layers[il].wq, cur);
+            Kcur = build_lora_mm(model.layers[il].wk, cur);
+            Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            cb(cur, "kqv_out", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+        }
+
+        cur = ggml_add(ctx0, cur, inpL);
+
+        ggml_tensor * ffn_inp = cur;
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up, NULL, NULL,
+                model.layers[il].ffn_gate, NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(cur, "ffn_out", il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_embd", -1);
+    res->t_embd = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp
new file mode 100644
index 00000000..ea75701c
--- /dev/null
+++ b/examples/talk-llama/models/exaone-moe.cpp
@@ -0,0 +1,145 @@
+#include "models.h"
+
+llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_k();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
+    GGML_ASSERT(n_embd_head == n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn_iswa = build_attn_inp_kv_iswa();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < n_transformer_layers; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // use RoPE for SWA layers
+        const bool is_local_layer = hparams.is_swa(il);
+
+        // norm
+        cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self-attention
+        {
+            ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+            Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+            cb(Qcur, "Qcur_normed", il);
+            cb(Kcur, "Kcur_normed", il);
+
+            if (is_local_layer) {
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+                                     freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base,
+                                     freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+            }
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn_iswa,
+                model.layers[il].wo, NULL,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+            cb(cur, "attn_out", il);
+        }
+        if (il == n_transformer_layers - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // norm
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward network
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense branch
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL, NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE branch
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp,
+                model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps,
+                model.layers[il].ffn_down_exps,
+                model.layers[il].ffn_exp_probs_b,
+                n_expert, n_expert_used,
+                LLM_FFN_SILU, hparams.expert_weights_norm,
+                hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // FFN shared expert
+            {
+                ggml_tensor * ffn_shexp =
+                    build_ffn(cur,
+                        model.layers[il].ffn_up_shexp, NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // final norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp
index 62602b28..d4eea58e 100644
--- a/examples/talk-llama/models/exaone.cpp
+++ b/examples/talk-llama/models/exaone.cpp
@@ -4,10 +4,10 @@
 
 llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp
index 8b7e3dc0..755af3b7 100644
--- a/examples/talk-llama/models/exaone4.cpp
+++ b/examples/talk-llama/models/exaone4.cpp
@@ -4,10 +4,10 @@
 template 
 llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_k;
+    const int64_t n_embd_head = hparams.n_embd_head_k();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_v());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp
index b641a094..ff842d93 100644
--- a/examples/talk-llama/models/falcon-h1.cpp
+++ b/examples/talk-llama/models/falcon-h1.cpp
@@ -1,10 +1,8 @@
 #include "models.h"
 
-
-
 llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    llm_build_mamba_base(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp
index db1ccdb5..9fcba508 100644
--- a/examples/talk-llama/models/falcon.cpp
+++ b/examples/talk-llama/models/falcon.cpp
@@ -2,11 +2,11 @@
 
 
 llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp
index 944c198b..98110d45 100644
--- a/examples/talk-llama/models/gemma-embedding.cpp
+++ b/examples/talk-llama/models/gemma-embedding.cpp
@@ -2,7 +2,7 @@
 
 llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_k;
+    const int64_t n_embd_head = hparams.n_embd_head_k();
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp
index 4893d9af..1869efd3 100644
--- a/examples/talk-llama/models/gemma.cpp
+++ b/examples/talk-llama/models/gemma.cpp
@@ -2,7 +2,7 @@
 
 
 llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2-iswa.cpp
index 7a919819..3927ddd2 100644
--- a/examples/talk-llama/models/gemma2-iswa.cpp
+++ b/examples/talk-llama/models/gemma2-iswa.cpp
@@ -1,7 +1,7 @@
 #include "models.h"
 
 llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_k;
+    const int64_t n_embd_head = hparams.n_embd_head_k();
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp
index dec3fc4b..bbb4d9a8 100644
--- a/examples/talk-llama/models/gemma3.cpp
+++ b/examples/talk-llama/models/gemma3.cpp
@@ -2,7 +2,7 @@
 
 template 
 llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_k;
+    const int64_t n_embd_head = hparams.n_embd_head_k();
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n-iswa.cpp
index 93defbee..8ce2ae39 100644
--- a/examples/talk-llama/models/gemma3n-iswa.cpp
+++ b/examples/talk-llama/models/gemma3n-iswa.cpp
@@ -3,7 +3,7 @@
 llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model),
-    n_embd_head(model.hparams.n_embd_head_k),
+    n_embd_head(model.hparams.n_embd_head_k()),
     n_embd_altup(model.hparams.n_embd_altup),
     n_altup(model.hparams.n_altup),
     i_altup_act(model.hparams.i_altup_act) {
@@ -245,12 +245,12 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) {
 // equivalent to get_per_layer_inputs() in python code
 // output shape: [n_embd_altup, n_layer, n_tokens]
 ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
-    auto inp = std::make_unique();
+    auto inp = std::make_unique(n_embd);
     ggml_tensor * inp_per_layer;
     if (ubatch.token) {
         inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
         ggml_set_input(inp->tokens);
-        res->t_tokens = inp->tokens;
+        res->t_inp_tokens = inp->tokens;
         inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
         inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
         inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup));
@@ -258,12 +258,12 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() {
         res->add_input(std::move(inp));
     } else {
         // Vision embedding path: use padding token (ID=0) embedding
+        // TODO: verify if this is the correct behavior in transformers implementation
         const int64_t embd_size = model.tok_embd_per_layer->ne[0];  // n_embd_altup * n_layer
 
-        // Extract and dequantize padding token embedding (column 0)
-        ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
-        ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size);
-        inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32);
+        // Extract and dequantize padding token embedding (row 0)
+        ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0);
+        inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32);
 
         // Reshape to [n_embd_altup, n_layer, 1]
         inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1);
diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp
index 003f70f7..7938545e 100644
--- a/examples/talk-llama/models/glm4-moe.cpp
+++ b/examples/talk-llama/models/glm4-moe.cpp
@@ -1,9 +1,9 @@
 #include "models.h"
 
 llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -128,7 +128,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap
                     model.layers[il].ffn_exp_probs_b,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, hparams.expert_weights_norm,
-                    true, hparams.expert_weights_scale,
+                    hparams.expert_weights_scale,
                     (llama_expert_gating_func_type) hparams.expert_gating_func,
                     il);
             cb(routed_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp
index 204aa393..b6ad8feb 100644
--- a/examples/talk-llama/models/glm4.cpp
+++ b/examples/talk-llama/models/glm4.cpp
@@ -3,10 +3,10 @@
 
 
 llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -29,7 +29,10 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    for (int il = 0; il < n_layer; ++il) {
+    // Only process up to last layer (skip final NextN layer)
+    // Final layer tensors are loaded but not processed in forward pass
+    const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
+    for (int il = 0; il < n_transformer_layers; ++il) {
         ggml_tensor * inpSA = inpL;
 
         // Pre-attention norm
@@ -100,7 +103,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
                     model.layers[il].wo, NULL,
                     Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
         }
-        if (il == n_layer - 1 && inp_out_ids) {
+        if (il == n_transformer_layers - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
             inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
         }
@@ -130,9 +133,13 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params
             cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il);
             cb(cur, "post_mlp_norm", il);
         }
-        // Add residual connection after post-MLP norm
-        inpL = ggml_add(ctx0, cur, ffn_inp);
-        cb(inpL, "l_out", il);
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
     }
     // Final norm
     cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp
index 60761c8e..cb1238f2 100644
--- a/examples/talk-llama/models/gpt2.cpp
+++ b/examples/talk-llama/models/gpt2.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * pos;
diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp
index 2151b14e..1c8fe6c8 100644
--- a/examples/talk-llama/models/gptneox.cpp
+++ b/examples/talk-llama/models/gptneox.cpp
@@ -2,10 +2,10 @@
 
 
 llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp
index f6ca4c17..9b54a38c 100644
--- a/examples/talk-llama/models/granite-hybrid.cpp
+++ b/examples/talk-llama/models/granite-hybrid.cpp
@@ -1,10 +1,9 @@
 #include "models.h"
 
-
 llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    llm_build_mamba_base(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -160,7 +159,7 @@ ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor *       cur,
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp
index 18748e9c..7a7e1664 100644
--- a/examples/talk-llama/models/granite.cpp
+++ b/examples/talk-llama/models/granite.cpp
@@ -1,15 +1,14 @@
 #include "models.h"
 
-
 llm_build_granite::llm_build_granite(
     const llama_model & model,
     const llm_graph_params & params)
     : llm_graph_context(params) {
 
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -175,7 +174,7 @@ ggml_tensor * llm_build_granite::build_layer_ffn(
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp
index 3c54dfee..580d63e3 100644
--- a/examples/talk-llama/models/grok.cpp
+++ b/examples/talk-llama/models/grok.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -99,7 +99,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_GELU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp
index 56b6db9a..aa60d3e9 100644
--- a/examples/talk-llama/models/grovemoe.cpp
+++ b/examples/talk-llama/models/grovemoe.cpp
@@ -1,14 +1,12 @@
 #include "models.h"
 
-
-
 llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t n_embd_head    = hparams.n_embd_head_v;
+    const int64_t n_embd_head    = hparams.n_embd_head_v();
     const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -90,7 +88,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il,
                 probs);
@@ -106,7 +104,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap
                     nullptr,
                     n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il,
                     probs);
diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp
index 7d5dcc78..6a51707c 100644
--- a/examples/talk-llama/models/hunyuan-dense.cpp
+++ b/examples/talk-llama/models/hunyuan-dense.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp
index 77e39de5..806c30b3 100644
--- a/examples/talk-llama/models/hunyuan-moe.cpp
+++ b/examples/talk-llama/models/hunyuan-moe.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -119,8 +119,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll
                 n_expert, n_expert_used,
                 LLM_FFN_SILU,
                 true, // norm_topk_prob
-                false,
-                0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(cur_moe, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp
index 387e8211..441d2502 100644
--- a/examples/talk-llama/models/internlm2.cpp
+++ b/examples/talk-llama/models/internlm2.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp
index 3e3376e6..135bf288 100644
--- a/examples/talk-llama/models/jais.cpp
+++ b/examples/talk-llama/models/jais.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp
new file mode 100644
index 00000000..2cfe484e
--- /dev/null
+++ b/examples/talk-llama/models/jais2.cpp
@@ -0,0 +1,123 @@
+#include "models.h"
+
+// JAIS-2 model graph builder
+// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings
+llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    // KV input for attention
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        // Pre-attention LayerNorm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm,
+                model.layers[il].attn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "attn_norm", il);
+
+        // Self-attention with separate Q, K, V projections
+        {
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+            cb(Qcur, "Qcur_bias", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+            cb(Kcur, "Kcur_bias", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+            cb(Vcur, "Vcur_bias", il);
+
+            // Reshape for attention
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            // Apply RoPE
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+            cb(Qcur, "Qcur_rope", il);
+            cb(Kcur, "Kcur_rope", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+        }
+
+        // Residual connection
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // Pre-FFN LayerNorm
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm,
+                model.layers[il].ffn_norm_b,
+                LLM_NORM, il);
+        cb(cur, "ffn_norm", il);
+
+        // FFN with relu2 activation (ReLU squared) - no gate projection
+        // up -> relu2 -> down
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                NULL, NULL, NULL,  // no gate
+                model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                NULL,
+                LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection
+        inpL = ggml_add(ctx0, cur, ffn_inp);
+        inpL = build_cvec(inpL, il);
+        cb(inpL, "l_out", il);
+    }
+
+    // Final LayerNorm
+    cur = build_norm(inpL,
+            model.output_norm,
+            model.output_norm_b,
+            LLM_NORM, -1);
+    cb(cur, "result_norm", -1);
+
+    res->t_embd = cur;
+
+    // Output projection
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp
index a0187772..c0c89de1 100644
--- a/examples/talk-llama/models/jamba.cpp
+++ b/examples/talk-llama/models/jamba.cpp
@@ -1,7 +1,7 @@
 #include "models.h"
 
-llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -76,7 +76,7 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, false,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
             cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp
new file mode 100644
index 00000000..4d62f4e7
--- /dev/null
+++ b/examples/talk-llama/models/kimi-linear.cpp
@@ -0,0 +1,381 @@
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
+    const int64_t d_inner = head_dim * n_head;
+    const int64_t conv_state_size = (d_conv - 1) * d_inner;
+    const int64_t n_embd_r_total = 3 * conv_state_size;  // Q + K + V
+
+    // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
+    // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
+    // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
+    // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
+    // View Q conv state: offset 0, size conv_state_size per seq
+    // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
+    //   state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
+    // We want [d_conv-1, d_inner, n_seqs] view:
+    //   nb1 = (d_conv-1) * element_size (stride between channels)
+    //   nb2 = n_embd_r_total * element_size (stride between seqs)
+    ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
+        (d_conv - 1) * ggml_element_size(conv_state_all),  // nb1: stride between channels
+        n_embd_r_total * ggml_element_size(conv_state_all),  // nb2: stride between seqs
+        qkv * conv_state_size * ggml_element_size(conv_state_all));
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+    // Step 1: Q, K, V projections -> [d_inner, n_tokens]
+    ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
+
+    // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
+
+    // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
+    ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
+
+    // Save last (d_conv-1) columns back to Q conv state
+    ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+        conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0, last_conv_x,
+            ggml_view_3d(ctx0, conv_states_all,
+                d_conv - 1, d_inner, n_seqs,
+                (d_conv - 1) * ggml_element_size(conv_states_all),           // nb1: contiguous within one channel's conv taps
+                n_embd_r_total * ggml_element_size(conv_states_all),         // nb2: stride between sequences (skip over K,V states)
+                (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));  // offset to first seq's Q/K/V state
+    // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
+    // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
+    // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
+    // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
+    // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
+    // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
+    ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
+
+    // Apply conv1d
+    // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
+    // Reshape to 2D for bias add: {d_inner, n_tokens}
+    Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
+    Xcur = ggml_silu(ctx0, Xcur);
+
+    return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
+}
+
+llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
+    llm_build_delta_net_base(params), model(model) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+    // So we don't need inp_pos
+
+    auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
+    auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
+    auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
+    auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
+    auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
+
+    // Output ids for selecting which tokens to output
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    // Kimi dimension constants
+    const int64_t n_head = hparams.n_head();
+    const int64_t head_dim = hparams.n_embd_head_kda;
+    const int64_t d_conv = hparams.ssm_d_conv;
+    const int64_t d_inner = n_head * head_dim;  // 32 * 128 = 4096
+    const int64_t n_seqs = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // Verify batch consistency for recurrent layers
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // MLA params
+    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
+    const int64_t kv_lora_rank = hparams.n_lora_kv;
+    // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
+    // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
+    const int64_t n_embd_head_qk_rope = hparams.n_rot();  // config.qk_rope_head_dim
+    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;  // 192 - 64 = 128
+    // Attention scale for MLA
+    const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
+
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers[il];
+        ggml_tensor * inpSA = inpL;
+
+        // Attention Norm
+        cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_build_forward_expand(gf, cur);
+
+        if (hparams.is_recurrent(il)) {
+            // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
+            // Reference: vLLM kda.py
+            const auto * mctx_cur = inp_rs->mctx;
+            const auto kv_head = mctx_cur->get_head();
+
+            // Get conv states from r_l tensor (Q, K, V each have separate state)
+            ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+            cb(conv_states_all, "conv_states_all", il);
+            ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
+            ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+
+            // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
+            ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
+            ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
+            cb(g1, "g1 f_b(f_a(cur))", il);
+            g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
+            g1 = ggml_softplus(ctx0, g1);
+            g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
+
+            // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
+            // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
+            ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
+            g1 = ggml_mul(ctx0, g1, A);
+            cb(g1, "kda_g1", il);
+
+            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
+
+            // Compute beta (mixing coefficient)
+            ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
+            beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs);
+            cb(beta, "kda_beta", il);
+
+            beta = ggml_sigmoid(ctx0, beta);
+
+            // Reshape for KDA recurrence
+            // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+            // Get SSM state and compute KDA recurrence using ggml_kda_scan
+            ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+            ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
+            state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
+
+            const float eps_norm = hparams.f_norm_rms_eps;
+
+            Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm);
+            Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
+
+            // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
+            auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il);
+
+            ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
+            ggml_tensor * new_state = attn_out.second;
+            cb(output, "attn_output", il);
+            cb(new_state, "new_state", il);
+
+            // Update the recurrent states
+            ggml_build_forward_expand(gf,
+                                     ggml_cpy(ctx0, new_state,
+                                              ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                           kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+            // Output gating g2 = g_b(g_a(x))
+            ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+            ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
+            ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
+            cb(g2, "g2 g_b(g_a(cur_2d))", il);
+            g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
+
+            // Apply o_norm with sigmoid gating
+            // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
+            // Formula: output = RMSNorm(x) * sigmoid(g)
+            ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head,  n_seq_tokens * n_seqs);
+            ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
+            cb(normed, "kda_normed", il);
+            ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
+            ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
+
+            // Output projection
+            gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
+            cur = ggml_mul_mat(ctx0, layer.wo, gated);
+            cb(cur, "kda_out", il);
+
+        } else {
+            // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
+            // Reference: vLLM mla.py
+            // Step 1: Q projection and reshape
+            // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
+            // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
+
+            // Step 2: KV compression
+            // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
+            ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
+
+            // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
+            ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+            ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
+            // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
+            // k_pe is used directly without RoPE
+            // Normalize kv_c
+            kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
+
+            if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
+                // extract q_nope
+                ggml_tensor * q_nope =
+                    ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                                 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_embd_head_qk_rope, n_head, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(
+                    ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                    ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd_head_qk_nope, n_tokens, n_head}
+                q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+                cb(q_nope, "q_nope_perm", il);
+
+                // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
+                ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
+                cb(q_nope_absorbed, "q_nope_absorbed", il);
+
+                // {kv_lora_rank, n_head, n_tokens}
+                q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+                cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+                // note: rope must go first for in-place context shifting in build_rope_shift()
+                Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
+                cb(Qcur, "Qcur", il);
+
+                kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
+                cb(kv_cmpr, "kv_cmpr_reshape", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
+                cb(Kcur, "Kcur", il);
+
+                // {kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Vcur = kv_cmpr;
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            } else { // MLA KV cache disabled. Fall back to MHA KV cache.
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
+                cb(Qcur, "mla_Q", il);
+                // KV decompression: kv = kv_b_proj(kv_c_normed)
+                ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
+                const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
+
+                // Split kv into k_nope and v
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head), 0);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope));
+                Vcur = ggml_cont(ctx0, Vcur);
+                cb(Vcur, "mla_V", il);
+
+                // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
+                // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
+                // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
+                // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
+                ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
+                ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
+                cb(Kcur, "mla_K", il);
+
+                // Direct softmax attention (with MHA KV cache)
+                // Use build_attn with inp_attn for proper mask handling
+                cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            }
+        }
+
+        // On last layer, select only the output tokens
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur,   inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // FFN Norm
+        cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        if ((uint32_t) il < hparams.n_layer_dense_lead) {
+            // Dense FFN layer
+            cur = build_ffn(cur,
+                layer.ffn_up, NULL, NULL,
+                layer.ffn_gate, NULL, NULL,
+                layer.ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE layer
+            // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                layer.ffn_gate_inp,
+                layer.ffn_up_exps,
+                layer.ffn_gate_exps,
+                layer.ffn_down_exps,
+                layer.ffn_exp_probs_b,
+                hparams.n_expert,
+                hparams.n_expert_used,
+                LLM_FFN_SILU, true,
+                hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // Shared expert
+            {
+                ggml_tensor * ffn_shexp = build_ffn(cur,
+                        layer.ffn_up_shexp, NULL, NULL,
+                        layer.ffn_gate_shexp, NULL, NULL,
+                        layer.ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+        // Residual
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final Norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // Output
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp
index 7f805d78..dfa32216 100644
--- a/examples/talk-llama/models/lfm2.cpp
+++ b/examples/talk-llama/models/lfm2.cpp
@@ -1,18 +1,155 @@
 #include "models.h"
 
+#include "../llama-memory-hybrid-iswa.h"
 #include "../llama-memory-hybrid.h"
 
+template 
+llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+    using inp_hybrid_type = std::conditional_t;
+    using inp_attn_type   = std::conditional_t;
+    using mem_hybrid_ctx  = std::conditional_t;
 
-llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context(params),
-    model(model) {
+    // lambda helpers for readability
+    auto build_dense_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
+        GGML_ASSERT(!model.layers[il].ffn_up_b);
+        GGML_ASSERT(!model.layers[il].ffn_gate_b);
+        GGML_ASSERT(!model.layers[il].ffn_down_b);
+        return build_ffn(cur,
+            model.layers[il].ffn_up, NULL, NULL,
+            model.layers[il].ffn_gate, NULL, NULL,
+            model.layers[il].ffn_down, NULL, NULL,
+            NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+    };
+    auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
+        return build_moe_ffn(cur,
+                model.layers[il].ffn_gate_inp,
+                model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps,
+                model.layers[il].ffn_down_exps,
+                model.layers[il].ffn_exp_probs_b,
+                n_expert, n_expert_used,
+                LLM_FFN_SILU, true,
+                hparams.expert_weights_scale,
+                static_cast(hparams.expert_gating_func),
+                il);
+    };
+    auto build_attn_block = [&model, this](ggml_tensor *   cur,
+                                           ggml_tensor *   inp_pos,
+                                           inp_attn_type * inp_attn,
+                                           int             il) -> ggml_tensor * {
+        GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
+        const auto n_embd_head = hparams.n_embd_head_v();
+        const auto n_head_kv   = hparams.n_head_kv(il);
+
+        auto * q = build_lora_mm(model.layers[il].wq, cur);
+        cb(q, "model.layers.{}.self_attn.q_proj", il);
+        auto * k = build_lora_mm(model.layers[il].wk, cur);
+        cb(k, "model.layers.{}.self_attn.k_proj", il);
+        auto * v = build_lora_mm(model.layers[il].wv, cur);
+        cb(v, "model.layers.{}.self_attn.v_proj", il);
+
+        q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
+        k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
+        v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
+
+        // qk norm
+        q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+        cb(q, "model.layers.{}.self_attn.q_layernorm", il);
+        k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+        cb(k, "model.layers.{}.self_attn.k_layernorm", il);
+
+        // RoPE
+        q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                          attn_factor, beta_fast, beta_slow);
+        k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                          attn_factor, beta_fast, beta_slow);
+
+        cur = build_attn(inp_attn,
+                model.layers[il].wo, NULL,
+                q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+
+        cb(cur, "model.layers.{}.self_attn.out_proj", il);
+
+        return cur;
+    };
+    auto build_shortconv_block = [&model, this](ggml_tensor *        cur,
+                                                llm_graph_input_rs * inp_recr,
+                                                int                  il) -> ggml_tensor * {
+        const auto * mctx_cur = static_cast(mctx)->get_recr();
+        const uint32_t kv_head      = mctx_cur->get_head();
+        const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t  n_seqs       = ubatch.n_seqs;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs());
+        GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+        GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
+        const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
+
+        // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+        cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+        auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
+        cb(bcx, "model.layers.{}.conv.in_proj", il);
+
+        constexpr auto n_chunks = 3;
+        GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
+        const auto chunk_size = bcx->ne[0] / n_chunks;
+        auto *     b          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
+                                             0 * chunk_size * ggml_element_size(bcx));
+        auto *     c          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
+                                             1 * chunk_size * ggml_element_size(bcx));
+        auto *     x          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
+                                             2 * chunk_size * ggml_element_size(bcx));
+
+        auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
+
+        // read conv state
+        auto * conv_state = mctx_cur->get_r_l(il);
+        auto * conv_rs    = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
+        auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
+
+        bx = ggml_concat(ctx0, conv, bx, 0);
+        GGML_ASSERT(bx->ne[0] > conv->ne[0]);
+
+        // last d_conv columns is a new conv state
+        auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2],
+                                       (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
+        GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
+
+        // write new conv conv state
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv,
+                                               ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv),
+                                                            kv_head * d_conv * n_embd * ggml_element_size(new_conv))));
+
+        auto * conv_kernel = model.layers[il].shortconv.conv;
+        auto * conv_out    = ggml_ssm_conv(ctx0, bx, conv_kernel);
+        cb(conv_out, "model.layers.{}.conv.conv", il);
+
+        auto * y = ggml_mul(ctx0, c, conv_out);
+        y        = build_lora_mm(model.layers[il].shortconv.out_proj, y);
+        cb(y, "model.layers.{}.conv.out_proj", il);
+        // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+        y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
+
+        return y;
+    };
+
+    // actual graph construction starts here
     ggml_tensor * cur = build_inp_embd(model.tok_embd);
     cb(cur, "model.embed_tokens", -1);
 
     ggml_build_forward_expand(gf, cur);
 
+    inp_hybrid_type * inp_hybrid = nullptr;
+    if constexpr (iswa) {
+        inp_hybrid = build_inp_mem_hybrid_iswa();
+    } else {
+        inp_hybrid = build_inp_mem_hybrid();
+    }
+
     ggml_tensor * inp_pos     = build_inp_pos();
-    auto *        inp_hybrid  = build_inp_mem_hybrid();
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
     for (int il = 0; il < n_layer; ++il) {
@@ -54,122 +191,6 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
     ggml_build_forward_expand(gf, cur);
 }
 
-ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const {
-    return build_moe_ffn(cur,
-                        model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
-                        model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0,
-                        static_cast(hparams.expert_gating_func), il);
-}
-
-ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const {
-    GGML_ASSERT(!model.layers[il].ffn_up_b);
-    GGML_ASSERT(!model.layers[il].ffn_gate_b);
-    GGML_ASSERT(!model.layers[il].ffn_down_b);
-    return build_ffn(cur,
-        model.layers[il].ffn_up, NULL, NULL,
-        model.layers[il].ffn_gate, NULL, NULL,
-        model.layers[il].ffn_down, NULL, NULL,
-        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
-}
-
-ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor *             cur,
-                                               ggml_tensor *             inp_pos,
-                                               llm_graph_input_attn_kv * inp_attn,
-                                               int                       il) const {
-    GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
-    const auto n_embd_head = hparams.n_embd_head_v;
-    const auto n_head_kv   = hparams.n_head_kv(il);
-
-    auto * q = build_lora_mm(model.layers[il].wq, cur);
-    cb(q, "model.layers.{}.self_attn.q_proj", il);
-    auto * k = build_lora_mm(model.layers[il].wk, cur);
-    cb(k, "model.layers.{}.self_attn.k_proj", il);
-    auto * v = build_lora_mm(model.layers[il].wv, cur);
-    cb(v, "model.layers.{}.self_attn.v_proj", il);
-
-    q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
-    k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
-    v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
-
-    // qk norm
-    q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
-    cb(q, "model.layers.{}.self_attn.q_layernorm", il);
-    k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
-    cb(k, "model.layers.{}.self_attn.k_layernorm", il);
-
-    // RoPE
-    q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                      attn_factor, beta_fast, beta_slow);
-    k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
-                      attn_factor, beta_fast, beta_slow);
-
-    cur = build_attn(inp_attn,
-            model.layers[il].wo, NULL,
-            q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
-
-    cb(cur, "model.layers.{}.self_attn.out_proj", il);
-
-    return cur;
-}
-
-ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) {
-    const auto *   mctx_cur     = static_cast(mctx)->get_recr();
-    const uint32_t kv_head      = mctx_cur->get_head();
-    const int64_t  n_seq_tokens = ubatch.n_seq_tokens;
-    const int64_t  n_seqs       = ubatch.n_seqs;
-    GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(ubatch.equal_seqs());
-    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-    GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
-    const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
-
-    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-    auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
-    cb(bcx, "model.layers.{}.conv.in_proj", il);
-
-    constexpr auto n_chunks = 3;
-    GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
-    const auto chunk_size = bcx->ne[0] / n_chunks;
-    auto *     b          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
-                                         0 * chunk_size * ggml_element_size(bcx));
-    auto *     c          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
-                                         1 * chunk_size * ggml_element_size(bcx));
-    auto *     x          = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
-                                         2 * chunk_size * ggml_element_size(bcx));
-
-    auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
-
-    // read conv state
-    auto * conv_state = mctx_cur->get_r_l(il);
-    auto * conv_rs    = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
-    auto * conv       = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
-
-    bx = ggml_concat(ctx0, conv, bx, 0);
-    GGML_ASSERT(bx->ne[0] > conv->ne[0]);
-
-    // last d_conv columns is a new conv state
-    auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2],
-                                   (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
-    GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
-
-    // write new conv conv state
-    ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv,
-                                           ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv),
-                                                        kv_head * d_conv * n_embd * ggml_element_size(new_conv))));
-
-    auto * conv_kernel = model.layers[il].shortconv.conv;
-    auto * conv_out    = ggml_ssm_conv(ctx0, bx, conv_kernel);
-    cb(conv_out, "model.layers.{}.conv.conv", il);
-
-    auto * y = ggml_mul(ctx0, c, conv_out);
-    y        = build_lora_mm(model.layers[il].shortconv.out_proj, y);
-    cb(y, "model.layers.{}.conv.out_proj", il);
-    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-    y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
-
-    return y;
-}
+// Explicit template instantiations
+template struct llm_build_lfm2;
+template struct llm_build_lfm2;
diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp
index 5f64686f..18de88fd 100644
--- a/examples/talk-llama/models/llada-moe.cpp
+++ b/examples/talk-llama/models/llada-moe.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -90,7 +90,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, false,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp
index 85703366..0dac9d61 100644
--- a/examples/talk-llama/models/llada.cpp
+++ b/examples/talk-llama/models/llada.cpp
@@ -2,10 +2,10 @@
 
 llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     // LLaDA is similar to LLaMA but uses non-causal attention for diffusion
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/llama-iswa.cpp b/examples/talk-llama/models/llama-iswa.cpp
index 61dd2c17..67cb9a10 100644
--- a/examples/talk-llama/models/llama-iswa.cpp
+++ b/examples/talk-llama/models/llama-iswa.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -134,7 +134,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, false,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
                     il);
 
diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp
index 42b5fcdf..e08ae0c0 100644
--- a/examples/talk-llama/models/llama.cpp
+++ b/examples/talk-llama/models/llama.cpp
@@ -2,10 +2,10 @@
 
 template 
 llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -43,19 +43,19 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra
             ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
             // compute Q and K and RoPE them
-            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s);
             cb(Qcur, "Qcur", il);
             if (model.layers[il].bq) {
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
             }
-            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
             cb(Kcur, "Kcur", il);
             if (model.layers[il].bk) {
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
             }
-            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
             cb(Vcur, "Vcur", il);
             if (model.layers[il].bv) {
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -91,6 +91,9 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra
             cur = build_attn(inp_attn,
                     model.layers[il].wo, model.layers[il].bo,
                     Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            if (model.layers[il].wo_s) {
+                cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
+            }
             cb(cur, "attn_out", il);
         }
         if (il == n_layer - 1 && inp_out_ids) {
@@ -109,9 +112,9 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra
             cb(cur, "ffn_norm", il);
 
             cur = build_ffn(cur,
-                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   model.layers[il].ffn_up_s,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, il);
             cb(cur, "ffn_out", il);
@@ -130,9 +133,13 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    il);
+                    il,
+                    nullptr, nullptr,
+                    model.layers[il].ffn_up_exps_s,
+                    model.layers[il].ffn_gate_exps_s,
+                    model.layers[il].ffn_down_exps_s);
             cb(cur, "ffn_moe_out", il);
         }
         cur = ggml_add(ctx0, cur, ffn_inp);
diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp
index da573081..a72b7790 100644
--- a/examples/talk-llama/models/maincoder.cpp
+++ b/examples/talk-llama/models/maincoder.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/graph-context-mamba.cpp b/examples/talk-llama/models/mamba-base.cpp
similarity index 96%
rename from examples/talk-llama/models/graph-context-mamba.cpp
rename to examples/talk-llama/models/mamba-base.cpp
index b9a363b3..9de587db 100644
--- a/examples/talk-llama/models/graph-context-mamba.cpp
+++ b/examples/talk-llama/models/mamba-base.cpp
@@ -1,8 +1,10 @@
 #include "models.h"
 
-llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
+#include "llama-memory-recurrent.h"
 
-ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
+llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
                                                          ggml_tensor *        cur,
                                                          const llama_model &  model,
                                                          const llama_ubatch & ubatch,
@@ -28,6 +30,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in
     GGML_ASSERT(n_seqs != 0);
     GGML_ASSERT(ubatch.equal_seqs());
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(d_inner % n_head == 0);
 
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
@@ -143,7 +146,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in
     return cur;
 }
 
-ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
+ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
                                                           ggml_tensor *        cur,
                                                           const llama_model &  model,
                                                           const llama_ubatch & ubatch,
@@ -165,6 +168,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i
     GGML_ASSERT(n_seqs != 0);
     GGML_ASSERT(ubatch.equal_seqs());
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(d_inner % n_head  == 0);
+    GGML_ASSERT(d_inner % d_state == 0);
+    GGML_ASSERT(d_inner % n_group == 0);
 
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp
index 46819613..55fd2e055 100644
--- a/examples/talk-llama/models/mamba.cpp
+++ b/examples/talk-llama/models/mamba.cpp
@@ -1,7 +1,6 @@
 #include "models.h"
 
-
-llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp
index edc87cc9..06956915 100644
--- a/examples/talk-llama/models/mimo2-iswa.cpp
+++ b/examples/talk-llama/models/mimo2-iswa.cpp
@@ -1,4 +1,3 @@
-
 #include "models.h"
 
 llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
@@ -88,10 +87,17 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_
             cb(cur, "ffn_out", il);
         } else {
             // MoE branch
-            cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
-                                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
-                                model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false,
-                                0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il);
+            cur = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    model.layers[il].ffn_exp_probs_b,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, true,
+                    hparams.expert_weights_scale,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
+                    il);
             cb(cur, "ffn_moe_out", il);
         }
 
diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp
index f374a9fd..89dd7105 100644
--- a/examples/talk-llama/models/minicpm3.cpp
+++ b/examples/talk-llama/models/minicpm3.cpp
@@ -5,10 +5,11 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
     const int64_t n_embd_base = 256;
     const float scale_embd  = 12.0f;
     const float scale_depth = 1.4f;
-    const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
+    const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k()));
+
+    const uint32_t n_embd_head_qk_rope = hparams.n_rot();
+    const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
 
-    const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-    const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     ggml_tensor * cur;
@@ -50,21 +51,21 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
                     LLM_NORM_RMS, il);
             cb(q, "q", il);
 
-            // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+            // {q_lora_rank, n_head * hparams.n_embd_head_k()} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k(), n_tokens}
             q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
             cb(q, "q", il);
 
             // split into {n_head * n_embd_head_qk_nope, n_tokens}
             ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                    ggml_row_size(q->type, hparams.n_embd_head_k),
-                    ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                    ggml_row_size(q->type, hparams.n_embd_head_k()),
+                    ggml_row_size(q->type, hparams.n_embd_head_k() * n_head),
                     0);
             cb(q_nope, "q_nope", il);
 
             // and {n_head * n_embd_head_qk_rope, n_tokens}
             ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                    ggml_row_size(q->type, hparams.n_embd_head_k),
-                    ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                    ggml_row_size(q->type, hparams.n_embd_head_k()),
+                    ggml_row_size(q->type, hparams.n_embd_head_k() * n_head),
                     ggml_row_size(q->type, n_embd_head_qk_nope));
             cb(q_pe, "q_pe", il);
 
@@ -96,15 +97,15 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
 
             // split into {n_head * n_embd_head_qk_nope, n_tokens}
             ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                    ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                    ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()),
+                    ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())),
                     0);
             cb(k_nope, "k_nope", il);
 
             // and {n_head * n_embd_head_v, n_tokens}
-            ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+            ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens,
+                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())),
+                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head),
                     ggml_row_size(kv->type, (n_embd_head_qk_nope)));
             cb(v_states, "v_states", il);
 
diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp
index f7001bad..83d0916c 100644
--- a/examples/talk-llama/models/minimax-m2.cpp
+++ b/examples/talk-llama/models/minimax-m2.cpp
@@ -1,11 +1,10 @@
-
 #include "models.h"
 
 llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    // GGML_ASSERT(n_embd_head == n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -91,7 +90,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_
                 model.layers[il].ffn_exp_probs_b,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, true,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
                 il);
         cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp
index 0b672235..42a5117f 100644
--- a/examples/talk-llama/models/mistral3.cpp
+++ b/examples/talk-llama/models/mistral3.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -127,7 +127,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
             cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h
index 6c40f480..a86b2b1e 100644
--- a/examples/talk-llama/models/models.h
+++ b/examples/talk-llama/models/models.h
@@ -1,23 +1,71 @@
 #pragma once
 
-#include "../llama-model.h"
-#include "../llama-graph.h"
+#include "llama-model.h"
+#include "llama-graph.h"
 
-// TODO: remove in follow-up PR - move to .cpp files
-#include "../llama-memory-recurrent.h"
+// note: almost all graphs require at least sqrtf, so include cmath globally
 #include 
 
-struct llm_graph_context_mamba : public llm_graph_context {
-    llm_graph_context_mamba(const llm_graph_params & params);
+//
+// base classes
+//
 
-    virtual ~llm_graph_context_mamba() = default;
+struct llm_build_mamba_base : public llm_graph_context {
+    llm_build_mamba_base(const llm_graph_params & params);
+
+    virtual ~llm_build_mamba_base() = default;
 
     ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
     ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
 
 };
 
-// Base class for RWKV-related models
+struct llm_build_delta_net_base : public llm_graph_context {
+    llm_build_delta_net_base(const llm_graph_params & params);
+
+    virtual ~llm_build_delta_net_base() = default;
+
+    // returns pair of output and new state
+    std::pair build_delta_net_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
+
+    // returns pair of output and new state
+    std::pair build_delta_net_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                int           il);
+
+    // use the ggml_gated_delta_net fused operator
+    std::pair build_delta_net_fused(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
+
+    // choose one of two implementations above based on the number of tokens
+    std::pair build_delta_net(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
+};
+
 struct llm_build_rwkv6_base : public llm_graph_context {
     const llama_model & model;
 
@@ -58,6 +106,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                                        int                  il) const;
 };
 
+//
+// models
+//
+
 struct llm_build_afmoe : public llm_graph_context {
     llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
 };
@@ -158,6 +210,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
     llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_paddleocr : public llm_graph_context {
+    llm_build_paddleocr(const llama_model & model, const llm_graph_params & params);
+};
+
 template 
 struct llm_build_exaone4 : public llm_graph_context {
     llm_build_exaone4(const llama_model & model, const llm_graph_params & params);
@@ -167,11 +223,15 @@ struct llm_build_exaone : public llm_graph_context {
     llm_build_exaone(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_exaone_moe : public llm_graph_context {
+    llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_falcon : public llm_graph_context {
     llm_build_falcon(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_falcon_h1 : public llm_graph_context_mamba {
+struct llm_build_falcon_h1 : public llm_build_mamba_base {
     llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
 };
 
@@ -249,7 +309,7 @@ private:
         const int                 il);
 };
 
-struct llm_build_granite_hybrid : public llm_graph_context_mamba {
+struct llm_build_granite_hybrid : public llm_build_mamba_base {
     llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
     ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
     ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
@@ -280,19 +340,44 @@ struct llm_build_jais : public llm_graph_context {
     llm_build_jais(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_jamba : public llm_graph_context_mamba {
+struct llm_build_jais2 : public llm_graph_context {
+    llm_build_jais2(const llama_model & model, const llm_graph_params & params);
+};
+
+struct llm_build_jamba : public llm_build_mamba_base {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_lfm2 : public llm_graph_context {
+struct llm_build_kimi_linear : public llm_build_delta_net_base {
+    llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
+
+    std::pair build_kda_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                        int   il);
+
+    std::pair build_kda_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
     const llama_model & model;
+};
 
+template 
+struct llm_build_lfm2 : public llm_graph_context {
     llm_build_lfm2(const llama_model & model, const llm_graph_params & params);
-    ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const;
-    ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const;
-    ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const;
-    ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il);
-
 };
 
 struct llm_build_llada : public llm_graph_context {
@@ -316,7 +401,7 @@ struct llm_build_maincoder : public llm_graph_context {
     llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_mamba : public llm_graph_context_mamba {
+struct llm_build_mamba : public llm_build_mamba_base {
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
 };
 
@@ -348,17 +433,21 @@ struct llm_build_nemotron : public llm_graph_context {
     llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_nemotron_h : public llm_graph_context_mamba {
+struct llm_build_nemotron_h : public llm_build_mamba_base {
     llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
-    ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
+    ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il);
     ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
-        const llama_model & model, const int64_t n_embd_head, const int il);
+        const llama_model & model, int64_t n_embd_head, int il);
 };
 
 struct llm_build_neo_bert : public llm_graph_context {
     llm_build_neo_bert(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_eurobert : public llm_graph_context {
+    llm_build_eurobert(const llama_model & model, const llm_graph_params & params);
+};
+
 template 
 struct llm_build_olmo2 : public llm_graph_context {
     llm_build_olmo2(const llama_model & model, const llm_graph_params & params);
@@ -397,7 +486,7 @@ struct llm_build_phi3 : public llm_graph_context {
     llm_build_phi3(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_plamo2 : public llm_graph_context_mamba {
+struct llm_build_plamo2 : public llm_build_mamba_base {
     llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
     private:
         ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
@@ -445,7 +534,8 @@ struct llm_build_qwen3vl : public llm_graph_context {
 struct llm_build_qwen3vlmoe : public llm_graph_context {
     llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
 };
-struct llm_build_qwen3next : public llm_graph_context_mamba {
+
+struct llm_build_qwen3next : public llm_build_delta_net_base {
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
@@ -457,37 +547,78 @@ private:
     ggml_tensor * build_layer_attn_linear(
          llm_graph_input_rs * inp,
                 ggml_tensor * cur,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
                         int   il);
 
     ggml_tensor * build_layer_ffn(
                 ggml_tensor * cur,
                         int   il);
 
-    // returns pair of output and new state
-    std::pair build_delta_net_chunking(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    // returns pair of qkv, z
+    std::pair build_qkvz(
+                ggml_tensor * input,
                         int   il);
 
-    // returns pair of output and new state
-    std::pair build_delta_net_autoregressive(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                int           il);
+    const llama_model & model;
+};
+
+struct llm_build_qwen35 : public llm_build_delta_net_base {
+    llm_build_qwen35(const llama_model & model, const llm_graph_params & params);
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int * sections,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_norm_gated(
+                ggml_tensor * input,
+                ggml_tensor * weights,
+                ggml_tensor * gate,
+                        int   layer);
+
+    // returns pair of qkv, z
+    std::pair build_qkvz(
+                ggml_tensor * input,
+                        int   il);
+
+    const llama_model & model;
+};
+
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_qwen35moe : public llm_build_delta_net_base {
+    llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params);
+private:
+    ggml_tensor * build_layer_attn(
+    llm_graph_input_attn_kv * inp_attn,
+                ggml_tensor * cur,
+                ggml_tensor * inp_pos,
+                        int * sections,
+                        int   il);
+
+    ggml_tensor * build_layer_attn_linear(
+         llm_graph_input_rs * inp,
+                ggml_tensor * cur,
+                        int   il);
+
+    ggml_tensor * build_layer_ffn(
+                ggml_tensor * cur,
+                        int   il);
 
     ggml_tensor * build_norm_gated(
                 ggml_tensor * input,
@@ -552,6 +683,10 @@ struct llm_build_starcoder : public llm_graph_context {
     llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_step35_iswa : public llm_graph_context {
+    llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_t5_dec : public llm_graph_context {
     llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp
index bb12ed81..26020584 100644
--- a/examples/talk-llama/models/modern-bert.cpp
+++ b/examples/talk-llama/models/modern-bert.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -104,13 +104,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll
             LLM_NORM, -1);
     cb(cur, "final_norm_out", -1);
 
-    if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
-        // extracting cls token
-        cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0);
-        cb(cur, "cls_pooled_embd", -1);
-    }
-
-    cb(cur, "res_embd", -1);
     res->t_embd = cur;
     ggml_build_forward_expand(gf, cur);
 }
diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp
index 2328e027..ce44a805 100644
--- a/examples/talk-llama/models/mpt.cpp
+++ b/examples/talk-llama/models/mpt.cpp
@@ -3,10 +3,10 @@
 
 
 llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * pos;
diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp
index eb135e63..7af99174 100644
--- a/examples/talk-llama/models/nemotron-h.cpp
+++ b/examples/talk-llama/models/nemotron-h.cpp
@@ -1,11 +1,9 @@
 #include "models.h"
 
-
-
 llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    llm_build_mamba_base(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -65,9 +63,9 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_
 ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *             cur,
                                                           llm_graph_input_attn_kv * inp_attn,
                                                           const llama_model &       model,
-                                                          const int64_t             n_embd_head,
-                                                          const int                 il) {
-    // compute Q and K and (optionally) RoPE them
+                                                                int64_t             n_embd_head,
+                                                                int                 il) {
+    // compute Q and K
     ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
     cb(Qcur, "Qcur", il);
     if (model.layers[il].bq) {
@@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
     return cur;
 }
 
-ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
+ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) {
     if (model.layers[il].ffn_gate_inp == nullptr) {
         cur = build_ffn(cur,
                 model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
@@ -116,9 +114,18 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla
                 LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
         cb(cur, "ffn_out", il);
     } else {
-        ggml_tensor * ffn_inp = cur;
+        ggml_tensor * inp_emb    = cur;
+        ggml_tensor * inp_latent = cur;
+
+        if (model.layers[il].ffn_latent_down) {
+            inp_latent = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_down, cur);
+        }
+
+        ggml_tensor * router_logits = build_lora_mm(model.layers[il].ffn_gate_inp, cur);
+        cb(router_logits, "ffn_moe_logits", il);
+
         ggml_tensor * moe_out =
-            build_moe_ffn(ffn_inp,
+            build_moe_ffn(inp_latent,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     nullptr, // no gate
@@ -126,12 +133,17 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla
                     model.layers[il].ffn_exp_probs_b,
                     n_expert, n_expert_used,
                     LLM_FFN_RELU_SQR, hparams.expert_weights_norm,
-                    true, hparams.expert_weights_scale,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID,
-                    il);
+                    il,
+                    router_logits);
         cb(moe_out, "ffn_moe_out", il);
 
-        ggml_tensor * ffn_shexp = build_ffn(ffn_inp,
+        if (model.layers[il].ffn_latent_up) {
+            moe_out = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_up, moe_out);
+        }
+
+        ggml_tensor * ffn_shexp = build_ffn(inp_emb,
                     model.layers[il].ffn_up_shexp,  NULL, NULL,
                     NULL /* no gate */           ,  NULL, NULL,
                     model.layers[il].ffn_down_shexp, NULL, NULL,
diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp
index fcead041..34aa6fa5 100644
--- a/examples/talk-llama/models/nemotron.cpp
+++ b/examples/talk-llama/models/nemotron.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    //GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    //GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp
index 7c32bfca..2fdf4a36 100644
--- a/examples/talk-llama/models/neo-bert.cpp
+++ b/examples/talk-llama/models/neo-bert.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp
index bbd623f1..26f4b6ee 100644
--- a/examples/talk-llama/models/olmo.cpp
+++ b/examples/talk-llama/models/olmo.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp
index 713552da..5076359e 100644
--- a/examples/talk-llama/models/olmo2.cpp
+++ b/examples/talk-llama/models/olmo2.cpp
@@ -2,10 +2,10 @@
 
 template 
 llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp
index b8b6988f..83a56a0b 100644
--- a/examples/talk-llama/models/olmoe.cpp
+++ b/examples/talk-llama/models/olmoe.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -92,7 +92,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SILU, false,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                 il);
         cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe-iswa.cpp
index dbe3ca18..403f130b 100644
--- a/examples/talk-llama/models/openai-moe-iswa.cpp
+++ b/examples/talk-llama/models/openai-moe-iswa.cpp
@@ -95,7 +95,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model,
                 nullptr,
                 n_expert, n_expert_used,
                 LLM_FFN_SWIGLU_OAI_MOE, false,
-                false, 0.0,
+                hparams.expert_weights_scale,
                 LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
                 il);
         cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp
index ee46a337..5df6fe3e 100644
--- a/examples/talk-llama/models/openelm.cpp
+++ b/examples/talk-llama/models/openelm.cpp
@@ -1,9 +1,9 @@
 #include "models.h"
 
 llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -43,7 +43,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_
             ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv));
             cb(Vcur, "Vcur", il);
 
             Qcur = build_norm(Qcur,
diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp
index bb02273b..48c01efe 100644
--- a/examples/talk-llama/models/orion.cpp
+++ b/examples/talk-llama/models/orion.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp
new file mode 100644
index 00000000..340455c2
--- /dev/null
+++ b/examples/talk-llama/models/paddleocr.cpp
@@ -0,0 +1,122 @@
+#include "models.h"
+
+llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context(params) {
+
+    // NOTE: same with qwen2vl.cpp, but bias tensors are optional
+
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        {
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+        }
+        // self-attention
+        {
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+            if (model.layers[il].bq) {
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+            }
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+            if (model.layers[il].bk) {
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+            }
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+            if (model.layers[il].bv) {
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+            }
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_multi(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_multi(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+        if (il == n_layer - 1) {
+            // skip computing output for unused tokens
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        {
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up, NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embedded.cpp
index 664572a5..1cf0938e 100644
--- a/examples/talk-llama/models/pangu-embedded.cpp
+++ b/examples/talk-llama/models/pangu-embedded.cpp
@@ -2,10 +2,10 @@
 
 
 llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp
index 22dbf610..32d40d71 100644
--- a/examples/talk-llama/models/phi2.cpp
+++ b/examples/talk-llama/models/phi2.cpp
@@ -2,10 +2,10 @@
 
 
 llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * attn_norm_output;
diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp
index c8e5da33..3d11a945 100644
--- a/examples/talk-llama/models/phi3.cpp
+++ b/examples/talk-llama/models/phi3.cpp
@@ -2,10 +2,10 @@
 
 template
 llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -114,7 +114,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
             cb(cur, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp
index 04ff709f..b7a71211 100644
--- a/examples/talk-llama/models/plamo.cpp
+++ b/examples/talk-llama/models/plamo.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp
index 31115a08..f02acbc1 100644
--- a/examples/talk-llama/models/plamo2.cpp
+++ b/examples/talk-llama/models/plamo2.cpp
@@ -1,7 +1,9 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -25,7 +27,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa
         cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
 
         // check if this layer is Mamba or Attention
-        bool is_mamba_layer = hparams.is_recurrent(il);
+        const bool is_mamba_layer = hparams.is_recurrent(il);
 
         if (is_mamba_layer) {
             // PLaMo-2 Mamba layer
@@ -104,9 +106,9 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv
         cb(qkv, "wqkv", il);
 
         // split QKV tensor into Q, K, V
-        const int64_t n_embd_head_q = hparams.n_embd_head_k;
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-        const int64_t n_embd_head_v = hparams.n_embd_head_v;
+        const int64_t n_embd_head_q = hparams.n_embd_head_k();
+        const int64_t n_embd_head_k = hparams.n_embd_head_k();
+        const int64_t n_embd_head_v = hparams.n_embd_head_v();
         int32_t       n_head        = hparams.n_head(il);
         int32_t       n_head_kv     = hparams.n_head_kv(il);
 
@@ -169,6 +171,8 @@ ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * in
     GGML_ASSERT(n_seqs != 0);
     GGML_ASSERT(ubatch.equal_seqs());
     GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(d_inner % n_head == 0);
+    GGML_ASSERT(n_group == 0);
 
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp
index 55c80646..32af6e04 100644
--- a/examples/talk-llama/models/plamo3.cpp
+++ b/examples/talk-llama/models/plamo3.cpp
@@ -3,8 +3,8 @@
 template 
 llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    const int64_t head_dim_q = hparams.n_embd_head_k;
-    const int64_t head_dim_v = hparams.n_embd_head_v;
+    const int64_t head_dim_q = hparams.n_embd_head_k();
+    const int64_t head_dim_v = hparams.n_embd_head_v();
 
     ggml_tensor * cur;
     ggml_tensor * inpL = build_inp_embd(model.tok_embd);
diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp
index 481cbba6..bcb651ce 100644
--- a/examples/talk-llama/models/plm.cpp
+++ b/examples/talk-llama/models/plm.cpp
@@ -1,10 +1,11 @@
 #include "models.h"
 
 llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k));
+    const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k()));
+
+    const uint32_t n_embd_head_qk_rope = hparams.n_rot();
+    const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot();
 
-    const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-    const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     ggml_tensor * cur;
@@ -37,15 +38,15 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
 
             // split into {n_head * n_embd_head_qk_nope, n_tokens}
             ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
-                    ggml_row_size(q->type, hparams.n_embd_head_k),
-                    ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                    ggml_row_size(q->type, hparams.n_embd_head_k()),
+                    ggml_row_size(q->type, hparams.n_embd_head_k() * n_head),
                     0);
             cb(q_nope, "q_nope", il);
 
             // and {n_head * n_embd_head_qk_rope, n_tokens}
             ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
-                    ggml_row_size(q->type, hparams.n_embd_head_k),
-                    ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                    ggml_row_size(q->type, hparams.n_embd_head_k()),
+                    ggml_row_size(q->type, hparams.n_embd_head_k() * n_head),
                     ggml_row_size(q->type, n_embd_head_qk_nope));
             cb(q_pe, "q_pe", il);
 
@@ -77,23 +78,23 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
 
             // split into {n_head * n_embd_head_qk_nope, n_tokens}
             ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
-                    ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
-                    ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()),
+                    ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())),
                     0);
             cb(k_nope, "k_nope", il);
 
             // and {n_head * n_embd_head_v, n_tokens}
-            ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
-                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
-                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+            ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens,
+                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())),
+                    ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head),
                     ggml_row_size(kv->type, (n_embd_head_qk_nope)));
             cb(v_states, "v_states", il);
 
             v_states = ggml_cont(ctx0, v_states);
             cb(v_states, "v_states", il);
 
-            v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
-                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+            v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v() * n_head, n_tokens,
+                    ggml_row_size(kv->type, hparams.n_embd_head_v() * n_head),
                     0);
             cb(v_states, "v_states", il);
 
diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp
index 31fd9b73..7390f132 100644
--- a/examples/talk-llama/models/qwen.cpp
+++ b/examples/talk-llama/models/qwen.cpp
@@ -2,9 +2,9 @@
 
 
 llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp
index 3da4dea3..58c10622 100644
--- a/examples/talk-llama/models/qwen2.cpp
+++ b/examples/talk-llama/models/qwen2.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp
index 49142b71..60761789 100644
--- a/examples/talk-llama/models/qwen2moe.cpp
+++ b/examples/talk-llama/models/qwen2moe.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -94,7 +94,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, false,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
         cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp
index 9be38675..9004bab9 100644
--- a/examples/talk-llama/models/qwen2vl.cpp
+++ b/examples/talk-llama/models/qwen2vl.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp
index a5cfffa5..52081668 100644
--- a/examples/talk-llama/models/qwen3.cpp
+++ b/examples/talk-llama/models/qwen3.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -30,13 +30,13 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
         // self-attention
         {
             // compute Q and K and RoPE them
-            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s);
             cb(Qcur, "Qcur", il);
 
-            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
             cb(Vcur, "Vcur", il);
 
             Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
@@ -68,6 +68,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
             cur = build_attn(inp_attn,
                     model.layers[il].wo, model.layers[il].bo,
                     Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            if (model.layers[il].wo_s) {
+                cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
+            }
         }
         if (il == n_layer - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
@@ -83,9 +86,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para
         cb(cur, "ffn_norm", il);
 
         cur = build_ffn(cur,
-                model.layers[il].ffn_up,   NULL, NULL,
-                model.layers[il].ffn_gate, NULL, NULL,
-                model.layers[il].ffn_down, NULL, NULL,
+                model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_s,
+                model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
+                model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s,
                 NULL,
                 LLM_FFN_SILU, LLM_FFN_PAR, il);
         cb(cur, "ffn_out", il);
diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp
new file mode 100644
index 00000000..3108bf33
--- /dev/null
+++ b/examples/talk-llama/models/qwen35.cpp
@@ -0,0 +1,381 @@
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
+    llm_build_delta_net_base(params), model(model) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    cb(inpL, "model.input_embed", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_build_forward_expand(gf, cur);
+
+        // Determine layer type and build appropriate attention mechanism
+        if (hparams.is_recurrent(il)) {
+            // Linear attention layer (gated delta net)
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
+        } else {
+            // Full attention layer
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual connection
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        // Save the tensor before post-attention norm for residual connection
+        ggml_tensor * ffn_residual = cur;
+
+        // Post-attention norm
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        // Dense FFN layer - without residual connection
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_ffn", il);
+
+        // Input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final norm
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // LM head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+std::pair llm_build_qwen35::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s);
+    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+    cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+    ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s);
+    cb(z, "z", il);
+
+    return { qkv_mixed, z };
+}
+
+ggml_tensor * llm_build_qwen35::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen35::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int *                     sections,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
+
+    // Qwen3Next uses a single Q projection that outputs query + gate
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ]
+    cb(Qcur_full, "Qcur_full", il);
+
+    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    // Apply Q normalization
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
+    cb(Vcur, "Vcur", il);
+
+    // Apply K normalization
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+        ggml_element_size(Qcur_full) * n_embd_head);
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    // Apply MRoPE
+    Qcur = ggml_rope_multi(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    Kcur = ggml_rope_multi(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    // Attention computation
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // Input projections
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
+
+    ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    cb(beta, "beta", il);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s);
+    alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
+    cb(alpha, "alpha", il);
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
+    cb(gate, "gate", il);
+
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Get convolution states from cache
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    // Build the convolution states tensor
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    // Calculate convolution kernel size
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    // Update convolution state cache
+    // Extract the last (conv_kernel_size - 1) states from conv_input
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
+
+    // Extract the convolved Q, K, V from conv_output
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
+    cb(q_conv, "q_conv", il);
+    cb(k_conv, "k_conv", il);
+    cb(v_conv, "v_conv", il);
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
+
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
+    // note: need explicit repeat only if we are not using the fused GDN
+    if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
+
+    // Update the recurrent states
+    ggml_build_forward_expand(gf,
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Apply gated normalization: self.norm(core_attn_out, z)
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
+
+    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    // Output projection
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s);
+    cb(cur, "linear_attn_out", il);
+
+    // Reshape back to original dimensions
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Qwen3.5 does not use MoE FFN
+    GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr);
+
+    cur = build_ffn(cur,
+        model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s,
+        model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
+        model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s,
+        NULL,
+        LLM_FFN_SILU, LLM_FFN_PAR, il);
+    cb(cur, "ffn_out", il);
+
+    return cur;
+}
diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp
new file mode 100644
index 00000000..165e2412
--- /dev/null
+++ b/examples/talk-llama/models/qwen35moe.cpp
@@ -0,0 +1,422 @@
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
+    llm_build_delta_net_base(params), model(model) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+    int sections[4];
+    std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    cb(inpL, "model.input_embed", -1);
+
+    auto * inp = build_inp_mem_hybrid();
+
+    ggml_tensor * inp_pos     = build_inp_pos();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        ggml_build_forward_expand(gf, cur);
+
+        // Determine layer type and build appropriate attention mechanism
+        if (hparams.is_recurrent(il)) {
+            // Linear attention layer (gated delta net)
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
+        } else {
+            // Full attention layer
+            cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual connection
+        cur = ggml_add(ctx0, cur, inpSA);
+        cb(cur, "attn_residual", il);
+
+        // Save the tensor before post-attention norm for residual connection
+        ggml_tensor * ffn_residual = cur;
+
+        // Post-attention norm
+        ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
+        cb(attn_post_norm, "attn_post_norm", il);
+
+        // MOE FFN layer
+        cur = build_layer_ffn(attn_post_norm, il);
+        cb(cur, "ffn_out", il);
+
+        // Residual connection for FFN - add to the tensor from before post_attention_layernorm
+        cur = ggml_add(ctx0, cur, ffn_residual);
+        cb(cur, "post_moe", il);
+
+        // Input for next layer
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final norm
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // LM head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+std::pair llm_build_qwen35moe::build_qkvz(
+                ggml_tensor * input,
+                        int   il) {
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s);
+    qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
+    cb(qkv_mixed, "linear_attn_qkv_mixed", il);
+
+    ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s);
+    cb(z, "z", il);
+
+    return { qkv_mixed, z };
+}
+
+ggml_tensor * llm_build_qwen35moe::build_norm_gated(
+        ggml_tensor * input,
+        ggml_tensor * weights,
+        ggml_tensor * gate,
+        int           layer) {
+    ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer);
+    ggml_tensor * gated_silu = ggml_silu(ctx0, gate);
+
+    return ggml_mul(ctx0, normalized, gated_silu);
+}
+
+ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
+        llm_graph_input_attn_kv * inp,
+        ggml_tensor *             cur,
+        ggml_tensor *             inp_pos,
+        int *                     sections,
+        int                       il) {
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+
+    // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
+
+    // Qwen3Next uses a single Q projection that outputs query + gate
+    ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ]
+    cb(Qcur_full, "Qcur_full", il);
+
+    ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0);
+    cb(Qcur, "Qcur_reshaped", il);
+
+    // Apply Q normalization
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
+    ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
+    cb(Kcur, "Kcur", il);
+
+    ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
+    cb(Vcur, "Vcur", il);
+
+    // Apply K normalization
+    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
+
+    ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
+        ggml_element_size(Qcur_full) * n_embd_head * 2,
+        ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head,
+        ggml_element_size(Qcur_full) * n_embd_head);
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+    cb(gate, "gate_reshaped", il);
+
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    // Apply IMRoPE
+    Qcur = ggml_rope_multi(
+            ctx0, Qcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    Kcur = ggml_rope_multi(
+            ctx0, Kcur, inp_pos, nullptr,
+            n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+            ext_factor, attn_factor, beta_fast, beta_slow
+            );
+
+    cb(Qcur, "Qcur", il);
+    cb(Kcur, "Kcur", il);
+    cb(Vcur, "Vcur", il);
+
+    // Attention computation
+    const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+    cur = build_attn(inp,
+                nullptr, nullptr,
+                Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+    cb(cur, "attn_pregate", il);
+
+    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
+    cb(gate_sigmoid, "gate_sigmoid", il);
+
+    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cb(cur, "attn_gated", il);
+
+    cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
+    cb(cur, "attn_output", il);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
+        llm_graph_input_rs * inp,
+        ggml_tensor *        cur,
+        int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const int64_t d_inner      = hparams.ssm_d_inner;
+    const int64_t n_seqs       = ubatch.n_seqs;
+    const int64_t head_k_dim   = hparams.ssm_d_state;
+    const int64_t num_k_heads  = hparams.ssm_n_group;
+    const int64_t num_v_heads  = hparams.ssm_dt_rank;
+    const int64_t head_v_dim   = d_inner / num_v_heads;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // Input projections
+    auto qkvz = build_qkvz(cur, il);
+    ggml_tensor * qkv_mixed = qkvz.first;
+    ggml_tensor * z         = qkvz.second;
+
+    ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s);
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    cb(beta, "beta", il);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s);
+    alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
+    cb(alpha, "alpha", il);
+
+    ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
+    ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
+    cb(alpha_softplus, "a_softplus", il);
+
+    ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
+    cb(gate, "gate", il);
+
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Get convolution states from cache
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    // Build the convolution states tensor
+    ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    cb(conv_states, "conv_states", il);
+
+    // Calculate convolution kernel size
+    ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
+    const int64_t conv_kernel_size = conv_kernel->ne[0];
+    const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+    cb(conv_states, "conv_states_reshaped", il);
+
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
+
+    ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
+    cb(conv_input, "conv_input", il);
+
+    // Update convolution state cache
+    // Extract the last (conv_kernel_size - 1) states from conv_input
+    ggml_tensor * last_conv_states =
+        ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1],
+                     conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input));
+    cb(last_conv_states, "last_conv_states", il);
+
+    ggml_tensor * state_update_target =
+        ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
+                     kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
+    cb(state_update_target, "state_update_target", il);
+
+    ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
+
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
+    ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
+    cb(conv_output_proper, "conv_output_raw", il);
+
+    ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper);
+    cb(conv_output_silu, "conv_output_silu", il);
+
+    ggml_tensor * conv_qkv_mix = conv_output_silu;
+
+    // Calculate the total conv dimension
+    int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads;
+    int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
+
+    // Extract the convolved Q, K, V from conv_output
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
+    cb(q_conv, "q_conv", il);
+    cb(k_conv, "k_conv", il);
+    cb(v_conv, "v_conv", il);
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
+
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // if head keys and value keys are different, repeat to force tensors into matching shapes
+    // note: need explicit repeat only if we are not using the fused GDN
+    if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
+        GGML_ASSERT(num_v_heads % num_k_heads == 0);
+        q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+        k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
+    }
+
+    cb(q_conv, "q_conv_predelta", il);
+    cb(k_conv, "k_conv_predelta", il);
+    cb(v_conv, "v_conv_predelta", il);
+
+    auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
+    ggml_tensor * output    = attn_out.first;
+    ggml_tensor * new_state = attn_out.second;
+    cb(output, "attn_output", il);
+    cb(new_state, "new_state", il);
+
+    // Update the recurrent states
+    ggml_build_forward_expand(gf,
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+    // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+
+    // Apply gated normalization: self.norm(core_attn_out, z)
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
+
+    // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
+    ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
+    cb(final_output, "final_output", il);
+
+    // Output projection
+    cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s);
+    cb(cur, "linear_attn_out", il);
+
+    // Reshape back to original dimensions
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) {
+    // Check if this is an MoE layer
+    GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr);
+
+    ggml_tensor * moe_out =
+        build_moe_ffn(cur,
+            model.layers[il].ffn_gate_inp,
+            model.layers[il].ffn_up_exps,
+            model.layers[il].ffn_gate_exps,
+            model.layers[il].ffn_down_exps,
+            nullptr,
+            n_expert, n_expert_used,
+            LLM_FFN_SILU, true,
+            hparams.expert_weights_scale,
+            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+            nullptr, model.layers[il].ffn_gate_up_exps,
+            model.layers[il].ffn_up_exps_s,
+            model.layers[il].ffn_gate_exps_s,
+            model.layers[il].ffn_down_exps_s);
+    cb(moe_out, "ffn_moe_out", il);
+
+    // Add shared experts if present - following Qwen3Next reference implementation
+    if (model.layers[il].ffn_up_shexp != nullptr) {
+        ggml_tensor * ffn_shexp =
+            build_ffn(cur,
+                model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s,
+                model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s,
+                model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+        cb(ffn_shexp, "ffn_shexp", il);
+
+        // Apply shared expert gating as in the reference implementation
+        // The shared expert has its own gate that is sigmoided
+        // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
+        ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
+        cb(shared_gate, "shared_expert_gate", il);
+
+        // Apply sigmoid to the gate
+        shared_gate = ggml_sigmoid(ctx0, shared_gate);
+        cb(shared_gate, "shared_expert_gate_sigmoid", il);
+
+
+        // Apply the gate to the shared expert output
+        ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
+        cb(ffn_shexp, "ffn_shexp_gated", il);
+
+        cur = ggml_add(ctx0, moe_out, ffn_shexp);
+        cb(cur, "ffn_out", il);
+    } else {
+        cur = moe_out;
+    }
+
+    return cur;
+}
diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp
index 888534fb..dba46618 100644
--- a/examples/talk-llama/models/qwen3moe.cpp
+++ b/examples/talk-llama/models/qwen3moe.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -30,13 +30,13 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
         // self_attention
         {
             // compute Q and K and RoPE them
-            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s);
             cb(Qcur, "Qcur", il);
 
-            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
             cb(Vcur, "Vcur", il);
 
             Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
@@ -68,6 +68,9 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
             cur = build_attn(inp_attn,
                     model.layers[il].wo, model.layers[il].bo,
                     Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            if (model.layers[il].wo_s) {
+                cur = ggml_mul(ctx0, cur, model.layers[il].wo_s);
+            }
         }
         if (il == n_layer - 1 && inp_out_ids) {
             cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
@@ -91,9 +94,13 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                    il);
+                    il,
+                    nullptr, nullptr,
+                    model.layers[il].ffn_up_exps_s,
+                    model.layers[il].ffn_gate_exps_s,
+                    model.layers[il].ffn_down_exps_s);
         cb(moe_out, "ffn_moe_out", il);
         cur = moe_out;
 
diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp
index 57b6659b..cc479dd0 100644
--- a/examples/talk-llama/models/qwen3next.cpp
+++ b/examples/talk-llama/models/qwen3next.cpp
@@ -1,10 +1,9 @@
-#include "ggml.h"
 #include "models.h"
 
-#define CHUNK_SIZE 64
+#include "llama-memory-recurrent.h"
 
 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_build_delta_net_base(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -16,27 +15,18 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_tensor * inp_pos     = build_inp_pos();
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    ggml_tensor * causal_mask =
-        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
-                    GGML_TRI_TYPE_LOWER);
-
-    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
-    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
-
-    ggml_build_forward_expand(gf, causal_mask);
-    ggml_build_forward_expand(gf, identity);
-    ggml_build_forward_expand(gf, diag_mask);
-
     for (int il = 0; il < n_layer; ++il) {
         ggml_tensor * inpSA = inpL;
 
         cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
         cb(cur, "attn_norm", il);
 
+        ggml_build_forward_expand(gf, cur);
+
         // Determine layer type and build appropriate attention mechanism
         if (hparams.is_recurrent(il)) {
             // Linear attention layer (gated delta net)
-            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
         } else {
             // Full attention layer
             cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
@@ -94,348 +84,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t
         t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
 }
 
-std::pair llm_build_qwen3next::build_delta_net_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
-
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
-
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(g, "g_pad", il);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
-
-    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
-
-    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
-
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
-
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
-
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
-
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
-
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
-    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
-    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
-    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
-
-
-    // vectorized calculation of key_gdiff
-    // improved from the chunked version:
-    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-    //   key_gdiff = key * g_diff.unsqueeze(-1)
-    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-    // get last element in g_cumsum along chunk_size dimension (ne0)
-    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
-    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
-                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
-                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
-    g_last = ggml_cont(ctx0, g_last);
-    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
-    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
-    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
-    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
-
-
-    // state to be updated per chunk
-    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
-    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
-
-    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
-    ggml_tensor * core_attn_out = nullptr;
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
-        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
-
-        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
-
-        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
-
-        // shape: (chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
-
-        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        // replaced by precomputed attn_kq
-        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
-        cb(attn_chunk, "attn_chunk", il);
-
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
-
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
-
-        // v_new = v_i - v_prime
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-        cb(v_new, "v_new_chunk", il);
-
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-        cb(attn_inter, "attn_inter_chunk", il);
-
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
-        cb(v_attn, "v_attn_chunk", il);
-
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-
-        core_attn_out = core_attn_out == nullptr
-            ? core_attn_out_chunk
-            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
-
-        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
-        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
-
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
-    }
-
-    // truncate padded tokens
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
-            S_v, n_tokens, H_v, n_seqs,
-            ggml_row_size(core_attn_out->type, S_v),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-    cb(output_tokens, "output_tokens", il);
-
-    // permute back to (S_v, H_v, n_tokens, n_seqs)
-    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-
-    return {output_tokens, new_state};
-}
-
-std::pair llm_build_qwen3next::build_delta_net_autoregressive(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(g, "g_in", il);
-
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
-
-    // Apply exponential to g_t
-    g_t = ggml_exp(ctx0, g_t);
-
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * g_t
-    state = ggml_mul(ctx0, state, g_t);
-
-    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
-    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
-    // we need to sum over dim=-2, so we transpose, sum, then transpose again
-    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
-
-    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
-    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    // delta = (v_t - kv_mem) * beta_t
-    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
-    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
-
-    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
-    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
-    state                   = ggml_add(ctx0, state, k_t_delta);
-
-    // Compute the attention output
-    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
-    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
-    // again, since it's over dim = -2, transpose, sum, transpose back
-    ggml_tensor * core_attn_out =
-        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
-
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
-
-    return {core_attn_out, state};
-}
-
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
         ggml_tensor * input,
         ggml_tensor * weights,
@@ -452,8 +100,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
         ggml_tensor *             cur,
         ggml_tensor *             inp_pos,
         int                       il) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
 
@@ -466,39 +114,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     // Split Q projection into query and gate
     // The split should be along dimension 0 (the feature dimension)
     ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
-                                             Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+                                            Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+    cb(Qcur, "Qcur_view", il);
+
     ggml_tensor * gate =
         ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
                      Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
-    cb(Qcur, "Qcur", il);
     cb(gate, "gate", il);
 
-    // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
-    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-    cb(Qcur, "Qcur_reshaped", il);
-
-    // Apply Q normalization
-    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
-    cb(Qcur, "Qcur_normed", il);
-
     ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
     cb(Kcur, "Kcur", il);
 
     ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
     cb(Vcur, "Vcur", il);
 
-    // Apply K normalization
     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
+
     Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
     cb(Kcur, "Kcur_normed", il);
 
-    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
-    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
-    cb(gate, "gate_reshaped", il);
-
-    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
-
-    // Apply RoPE
     Qcur = ggml_rope_ext(
             ctx0, Qcur, inp_pos, nullptr,
             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -513,7 +151,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     cb(Kcur, "Kcur", il);
     cb(Vcur, "Vcur", il);
 
-    // Attention computation
     const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
     cur = build_attn(inp,
@@ -521,10 +158,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
                 Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
     cb(cur, "attn_pregate", il);
 
-    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
-    cb(gate_sigmoid, "gate_sigmoid", il);
+    // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
 
-    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    gate = ggml_sigmoid(ctx0, gate);
+    cb(gate, "gate_sigmoid", il);
+
+    gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+
+    cur = ggml_mul(ctx0, cur, gate);
     cb(cur, "attn_gated", il);
 
     cur = build_lora_mm(model.layers[il].wo, cur);
@@ -554,7 +196,6 @@ std::pair llm_build_qwen3next::build_qkvz(
         cb(z, "z", il);
 
         return { qkv_mixed, z };
-
     } else {
         // legacy (slower) path
         ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
@@ -618,9 +259,6 @@ std::pair llm_build_qwen3next::build_qkvz(
 ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         llm_graph_input_rs * inp,
         ggml_tensor *        cur,
-        ggml_tensor *        causal_mask,
-        ggml_tensor *        identity,
-        ggml_tensor *        diag_mask,
         int                  il) {
     const auto * mctx_cur = inp->mctx;
 
@@ -665,7 +303,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
                                    split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
     cb(a, "a", il);
 
-    ggml_tensor * beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+    // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
+    b = ggml_cont(ctx0, b);
+
+    ggml_tensor * beta = ggml_sigmoid(ctx0, b);
 
     // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
     ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
@@ -673,15 +314,17 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
     ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
     cb(alpha_softplus, "a_softplus", il);
+
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
     // Get convolution states from cache
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
-    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
-
     // Build the convolution states tensor
     ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
     cb(conv_states, "conv_states", il);
@@ -690,11 +333,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
     const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
-    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
-    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
-    cb(qkv_mixed, "qkv_mixed_permuted", il);
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
 
     ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
     cb(conv_input, "conv_input", il);
@@ -712,9 +356,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(state_update_target, "state_update_target", il);
 
     ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
-    cb(conv_states_all, "conv_states_updated", il);
 
-    // Apply SSM convolution
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
     ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
     cb(conv_output_proper, "conv_output_raw", il);
 
@@ -728,28 +374,39 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
 
     // Extract the convolved Q, K, V from conv_output
-    ggml_tensor * q_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
     cb(q_conv, "q_conv", il);
-    ggml_tensor * k_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
-                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(k_conv, "k_conv", il);
-    ggml_tensor * v_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
-                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(v_conv, "v_conv", il);
 
-    // Unsqueeze them
-    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+    const float eps_norm = hparams.f_norm_rms_eps;
 
-    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
-    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
-    cb(state, "state_predelta", il);
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
+
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
+    // TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
     if (num_k_heads != num_v_heads) {
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
@@ -775,13 +432,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    std::pair attn_out; // pair of (output, new_state)
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
-    }
+    auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
     cb(output, "attn_output", il);
@@ -789,19 +441,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
 
     // Update the recurrent states
     ggml_build_forward_expand(gf,
-                              ggml_cpy(ctx0, new_state,
-                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
-                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
-
-    // Reshape both attn_out_final and z to 2D tensors for normalization
-    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
-    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
     ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
@@ -812,7 +460,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(cur, "linear_attn_out", il);
 
     // Reshape back to original dimensions
-    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
     return cur;
 }
 
@@ -822,18 +471,23 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
         // MoE branch
         ggml_tensor * moe_out =
             build_moe_ffn(cur,
-                model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
-                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                model.layers[il].ffn_gate_inp,
+                model.layers[il].ffn_up_exps,
+                model.layers[il].ffn_gate_exps,
+                model.layers[il].ffn_down_exps,
                 nullptr,
-                n_expert, n_expert_used, LLM_FFN_SILU,
-                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+                n_expert, n_expert_used,
+                LLM_FFN_SILU, true,
+                hparams.expert_weights_scale,
+                LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+                nullptr, model.layers[il].ffn_gate_up_exps);
         cb(moe_out, "ffn_moe_out", il);
 
         // Add shared experts if present - following Qwen3Next reference implementation
         if (model.layers[il].ffn_up_shexp != nullptr) {
             ggml_tensor * ffn_shexp =
                 build_ffn(cur,
-                    model.layers[il].ffn_up_shexp, NULL, NULL,
+                    model.layers[il].ffn_up_shexp,   NULL, NULL,
                     model.layers[il].ffn_gate_shexp, NULL, NULL,
                     model.layers[il].ffn_down_shexp, NULL, NULL,
                     NULL,
@@ -846,11 +500,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
             ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
             cb(shared_gate, "shared_expert_gate", il);
 
-            // Apply sigmoid to the gate
             shared_gate = ggml_sigmoid(ctx0, shared_gate);
             cb(shared_gate, "shared_expert_gate_sigmoid", il);
 
-            // Apply the gate to the shared expert output
             ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
             cb(ffn_shexp, "ffn_shexp_gated", il);
 
diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp
index f72f80a8..195daea6 100644
--- a/examples/talk-llama/models/qwen3vl-moe.cpp
+++ b/examples/talk-llama/models/qwen3vl-moe.cpp
@@ -2,11 +2,12 @@
 
 llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = hparams.n_embd;
-    const int64_t n_embd_head = hparams.n_embd_head_v;
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 
-    std::vector deepstack_features(n_deepstack_layers, nullptr);
-
-    if (ubatch.embd) {
-        // Image input: split main embd and deepstack embds
-        ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
-        for (size_t i = 0; i < n_deepstack_layers; i++) {
-            deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
-        }
-        inpL = inpL_main;
-    }
-
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -109,7 +99,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
         cb(moe_out, "ffn_moe_out", il);
@@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_
         cur = build_cvec(cur, il);
         cb(cur, "l_out", il);
 
-        if (ubatch.embd && (size_t)il < n_deepstack_layers) {
-            cur = ggml_add(ctx0, cur, deepstack_features[il]);
+        if (il < (int) n_deepstack_layers) {
+            ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
+            cur = ggml_add(ctx0, cur, ds);
             cb(cur, "deepstack_out", il);
         }
 
diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp
index 0bae5223..bbd5f42b 100644
--- a/examples/talk-llama/models/qwen3vl.cpp
+++ b/examples/talk-llama/models/qwen3vl.cpp
@@ -2,11 +2,12 @@
 
 llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
     const size_t n_deepstack_layers = hparams.n_deepstack_layers;
-    const int64_t n_embd = hparams.n_embd;
-    const int64_t n_embd_head = hparams.n_embd_head_v;
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    const int64_t n_embd      = hparams.n_embd;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
     int sections[4];
     std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
 
-    std::vector deepstack_features(n_deepstack_layers, nullptr);
-
-    if (ubatch.embd) {
-        // Image input: split main embd and deepstack embds
-        ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0);
-        for (size_t i = 0; i < n_deepstack_layers; i++) {
-            deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float));
-        }
-        inpL = inpL_main;
-    }
-
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
@@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_
         cur = build_cvec(cur, il);
         cb(cur, "l_out", il);
 
-        if (ubatch.embd && (size_t)il < n_deepstack_layers) {
-            cur = ggml_add(ctx0, cur, deepstack_features[il]);
+        if (il < (int) n_deepstack_layers) {
+            ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float));
+            cur = ggml_add(ctx0, cur, ds);
             cb(cur, "deepstack_out", il);
         }
 
diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp
index ff5eb284..140700d9 100644
--- a/examples/talk-llama/models/refact.cpp
+++ b/examples/talk-llama/models/refact.cpp
@@ -1,9 +1,9 @@
 #include "models.h"
 
 llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp
index 46b3dc3e..c8e1f434 100644
--- a/examples/talk-llama/models/rnd1.cpp
+++ b/examples/talk-llama/models/rnd1.cpp
@@ -2,10 +2,10 @@
 
 // RND1 is a Qwen3Moe AR model converted to diffusion model.
 llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -93,7 +93,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     il);
         cb(moe_out, "ffn_moe_out", il);
diff --git a/examples/talk-llama/models/rwkv6-base.cpp b/examples/talk-llama/models/rwkv6-base.cpp
index 7beed2da..83aeab72 100644
--- a/examples/talk-llama/models/rwkv6-base.cpp
+++ b/examples/talk-llama/models/rwkv6-base.cpp
@@ -1,5 +1,7 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model) {}
diff --git a/examples/talk-llama/models/rwkv7-base.cpp b/examples/talk-llama/models/rwkv7-base.cpp
index cda44653..7fcab777 100644
--- a/examples/talk-llama/models/rwkv7-base.cpp
+++ b/examples/talk-llama/models/rwkv7-base.cpp
@@ -1,5 +1,7 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model) {}
diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp
index 0dc33c50..a4d0b75d 100644
--- a/examples/talk-llama/models/seed-oss.cpp
+++ b/examples/talk-llama/models/seed-oss.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp
index 4c497ca7..e2155aac 100644
--- a/examples/talk-llama/models/smallthinker.cpp
+++ b/examples/talk-llama/models/smallthinker.cpp
@@ -2,10 +2,10 @@
 
 template 
 llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
@@ -93,7 +93,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model,
                     nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_RELU, true,
-                    false, 0.0,
+                    hparams.expert_weights_scale,
                     static_cast(hparams.expert_gating_func),
                     il, probs);
 
diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp
index 97c30dee..e267fd8f 100644
--- a/examples/talk-llama/models/smollm3.cpp
+++ b/examples/talk-llama/models/smollm3.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp
index bed1915c..ff5aced9 100644
--- a/examples/talk-llama/models/stablelm.cpp
+++ b/examples/talk-llama/models/stablelm.cpp
@@ -1,9 +1,9 @@
 #include "models.h"
 
 llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp
index e197af4a..941cee98 100644
--- a/examples/talk-llama/models/starcoder.cpp
+++ b/examples/talk-llama/models/starcoder.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp
index e40ef2cb..a5965ace 100644
--- a/examples/talk-llama/models/starcoder2.cpp
+++ b/examples/talk-llama/models/starcoder2.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp
new file mode 100644
index 00000000..176209cd
--- /dev/null
+++ b/examples/talk-llama/models/step35-iswa.cpp
@@ -0,0 +1,165 @@
+#include "models.h"
+
+llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    ggml_tensor * inp_pos     = build_inp_pos();
+    auto        * inp_attn    = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        const uint32_t n_head_l    = hparams.n_head(il);
+        const uint32_t n_head_kv_l = hparams.n_head_kv(il);
+
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // dump pre-attn RMSNorm input to pinpoint layer boundary issues
+        cb(cur, "attn_norm_in", il);
+
+        // self-attention
+        {
+            cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            // Q/K per-head RMSNorm (Step35 q_norm / k_norm)
+            if (model.layers[il].attn_q_norm) {
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+            }
+            if (model.layers[il].attn_k_norm) {
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+            }
+
+            // RoPE (partial rotary factors per layer)
+            const bool is_swa = hparams.is_swa(il);
+            ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
+            const int64_t n_rot_l = hparams.n_rot(il);
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur_pos", il);
+            cb(Kcur, "Kcur_pos", il);
+
+            const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
+            ggml_tensor * attn_out = build_attn(inp_attn,
+                    nullptr, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            cb(attn_out, "attn_out", il);
+            // head-wise attention gate: sigmoid(g_proj(x)) in torch
+            if (model.layers[il].wqkv_gate) {
+                ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
+                cb(gate, "attn_gate", il);
+
+                gate = ggml_sigmoid(ctx0, gate);
+                cb(gate, "attn_gate_sigmoid", il);
+
+                // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
+                ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
+                ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate,       1,          n_head_l, n_tokens);
+                cb(gate_3d, "attn_gate_3d", il);
+
+                attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
+                cb(attn_3d, "attn_gated_3d", il);
+
+                attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
+                cb(attn_out, "attn_gated", il);
+            }
+
+            // output projection
+            cur = build_lora_mm(model.layers[il].wo, attn_out);
+            cb(cur, "attn_proj", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense MLP
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   nullptr,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE routed experts
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    model.layers[il].ffn_exp_probs_b,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, hparams.expert_weights_norm,
+                    hparams.expert_weights_scale,
+                    (llama_expert_gating_func_type) hparams.expert_gating_func,
+                    il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // shared expert MLP (always added on MoE layers in Step35)
+            ggml_tensor * sh_out = build_ffn(cur,
+                    model.layers[il].ffn_up_shexp,   nullptr, nullptr,
+                    model.layers[il].ffn_gate_shexp, nullptr, nullptr,
+                    model.layers[il].ffn_down_shexp, nullptr, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(sh_out, "ffn_shared_out", il);
+
+            cur = ggml_add(ctx0, moe_out, sh_out);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/models/t5-dec.cpp b/examples/talk-llama/models/t5-dec.cpp
index 297e450d..8ca8372b 100644
--- a/examples/talk-llama/models/t5-dec.cpp
+++ b/examples/talk-llama/models/t5-dec.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
     //const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/t5-enc.cpp b/examples/talk-llama/models/t5-enc.cpp
index 70e1d80d..395dfb51 100644
--- a/examples/talk-llama/models/t5-enc.cpp
+++ b/examples/talk-llama/models/t5-enc.cpp
@@ -1,9 +1,9 @@
 #include "models.h"
 
 llm_build_t5_enc::llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp
index 364797dd..3a8dfafc 100644
--- a/examples/talk-llama/models/xverse.cpp
+++ b/examples/talk-llama/models/xverse.cpp
@@ -1,10 +1,10 @@
 #include "models.h"
 
 llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
-    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_head = hparams.n_embd_head_v();
 
-    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-    GGML_ASSERT(n_embd_head == hparams.n_rot);
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
+    GGML_ASSERT(n_embd_head == n_rot);
 
     ggml_tensor * cur;
     ggml_tensor * inpL;
diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp
index b47dcbe6..122c8ca0 100644
--- a/examples/talk-llama/unicode.cpp
+++ b/examples/talk-llama/unicode.cpp
@@ -1,16 +1,10 @@
-#if defined(_MSC_VER)
-#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
-#endif
-
 #include "unicode.h"
 #include "unicode-data.h"
 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -199,27 +193,6 @@ static std::unordered_map unicode_utf8_to_byte_map() {
     return map;
 }
 
-static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
-#if defined(__clang__)
-    // disable C++17 deprecation warning for std::codecvt_utf8
-#    pragma clang diagnostic push
-#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic push
-#    pragma GCC diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
-    std::wstring_convert> conv;
-
-#if defined(__clang__)
-#    pragma clang diagnostic pop
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic pop
-#endif
-
-    return conv.from_bytes(s);
-}
-
 static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) {
     std::vector bpe_encoded_words;
     for (const auto & word : bpe_words) {
@@ -497,49 +470,26 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
-    std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
+template 
+static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) {
+    using BidirIt = typename std::basic_string::const_iterator;
+#ifdef _MSC_VER
+    // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830
+    constexpr auto regex_flags = std::regex_constants::ECMAScript;
+#else
+    constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs;
+#endif
+    std::basic_regex expr(regex, regex_flags);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
     size_t start = 0;
     for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
+        std::regex_iterator it(text.begin() + start, text.begin() + start + offset, expr);
+        std::regex_iterator end;
 
         int64_t start_idx = 0;
         while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t) offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
-    std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::cmatch match = *it;
+            std::match_results match = *it;
             if (match.position() > start_idx) {
                 bpe_offsets.emplace_back(match.position() - start_idx);
             }
@@ -819,6 +769,12 @@ static std::vector unicode_regex_split_custom(const std::string & text,
     } else if (regex_expr == "\\p{AFMoE_digits}") {
         // AFMOE digit pattern - use custom implementation for proper splitting
         bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
+    } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") {
+        // tiny_aya digit grouping pattern from tokenizer.json:
+        //   {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"}
+        // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567)
+        // TODO: Revisit this regex, in case there are any subtle tokenization differences with the original regex.
+        bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
     }
 
     return bpe_offsets;
@@ -1051,10 +1007,10 @@ std::vector unicode_regex_split(const std::string & text, const std
                     break;
                 }
             }
+            const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
 
             if (use_collapsed) {
                 // sanity-check that the original regex does not contain any non-ASCII characters
-                const auto cpts_regex = unicode_cpts_from_utf8(regex_expr);
                 for (size_t i = 0; i < cpts_regex.size(); ++i) {
                     if (cpts_regex[i] >= 128) {
                         throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported");
@@ -1110,7 +1066,7 @@ std::vector unicode_regex_split(const std::string & text, const std
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
             } else {
                 // no unicode category used, we can use std::wregex directly
-                const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
+                std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end());
 
                 // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
                 std::wstring wtext(cpts.begin(), cpts.end());
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 0176ca1c..44e58a52 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -1,10 +1,10 @@
-cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
+cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
 project("ggml" C CXX ASM)
 
 ### GGML Version
 set(GGML_VERSION_MAJOR 0)
 set(GGML_VERSION_MINOR 9)
-set(GGML_VERSION_PATCH 5)
+set(GGML_VERSION_PATCH 7)
 set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
 
 find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
@@ -228,6 +228,8 @@ option(GGML_WEBGPU_CPU_PROFILE              "ggml: enable WebGPU profiling (CPU)
 option(GGML_WEBGPU_GPU_PROFILE              "ggml: enable WebGPU profiling (GPU)"             OFF)
 option(GGML_WEBGPU_JSPI                     "ggml: use JSPI for WebGPU"                       ON)
 option(GGML_ZDNN                            "ggml: use zDNN"                                  OFF)
+option(GGML_VIRTGPU                         "ggml: use the VirtGPU/Virglrenderer API Remoting frontend"     OFF)
+option(GGML_VIRTGPU_BACKEND                 "ggml: build the VirtGPU/Virglrenderer API Remoting backend"    OFF)
 option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT})
 option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF)
 option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF)
@@ -246,12 +248,14 @@ set   (GGML_SYCL_TARGET "INTEL" CACHE STRING
 set   (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
                                             "ggml: sycl device architecture")
 
+option(GGML_OPENVINO                        "ggml: use OPENVINO"                              OFF)
+
 option(GGML_OPENCL                          "ggml: use OpenCL"                                OFF)
 option(GGML_OPENCL_PROFILING                "ggml: use OpenCL profiling (increases overhead)" OFF)
 option(GGML_OPENCL_EMBED_KERNELS            "ggml: embed kernels"                             ON)
 option(GGML_OPENCL_USE_ADRENO_KERNELS       "ggml: use optimized kernels for Adreno"          ON)
 set   (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
-                                            "gmml: OpenCL API version to target")
+                                            "ggml: OpenCL API version to target")
 
 option(GGML_HEXAGON                         "ggml: enable Hexagon backend"                    OFF)
 set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
@@ -320,10 +324,12 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-opt.h
     include/ggml-metal.h
     include/ggml-rpc.h
+    include/ggml-virtgpu.h
     include/ggml-sycl.h
     include/ggml-vulkan.h
     include/ggml-webgpu.h
     include/ggml-zendnn.h
+    include/ggml-openvino.h
     include/gguf.h)
 
 set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
diff --git a/ggml/cmake/BuildTypes.cmake b/ggml/cmake/BuildTypes.cmake
deleted file mode 100644
index a9c7b6c9..00000000
--- a/ggml/cmake/BuildTypes.cmake
+++ /dev/null
@@ -1,54 +0,0 @@
-# Add new build types
-
-# ReleaseGG - Release with enabled asserts
-
-SET(CMAKE_CXX_FLAGS_RELEASEGG
-    "-O3"
-    CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts."
-    FORCE )
-SET(CMAKE_C_FLAGS_RELEASEGG
-    "-O3"
-    CACHE STRING "Flags used by the compiler during release builds with enabled asserts."
-    FORCE )
-SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG
-    ""
-    CACHE STRING "Flags used for linking binaries during release builds with enabled asserts."
-    FORCE )
-SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG
-    ""
-    CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts."
-    FORCE )
-MARK_AS_ADVANCED(
-    CMAKE_CXX_FLAGS_RELEASEGG
-    CMAKE_C_FLAGS_RELEASEGG
-    CMAKE_EXE_LINKER_FLAGS_RELEASEGG
-    CMAKE_SHARED_LINKER_FLAGS_RELEASEGG )
-
-# RelWithDebInfoGG - RelWithDebInfo with enabled asserts
-
-SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
-    "-O2 -g"
-    CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts."
-    FORCE )
-SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG
-    "-O2 -g"
-    CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts."
-    FORCE )
-SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
-    ""
-    CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts."
-    FORCE )
-SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG
-    ""
-    CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts."
-    FORCE )
-MARK_AS_ADVANCED(
-    CMAKE_CXX_FLAGS_RELWITHDEBINFOGG
-    CMAKE_C_FLAGS_RELWITHDEBINFOGG
-    CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG
-    CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG )
-
-if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
-    set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
-    set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG")
-endif()
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index a9d17786..9fd3f7f3 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -259,7 +259,7 @@ extern "C" {
       Example usage:
 
         // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
-        // preferrably to run on the same backend as the buffer
+        // preferably to run on the same backend as the buffer
         ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
 
         sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true);
diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h
index b469e228..74af4653 100644
--- a/ggml/include/ggml-cann.h
+++ b/ggml/include/ggml-cann.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h
index 4f3b99c8..e3e067c9 100644
--- a/ggml/include/ggml-cpu.h
+++ b/ggml/include/ggml-cpu.h
@@ -19,6 +19,9 @@ extern "C" {
         // abort ggml_graph_compute when true
         ggml_abort_callback abort_callback;
         void *              abort_callback_data;
+
+        // use only reference implementations
+        bool use_ref;
     };
 
     // numa strategies
@@ -132,6 +135,8 @@ extern "C" {
     GGML_BACKEND_API void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
     GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
 
+    GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);
+
     GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
 
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);
diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h
new file mode 100644
index 00000000..c43beb07
--- /dev/null
+++ b/ggml/include/ggml-openvino.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include "ggml-backend.h"
+
+#include 
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#define GGML_OPENVINO_NAME "OPENVINO"
+
+// backend API
+GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device);
+
+GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend);
+
+GGML_BACKEND_API bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer);
+
+GGML_BACKEND_API bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft);
+
+GGML_BACKEND_API bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft);
+
+GGML_BACKEND_API size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer);
+
+// device buffer
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device);
+
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device);
+
+GGML_BACKEND_API int ggml_backend_openvino_get_device_count(void);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h
index 4703a05a..1c2ed79b 100644
--- a/ggml/include/ggml-opt.h
+++ b/ggml/include/ggml-opt.h
@@ -138,7 +138,7 @@ extern "C" {
     GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
     GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
 
-    // set gradients to zero, initilize loss, and optionally reset the optimizer
+    // set gradients to zero, initialize loss, and optionally reset the optimizer
     GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
 
     GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h
index df1ad2a5..1c11495b 100644
--- a/ggml/include/ggml-rpc.h
+++ b/ggml/include/ggml-rpc.h
@@ -8,7 +8,12 @@ extern "C" {
 
 #define RPC_PROTO_MAJOR_VERSION    3
 #define RPC_PROTO_MINOR_VERSION    6
-#define RPC_PROTO_PATCH_VERSION    0
+#define RPC_PROTO_PATCH_VERSION    1
+
+#ifdef  __cplusplus
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION");
+#endif
+
 #define GGML_RPC_MAX_SERVERS       16
 
 // backend API
diff --git a/ggml/include/ggml-virtgpu.h b/ggml/include/ggml-virtgpu.h
new file mode 100644
index 00000000..faaba8f2
--- /dev/null
+++ b/ggml/include/ggml-virtgpu.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index b69583dd..25f9601e 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -6,7 +6,7 @@
 // This documentation is still a work in progress.
 // If you wish some specific topics to be covered, feel free to drop a comment:
 //
-//   https://github.com/ggerganov/whisper.cpp/issues/40
+//   https://github.com/ggml-org/whisper.cpp/issues/40
 //
 // ## Overview
 //
@@ -427,7 +427,8 @@ extern "C" {
         // GGML_TYPE_IQ4_NL_4_8 = 37,
         // GGML_TYPE_IQ4_NL_8_8 = 38,
         GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)
-        GGML_TYPE_COUNT   = 40,
+        GGML_TYPE_NVFP4   = 40, // NVFP4 (4 blocks, E4M3 scale)
+        GGML_TYPE_COUNT   = 41,
     };
 
     // precision
@@ -463,6 +464,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
         GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
         GGML_FTYPE_MOSTLY_MXFP4   = 25, // except 1d tensors
+        GGML_FTYPE_MOSTLY_NVFP4   = 26, // except 1d tensors
     };
 
     // available tensor operations:
@@ -556,6 +558,7 @@ extern "C" {
         GGML_OP_GATED_LINEAR_ATTN,
         GGML_OP_RWKV_WKV7,
         GGML_OP_SOLVE_TRI,
+        GGML_OP_GATED_DELTA_NET,
 
         GGML_OP_UNARY,
 
@@ -630,10 +633,11 @@ extern "C" {
 
     // this tensor...
     enum ggml_tensor_flag {
-        GGML_TENSOR_FLAG_INPUT  =  1, // ...is an input for the GGML compute graph
-        GGML_TENSOR_FLAG_OUTPUT =  2, // ...is an output for the GGML compute graph
-        GGML_TENSOR_FLAG_PARAM  =  4, // ...contains trainable parameters
-        GGML_TENSOR_FLAG_LOSS   =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
+        GGML_TENSOR_FLAG_INPUT   =  1, // ...is an input for the GGML compute graph
+        GGML_TENSOR_FLAG_OUTPUT  =  2, // ...is an output for the GGML compute graph
+        GGML_TENSOR_FLAG_PARAM   =  4, // ...contains trainable parameters
+        GGML_TENSOR_FLAG_LOSS    =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
+        GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
     };
 
     enum ggml_tri_type {
@@ -729,10 +733,6 @@ extern "C" {
     GGML_API size_t  ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
     GGML_API size_t  ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
 
-    GGML_DEPRECATED(
-    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
-    "use ggml_row_size() instead");
-
     GGML_API const char * ggml_type_name(enum ggml_type type);
     GGML_API const char * ggml_op_name  (enum ggml_op   op);
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
@@ -751,6 +751,7 @@ extern "C" {
     GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_empty     (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_view      (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_scalar    (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_vector    (const struct ggml_tensor * tensor);
     GGML_API bool ggml_is_matrix    (const struct ggml_tensor * tensor);
@@ -2465,6 +2466,17 @@ extern "C" {
         bool                  lower,
         bool                  uni);
 
+    // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
+    // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
+    GGML_API struct ggml_tensor * ggml_gated_delta_net(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * beta,
+            struct ggml_tensor  * state);
+
     // custom operators
 
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
@@ -2577,11 +2589,42 @@ extern "C" {
         struct ggml_tensor *  grad,
         struct ggml_tensor *  sgd_params); // alpha, weight decay
 
+    // build forward multiple tensors and select one of them for computing
+    // this is useful for creating graphs that have constant topology but compute different things based on the input
+    // ref: https://github.com/ggml-org/llama.cpp/pull/18550
     //
-    // automatic differentiation
+    // nodes:
+    //   | - build forward into the graph but do not compute
+    //   c - build forward into the graph and compute
     //
+    //    |  |  ...  c  ...  |
+    //    |  |  ...  c  ...  |
+    //    |  |  ...  c  ...  |
+    //   [0  1  ... idx ...  n-1]        <-- ggml_build_forward_select(..., n, idx)
+    //               c
+    //               c
+    //
+    // example:
+    //   struct ggml_tensor * curs[3];
+    //
+    //   curs[0]  = compute0(...);
+    //   curs[1]  = compute1(...);
+    //   curs[2]  = compute2(...);
+    //
+    //   int idx = select_branch(some_input);
+    //
+    //   struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
+    //
+    GGML_API struct ggml_tensor * ggml_build_forward_select(
+            struct ggml_cgraph  * cgraph,
+            struct ggml_tensor ** tensors,
+            int                   n_tensors,
+            int                   idx);
+
+    GGML_API void ggml_build_forward_expand(
+            struct ggml_cgraph * cgraph,
+            struct ggml_tensor * tensor);
 
-    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
     GGML_API void ggml_build_backward_expand(
         struct ggml_context *  ctx,        // context for gradient computation
         struct ggml_cgraph  *  cgraph,
@@ -2613,7 +2656,7 @@ extern "C" {
     GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
 
     // dump the graph into a file using the dot format
-    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
 
     // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
     typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 6192a870..78853304 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC)
 endif()
 
 add_library(ggml
+            ggml-backend-dl.cpp
             ggml-backend-reg.cpp)
 add_library(ggml::ggml ALIAS ggml)
 
@@ -451,6 +452,7 @@ ggml_add_backend(HIP)
 ggml_add_backend(METAL)
 ggml_add_backend(MUSA)
 ggml_add_backend(RPC)
+ggml_add_backend(VirtGPU)
 ggml_add_backend(SYCL)
 ggml_add_backend(Vulkan)
 ggml_add_backend(WebGPU)
@@ -458,6 +460,7 @@ ggml_add_backend(zDNN)
 ggml_add_backend(OpenCL)
 ggml_add_backend(Hexagon)
 ggml_add_backend(ZenDNN)
+ggml_add_backend(OPENVINO)
 
 foreach (target ggml-base ggml)
     target_include_directories(${target} PUBLIC    $ $)
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 41419b61..7f414b23 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -17,11 +17,6 @@
 //#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
 #define AT_PRINTF(...)
 
-
-static bool ggml_is_view(const struct ggml_tensor * t) {
-    return t->view_src != NULL;
-}
-
 // ops that return true for this function must not use restrict pointers for their backend implementations
 bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
     GGML_ASSERT(buffer_id >= 0);
     struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
 
-    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
+    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
         hn->allocated = true;
         assert(hn->addr.offset == 0);
 
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
 
                 struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
                 if (p_hn->n_children == 1 && p_hn->n_views == 0) {
-                    if (ggml_is_view(parent)) {
+                    if (ggml_impl_is_view(parent)) {
                         struct ggml_tensor * view_src = parent->view_src;
                         struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
                         if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
         // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
         // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
         // itself is never used and should not be considered a dependency
-        if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
+        if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
             struct ggml_tensor * view_src = node->view_src;
             ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
         }
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
                 parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
 
             if (p_hn->n_children == 0 && p_hn->n_views == 0) {
-                if (ggml_is_view(parent)) {
+                if (ggml_impl_is_view(parent)) {
                     struct ggml_tensor * view_src = parent->view_src;
                     struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
                     view_src_hn->n_views -= 1;
diff --git a/ggml/src/ggml-backend-dl.cpp b/ggml/src/ggml-backend-dl.cpp
new file mode 100644
index 00000000..a65cf009
--- /dev/null
+++ b/ggml/src/ggml-backend-dl.cpp
@@ -0,0 +1,48 @@
+#include "ggml-backend-dl.h"
+
+#ifdef _WIN32
+
+dl_handle * dl_load_library(const fs::path & path) {
+    // suppress error dialogs for missing DLLs
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    HMODULE handle = LoadLibraryW(path.wstring().c_str());
+
+    SetErrorMode(old_mode);
+
+    return handle;
+}
+
+void * dl_get_sym(dl_handle * handle, const char * name) {
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    void * p = (void *) GetProcAddress(handle, name);
+
+    SetErrorMode(old_mode);
+
+    return p;
+}
+
+const char * dl_error() {
+    return "";
+}
+
+#else
+
+dl_handle * dl_load_library(const fs::path & path) {
+    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
+    return handle;
+}
+
+void * dl_get_sym(dl_handle * handle, const char * name) {
+    return dlsym(handle, name);
+}
+
+const char * dl_error() {
+    const char *rslt = dlerror();
+    return rslt != nullptr ? rslt : "";
+}
+
+#endif
diff --git a/ggml/src/ggml-backend-dl.h b/ggml/src/ggml-backend-dl.h
new file mode 100644
index 00000000..f74b7c94
--- /dev/null
+++ b/ggml/src/ggml-backend-dl.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#ifdef _WIN32
+#   define WIN32_LEAN_AND_MEAN
+#   ifndef NOMINMAX
+#       define NOMINMAX
+#   endif
+#   include 
+#   include 
+#else
+#    include 
+#    include 
+#endif
+#include 
+
+namespace fs = std::filesystem;
+
+#ifdef _WIN32
+
+using dl_handle = std::remove_pointer_t;
+
+struct dl_handle_deleter {
+    void operator()(HMODULE handle) {
+        FreeLibrary(handle);
+    }
+};
+
+#else
+
+using dl_handle = void;
+
+struct dl_handle_deleter {
+    void operator()(void * handle) {
+        dlclose(handle);
+    }
+};
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr;
+
+dl_handle * dl_load_library(const fs::path & path);
+void * dl_get_sym(dl_handle * handle, const char * name);
+const char * dl_error();
+
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 4181a714..05871092 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -1,5 +1,6 @@
 #include "ggml-backend-impl.h"
 #include "ggml-backend.h"
+#include "ggml-backend-dl.h"
 #include "ggml-impl.h"
 #include 
 #include 
@@ -69,6 +70,10 @@
 #include "ggml-rpc.h"
 #endif
 
+#ifdef GGML_USE_VIRTGPU_FRONTEND
+#include "ggml-virtgpu.h"
+#endif
+
 #ifdef GGML_USE_CANN
 #include "ggml-cann.h"
 #endif
@@ -77,105 +82,27 @@
 #include "ggml-zendnn.h"
 #endif
 
-// disable C++17 deprecation warning for std::codecvt_utf8
-#if defined(__clang__)
-#    pragma clang diagnostic push
-#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic push
-#    pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+#ifdef GGML_USE_OPENVINO
+#include "ggml-openvino.h"
 #endif
 
 namespace fs = std::filesystem;
 
 static std::string path_str(const fs::path & path) {
-    std::string u8path;
     try {
 #if defined(__cpp_lib_char8_t)
         // C++20 and later: u8string() returns std::u8string
-        std::u8string u8str = path.u8string();
-        u8path = std::string(reinterpret_cast(u8str.c_str()));
+        const std::u8string u8str = path.u8string();
+        return std::string(reinterpret_cast(u8str.data()), u8str.size());
 #else
         // C++17: u8string() returns std::string
-        u8path = path.u8string();
+        return path.u8string();
 #endif
     } catch (...) {
+        return std::string();
     }
-    return u8path;
 }
 
-#if defined(__clang__)
-#    pragma clang diagnostic pop
-#elif defined(__GNUC__)
-#    pragma GCC diagnostic pop
-#endif
-
-#ifdef _WIN32
-
-using dl_handle = std::remove_pointer_t;
-
-struct dl_handle_deleter {
-    void operator()(HMODULE handle) {
-        FreeLibrary(handle);
-    }
-};
-
-static dl_handle * dl_load_library(const fs::path & path) {
-    // suppress error dialogs for missing DLLs
-    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
-    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
-
-    HMODULE handle = LoadLibraryW(path.wstring().c_str());
-
-    SetErrorMode(old_mode);
-
-    return handle;
-}
-
-static void * dl_get_sym(dl_handle * handle, const char * name) {
-    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
-    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
-
-    void * p = (void *) GetProcAddress(handle, name);
-
-    SetErrorMode(old_mode);
-
-    return p;
-}
-
-static const char * dl_error() {
-    return "";
-}
-
-#else
-
-using dl_handle = void;
-
-struct dl_handle_deleter {
-    void operator()(void * handle) {
-        dlclose(handle);
-    }
-};
-
-static void * dl_load_library(const fs::path & path) {
-    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
-
-    return handle;
-}
-
-static void * dl_get_sym(dl_handle * handle, const char * name) {
-    return dlsym(handle, name);
-}
-
-static const char * dl_error() {
-    const char *rslt = dlerror();
-    return rslt != nullptr ? rslt : "";
-}
-
-#endif
-
-using dl_handle_ptr = std::unique_ptr;
-
 struct ggml_backend_reg_entry {
     ggml_backend_reg_t reg;
     dl_handle_ptr handle;
@@ -196,7 +123,12 @@ struct ggml_backend_registry {
         register_backend(ggml_backend_sycl_reg());
 #endif
 #ifdef GGML_USE_VULKAN
+    // Add runtime disable check
+    if (getenv("GGML_DISABLE_VULKAN") == nullptr) {
         register_backend(ggml_backend_vk_reg());
+    } else {
+        GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n");
+    }
 #endif
 #ifdef GGML_USE_WEBGPU
         register_backend(ggml_backend_webgpu_reg());
@@ -204,6 +136,10 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_ZDNN
         register_backend(ggml_backend_zdnn_reg());
 #endif
+#ifdef GGML_USE_VIRTGPU_FRONTEND
+        register_backend(ggml_backend_virtgpu_reg());
+#endif
+
 #ifdef GGML_USE_OPENCL
         register_backend(ggml_backend_opencl_reg());
 #endif
@@ -222,6 +158,9 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_RPC
         register_backend(ggml_backend_rpc_reg());
 #endif
+#ifdef GGML_USE_OPENVINO
+        register_backend(ggml_backend_openvino_reg());
+#endif
 #ifdef GGML_USE_CPU
         register_backend(ggml_backend_cpu_reg());
 #endif
@@ -539,9 +478,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
 
     int best_score = 0;
     fs::path best_path;
+    std::error_code ec;
 
     for (const auto & search_path : search_paths) {
-        if (std::error_code ec; !fs::exists(search_path, ec)) {
+        if (!fs::exists(search_path, ec)) {
             if (ec) {
                 GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
             } else {
@@ -551,7 +491,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
         }
         fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
         for (const auto & entry : dir_it) {
-            if (entry.is_regular_file()) {
+            if (entry.is_regular_file(ec)) {
                 auto filename = entry.path().filename();
                 auto ext = entry.path().extension();
                 if (filename.native().find(file_prefix) == 0 && ext == file_extension) {
@@ -620,9 +560,11 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     ggml_backend_load_best("rpc", silent, dir_path);
     ggml_backend_load_best("sycl", silent, dir_path);
     ggml_backend_load_best("vulkan", silent, dir_path);
+    ggml_backend_load_best("virtgpu", silent, dir_path);
     ggml_backend_load_best("opencl", silent, dir_path);
     ggml_backend_load_best("hexagon", silent, dir_path);
     ggml_backend_load_best("musa", silent, dir_path);
+    ggml_backend_load_best("openvino", silent, dir_path);
     ggml_backend_load_best("cpu", silent, dir_path);
     // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
     const char * backend_path = std::getenv("GGML_BACKEND_PATH");
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 1b59924b..22c65699 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -258,6 +258,7 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
 
     if (backend->iface.set_tensor_async == NULL) {
+        ggml_backend_synchronize(backend);
         ggml_backend_tensor_set(tensor, data, offset, size);
     } else {
         backend->iface.set_tensor_async(backend, tensor, data, offset, size);
@@ -271,6 +272,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
     GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
 
     if (backend->iface.get_tensor_async == NULL) {
+        ggml_backend_synchronize(backend);
         ggml_backend_tensor_get(tensor, data, offset, size);
     } else {
         backend->iface.get_tensor_async(backend, tensor, data, offset, size);
@@ -874,9 +876,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
         }
         if (sched->debug > 1) {
             ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
-            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
+            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
                 fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
-                graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
+                graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
             for (int j = 0; j < GGML_MAX_SRC; j++) {
                 struct ggml_tensor * src = node->src[j];
                 if (src == NULL) {
@@ -1922,6 +1924,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
         dst->view_offs = src->view_offs;
     }
     dst->op = src->op;
+    dst->flags = src->flags;
     memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
     ggml_set_name(dst, src->name);
 
diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt
index fb0936f4..c27dc174 100644
--- a/ggml/src/ggml-blas/CMakeLists.txt
+++ b/ggml/src/ggml-blas/CMakeLists.txt
@@ -93,7 +93,7 @@ if (BLAS_FOUND)
     endif()
 
     target_link_libraries     (ggml-blas PRIVATE ${BLAS_LIBRARIES})
-    target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
+    target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS})
 else()
     message(FATAL_ERROR "BLAS not found, please refer to "
                         "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp
index 84956cbb..5de64b81 100644
--- a/ggml/src/ggml-blas/ggml-blas.cpp
+++ b/ggml/src/ggml-blas/ggml-blas.cpp
@@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         switch (node->op) {
             case GGML_OP_MUL_MAT:
                 ggml_backend_blas_mul_mat(ctx, node);
@@ -335,8 +339,8 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t
 }
 
 static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    // TODO
-    *free = 0;
+    // no memory to report
+    *free  = 0;
     *total = 0;
 
     GGML_UNUSED(dev);
diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp
index 7b7042a1..e95d3c4d 100644
--- a/ggml/src/ggml-cann/acl_tensor.cpp
+++ b/ggml/src/ggml-cann/acl_tensor.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h
index 7deac383..4737773a 100644
--- a/ggml/src/ggml-cann/acl_tensor.h
+++ b/ggml/src/ggml-cann/acl_tensor.h
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 6b718e01..fc7c3e3b 100644
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
@@ -58,6 +58,7 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -2338,20 +2339,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
 
     // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
     // TODO: acl_yarn_ramp_tensor use rope cache.
-    bool                 yarn_ramp_tensor_updated = false;
-    acl_tensor_ptr       acl_yarn_ramp_tensor;
+    bool           yarn_ramp_tensor_updated = false;
+    acl_tensor_ptr acl_yarn_ramp_tensor;
     if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
                             ctx.rope_cache.freq_scale != freq_scale)) {
         yarn_ramp_tensor_updated = true;
         if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
             ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
         }
-        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
+        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),
+                              ACL_MEM_MALLOC_HUGE_FIRST));
         // -rope_yarn_ramp
         // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
         // return MIN(1, MAX(0, y)) - 1;
-        acl_yarn_ramp_tensor =
-            ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+        acl_yarn_ramp_tensor      = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
+                                                            theta_scale_ne, theta_scale_nb, 1);
         float          zero_value = 0, one_value = 1;
         float          denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
         acl_scalar_ptr low              = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
     } else {
-        acl_yarn_ramp_tensor =
-            ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+        acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
+                                                       theta_scale_ne, theta_scale_nb, 1);
     }
     // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
     if (ext_factor != 0) {
@@ -2991,20 +2993,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
 }
 
-void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_tensor * src0 = dst->src[0];
     ggml_tensor * src1 = dst->src[1];
 
     // stride
-    int64_t s0 = ((const int32_t*)(dst->op_params))[0];
+    int64_t s0 = ((const int32_t *) (dst->op_params))[0];
 
-    acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_input  = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
     acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
-    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_dst    = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
 
     // get base information of input and kernel
-    int64_t input_len = *(src1->ne);
-    int64_t dst_len = *(dst->ne);
+    int64_t input_len   = *(src1->ne);
+    int64_t dst_len     = *(dst->ne);
     int64_t kernel_size = *(src0->ne);
 
     // set the max kernel size for each conv
@@ -3012,56 +3014,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
 
     // compute the partition of kernel
     int64_t part_num = 1;
-    part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
+    part_num         = (kernel_size + max_kernel_size - 1) / max_kernel_size;
 
     int64_t strideVal[1];
-    strideVal[0] = s0;
-    acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
-    int64_t paddingVal[] = {0};
-    acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
-    int64_t dilationVal[] = {1};
-    acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
-    bool transposed = true;
-    int64_t groups = 1;
-    int8_t cubeMathType = 0;
+    strideVal[0]                    = s0;
+    acl_int_array_ptr stride        = ggml_cann_create_int_array(strideVal, 1);
+    int64_t           paddingVal[]  = { 0 };
+    acl_int_array_ptr padding       = ggml_cann_create_int_array(paddingVal, 1);
+    int64_t           dilationVal[] = { 1 };
+    acl_int_array_ptr dilation      = ggml_cann_create_int_array(dilationVal, 1);
+    bool              transposed    = true;
+    int64_t           groups        = 1;
+    int8_t            cubeMathType  = 0;
 
 #ifdef ASCEND_310P
     cubeMathType = 1;
 #endif
 
     auto weight_type = ggml_cann_type_mapping(src0->type);
-    auto dst_type = ggml_cann_type_mapping(dst->type);
+    auto dst_type    = ggml_cann_type_mapping(dst->type);
 
     // slice the kernel to make each conv available
-    int64_t slice_dim = -1;
+    int64_t slice_dim   = -1;
     int64_t slice_start = 0;
-    int64_t slice_end = max_kernel_size;
-    int64_t slice_step = 1;
-    int64_t interval = max_kernel_size;
+    int64_t slice_end   = max_kernel_size;
+    int64_t slice_step  = 1;
+    int64_t interval    = max_kernel_size;
 
-    int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
+    int64_t left_pad_len  = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
     int64_t right_pad_len = 0;
 
-    acl_scalar_ptr alpha = nullptr;
-    float alphaValue = 1.0;
-    alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
+    acl_scalar_ptr alpha      = nullptr;
+    float          alphaValue = 1.0;
+    alpha                     = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
 
     // set zero to destination
     GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
 
-    for(int k = 0; k < part_num; k++){
-
+    for (int k = 0; k < part_num; k++) {
         // create part kernel tensor and slice from big kernel
         slice_start = max_kernel_size * k;
-        if(k == part_num - 1){
+        if (k == part_num - 1) {
             slice_end = kernel_size;
-            interval = kernel_size - max_kernel_size * k;
-        }else{
-            slice_end = max_kernel_size * (k+1);
+            interval  = kernel_size - max_kernel_size * k;
+        } else {
+            slice_end = max_kernel_size * (k + 1);
         }
 
         int64_t part_ne[4];
-        for(int i = 0; i < 4; i++) {
+        for (int i = 0; i < 4; i++) {
             part_ne[i] = *(src0->ne + i);
         }
         part_ne[0] = interval;
@@ -3074,16 +3075,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
 
         ggml_cann_pool_alloc part_kernel_allocator;
         part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
-        void* part_kernel_buf = part_kernel_allocator.get();
+        void * part_kernel_buf = part_kernel_allocator.get();
 
-        acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
-                                ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
+        acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),
+                                                             part_ne, part_nb, 3, ACL_FORMAT_NCL);
 
-        GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,
+                                part_kernel.get());
 
         // create the part conv result tensor
         int64_t part_dst_ne[4];
-        for(int i = 0; i < 4; i++){
+        for (int i = 0; i < 4; i++) {
             part_dst_ne[i] = *(dst->ne + i);
         }
         part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
@@ -3095,32 +3097,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
         }
         ggml_cann_pool_alloc part_dst_allocator;
         part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
-        void* part_dst_buf = part_dst_allocator.get();
+        void * part_dst_buf = part_dst_allocator.get();
 
         acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
-                                    part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
+                                                              part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
 
         // compute part conv transpose 1d
         GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
-        padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
+                                padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),
+                                cubeMathType);
 
         // compute the position of part result in final result
         int64_t global_start = slice_start;
-        int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
+        int64_t global_end   = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
 
-        left_pad_len = global_start;
+        left_pad_len  = global_start;
         right_pad_len = dst_len - global_end;
 
-        std::vector padDataVal = {left_pad_len,right_pad_len};
-        acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
+        std::vector padDataVal = { left_pad_len, right_pad_len };
+        acl_int_array_ptr    padData    = ggml_cann_create_int_array(padDataVal.data(), 2);
 
-        acl_scalar_ptr pad_value = nullptr;
-        float pad_valueVal = 0.0;
-        pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
+        acl_scalar_ptr pad_value    = nullptr;
+        float          pad_valueVal = 0.0;
+        pad_value                   = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
 
         int64_t conv_result_ne[4];
-        for(int i = 0; i < 4; i++){
+        for (int i = 0; i < 4; i++) {
             conv_result_ne[i] = *(dst->ne + i);
         }
 
@@ -3132,13 +3135,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
 
         ggml_cann_pool_alloc conv_result_allocator;
         conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
-        void* conv_result_buf = conv_result_allocator.get();
+        void * conv_result_buf = conv_result_allocator.get();
 
         acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
-                                    conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
+                                                             conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
 
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
-        GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),
+                                conv_result.get());
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
     }
 }
@@ -3282,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
 }
 
 /**
- * @brief Performs expert-specific matrix multiplication (MoE) with
- * quantized precision using the CANN backend.
+ * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
+ * models using the CANN backend.
  *
- * This function executes a matrix multiplication operation tailored for
- * Mixture of Experts (MoE) models, where the input tensor is multiplied
- * with expert-specific quantized weight matrices. It leverages the CANN
- * backend to perform efficient low-precision computations and stores the
- * quantized result in the destination tensor `dst`.
+ * This function implements MUL_MAT_ID operation for quantized weight matrices
+ * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
+ * the provided expert indices, and computes matrix multiplication using CANN's
+ * WeightQuantBatchMatmulV2 operator.
  *
- * Quantization techniques reduce memory footprint and improve performance
- * by using lower-bit representations (e.g., int8) instead of floating-point.
- * This function is designed to work with such formats and may incorporate
- * optimizations like identity-based fast paths or routing masks for sparse
- * expert selection.
+ * The function performs the following steps:
+ * 1. Converts input/output tensors to F16 format if necessary
+ * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
+ * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
+ * 4. Converts output back to the target type if needed
  *
- * @param ctx The context for executing CANN backend operations.
- * @param dst The destination tensor where the quantized MoE multiplication result
- * will be stored.
+ * Tensor shapes:
+ * - dst:  [M, K, N, 1] - output tensor
+ * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
+ * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
+ * - ids:  [K, N] - expert indices for routing
  *
- * @note This function assumes quantized data types and is designed for
- * MoE architectures with potential sparse expert routing.
+ * @param ctx The CANN backend context for operation execution.
+ * @param dst The destination tensor where the multiplication result will be stored.
+ *
+ * @note Only Q4_0 and Q8_0 quantization formats are supported.
+ * @note The function handles automatic type conversion to/from F16 as needed by the hardware.
  */
 static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
-    // TODO: Use aclnnGroupedMatMul
-    //dst   [M, K, N, 1]
-    ggml_tensor * src0 = dst->src[0];  //src0	[D, M, A, 1]
-    ggml_tensor * src1 = dst->src[1];  //src1	[D, B, N, 1], B = K or B = 1
-    ggml_tensor * ids  = dst->src[2];  //ids	[K, N]
+    // dst:  [M, K, N, 1]
+    // src0: [D, M, A, 1] - quantized weights
+    // src1: [D, B, N, 1] - input activations, B = K or B = 1
+    // ids:  [K, N] - expert indices
+    ggml_tensor * src0 = dst->src[0];
+    ggml_tensor * src1 = dst->src[1];
+    ggml_tensor * ids  = dst->src[2];
 
-    GGML_TENSOR_BINARY_OP_LOCALS
+    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_ASSERT(src1->ne[3] == 1);
+    GGML_ASSERT(dst->ne[3] == 1);
+    GGML_ASSERT(src1->ne[2] == ids->ne[1]);
 
-    // copy index from npu to cpu
-    int64_t n_as  = ne02;        // A
-    int64_t n_ids = ids->ne[0];  // K
+    const int64_t        n_batches        = ids->ne[1];
+    const int64_t        n_select_experts = ids->ne[0];
+    const enum ggml_type type             = src0->type;
 
-    std::vector ids_host(ggml_nbytes(ids));
-    ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
-                               ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
-    ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
+    const int32_t group_size = QK8_0;  // Both Q4_0 and Q8_0 use group size of 32
+    GGML_ASSERT(group_size == QK4_0);
 
-    char * src0_original = (char *) src0->data;
-    char * src1_original = (char *) src1->data;
-    char * dst_original  = (char *) dst->data;
+    // Calculate element size for quantized weights
+    const float weight_elem_size =
+        (type == GGML_TYPE_Q4_0) ? 0.5f :
+        (type == GGML_TYPE_Q8_0) ? 1.0f :
+                                   (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
 
-    ggml_tensor src0_row = *src0;
-    ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row  = *dst;
+    // Calculate scale offset in memory
+    const size_t weight_size     = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
+    const size_t scale_elem_size = sizeof(uint16_t);
+    char *       scale_data      = (char *) src0->data + weight_size;
 
-    const enum ggml_type type = dst->src[0]->type;
-    float                weight_elem_size;
-    if (type == GGML_TYPE_Q4_0) {
-        weight_elem_size = float(sizeof(uint8_t)) / 2;
-    } else if (type == GGML_TYPE_Q8_0) {
-        weight_elem_size = float(sizeof(uint8_t));
-    } else {
-        GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
-    }
+    // Allocate buffers for selected expert weights and scales
+    const size_t         selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
+    ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
+    void *               selected_weight_buffer = selected_weight_alloc.get();
 
-    // src0_row [D, M, 1, 1] weight without permute
-    src0_row.ne[2]       = 1;
-    src0_row.ne[3]       = 1;
-    src0_row.nb[0]       = weight_elem_size;
-    src0_row.nb[1]       = weight_elem_size * ne00;
-    src0_row.nb[2]       = weight_elem_size * ne00;
-    src0_row.nb[3]       = weight_elem_size * ne00;
-    size_t weight_stride = ne00 * ne01 * weight_elem_size;
-    size_t weight_size   = weight_stride * ne02 * ne03;
+    const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
+    ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
+    void *               selected_scale_buffer = selected_scale_alloc.get();
 
-    // scale [D, M, 1, 1] -> scale && permute
-    size_t scale_elem_size = sizeof(uint16_t);
-    size_t scale_stride    = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
+    // Helper lambda to allocate and cast tensor to F16 if needed
+    constexpr size_t f16_elem_size      = sizeof(uint16_t);
+    auto             prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
+                                  bool need_cast = false) -> void * {
+        if (tensor->type == GGML_TYPE_F16) {
+            return tensor->data;
+        }
 
-    // src1_row [D, 1, 1, 1] -> input
-    src1_row.ne[1] = 1;
-    src1_row.ne[2] = 1;
-    src1_row.ne[3] = 1;
-    src1_row.nb[2] = nb11;
-    src1_row.nb[3] = nb11;
+        size_t total_size = f16_elem_size;
+        for (int i = 0; i < GGML_MAX_DIMS; i++) {
+            total_size *= tensor->ne[i];
+        }
+        void * buffer = allocator.alloc(total_size);
 
-    // dst_row [M, 1, 1, 1] -> out
-    dst_row.ne[1] = 1;
-    dst_row.ne[2] = 1;
-    dst_row.ne[3] = 1;
-    dst_row.nb[2] = nb1;
-    dst_row.nb[3] = nb1;
+        if (need_cast == false) {
+            return buffer;
+        }
 
-    //create weight for one row
-    ggml_cann_pool_alloc weight_allocator(ctx.pool());
-    void *               weight_buffer = weight_allocator.alloc(nb02);
-    for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-        for (int64_t id = 0; id < n_ids; id++) {
-            // expert index
-            int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
-            GGML_ASSERT(i02 >= 0 && i02 < n_as);
+        int64_t ne[GGML_MAX_DIMS];
+        size_t  nb[GGML_MAX_DIMS] = { f16_elem_size };
+        for (int i = 0; i < GGML_MAX_DIMS; i++) {
+            ne[i] = tensor->ne[i];
+            if (i > 0) {
+                nb[i] = nb[i - 1] * ne[i - 1];
+            }
+        }
 
-            // If B = 1 (broadcast), always use 0; otherwise, use id.
-            int64_t i11 = (ne11 == 1 ? 0 : id);
-            int64_t i12 = iid1;
+        acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
+        acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
+        aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
 
-            int64_t i1 = id;
-            int64_t i2 = i12;
+        return buffer;
+    };
 
-            void * src0_tmp_ptr  = src0_original + i02 * weight_stride;
-            void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
-            void * src1_tmp_ptr  = src1_original + i11 * nb11 + i12 * nb12;
-            void * dst_tmp_ptr   = dst_original + i1 * nb1 + i2 * nb2;
+    // Prepare input and output buffers
+    ggml_cann_pool_alloc input_alloc(ctx.pool());
+    void *               input_buffer = prepare_f16_buffer(src1, input_alloc, true);
 
-            // mem cpy
-            ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
-                                       ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
-            void * scale_buffer = (char *) weight_buffer + weight_stride;
-            ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
-                                       ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+    ggml_cann_pool_alloc output_alloc(ctx.pool());
+    void *               output_buffer = prepare_f16_buffer(dst, output_alloc, false);
 
-            src0_row.data  = weight_buffer;
-            src1_row.data  = src1_tmp_ptr;
-            dst_row.data   = dst_tmp_ptr;
-            dst_row.src[0] = &src0_row;
-            dst_row.src[1] = &src1_row;
+    // Process each batch
+    for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
+        // Create index tensor for current batch
+        const size_t   index_offset  = batch_idx * ids->nb[1];
+        acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
 
-            ggml_cann_mul_mat(ctx, &dst_row);
+        // Select quantized weights using expert indices
+        // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
+        const int64_t weight_d         = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
+        const int64_t weight_m         = src0->ne[1];
+        const int64_t weight_n_experts = src0->ne[2];
+
+        int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
+        size_t  weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
+
+        acl_tensor_ptr all_weights =
+            ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
+
+        int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
+        size_t  selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
+                                          weight_d * weight_m * sizeof(int8_t) };
+
+        acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
+                                                                  selected_weight_ne, selected_weight_nb, 3);
+
+        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
+
+        // Select scales using the same expert indices
+        const int64_t scale_d     = src0->ne[0] / group_size;
+        int64_t       scale_ne[3] = { scale_d, weight_m, weight_n_experts };
+        size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
+
+        acl_tensor_ptr all_scales =
+            ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
+
+        int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
+        size_t  selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
+                                         scale_d * weight_m * scale_elem_size };
+
+        acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
+                                                                 selected_scale_ne, selected_scale_nb, 3);
+
+        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
+
+        // Process each expert for current batch
+        // IndexSelect output layout: [D, M, K] in contiguous format
+        // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
+        for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
+            // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
+            const size_t input_offset =
+                (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
+            const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
+
+            // Create weight view for current expert: [D, M, K] -> [M, D]
+            int64_t      weight_view_ne[2]  = { weight_m, src0->ne[0] };
+            float        weight_view_nb[2]  = { src0->ne[0] * weight_elem_size, weight_elem_size };
+            const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
+
+            acl_tensor_ptr weight_view =
+                ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
+                                        weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
+
+            // Create scale view for current expert: [D, M, K] -> [M, D]
+            int64_t      scale_view_ne[2]  = { weight_m, scale_d };
+            size_t       scale_view_nb[2]  = { selected_scale_nb[1], selected_scale_nb[0] };
+            const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
+
+            acl_tensor_ptr scale_view =
+                ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
+                                        scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
+
+            // Create input activation tensor [D, 1]
+            int64_t input_ne[2] = { src1->ne[0], 1 };
+            size_t  input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
+
+            acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
+                                                                  input_nb, 2, ACL_FORMAT_ND, input_offset);
+
+            // Create output tensor [M, 1]
+            int64_t output_ne[2] = { dst->ne[0], 1 };
+            size_t  output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
+
+            acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
+                                                                   output_nb, 2, ACL_FORMAT_ND, output_offset);
+
+            // Perform quantized matrix multiplication
+            GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
+                                    scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
+                                    output_tensor.get());
         }
     }
-    return;
+
+    // Cast output back to original type if we used a temporary F16 buffer
+    if (dst->type != GGML_TYPE_F16) {
+        int64_t ne[GGML_MAX_DIMS];
+        size_t  nb[GGML_MAX_DIMS] = { f16_elem_size };
+        for (int i = 0; i < GGML_MAX_DIMS; i++) {
+            ne[i] = dst->ne[i];
+            if (i > 0) {
+                nb[i] = nb[i - 1] * ne[i - 1];
+            }
+        }
+
+        acl_tensor_ptr f16_output =
+            ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
+        acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
+
+        aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
+    }
 }
 
 void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
@@ -3742,15 +3839,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     // we want a view:  ne_w = { nc, 1, nr }   // [K, 1, C]
     // so that reversed dims -> [C, 1, K] which matches
     //   [out_channels, in_channels/groups, kernel_size]
-    int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
+    int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 };  // [K, 1 input ch. per group, C groups]
     // Layout: src1 data is [K, C] with
     //   offset(k, c) = k*nb0 + c*nb1
     // We want offset_w(k, 0, c) = k*nb0 + c*nb1,
     // so we can reuse nb0 and nb1, and set nb2 = nb1.
-    size_t  w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
+    size_t  w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] };  // same as src1
 
-    acl_tensor_ptr acl_w = ggml_cann_create_tensor(
-        src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),
+                                                   ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
 
     // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
     //
@@ -3768,11 +3865,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     //   nb_y[0] = nr * sizeof(float);           // step in L
     //   nb_y[1] = sizeof(float);                // step in C
     //   nb_y[2] = nr * n_t * sizeof(float);     // step in N
-    int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
-    size_t  y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
+    int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 };  // [L_out, C, N]
+    size_t  y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),
+                                    dst->nb[3] };       // [nr, 1, nr * n_t]
 
-    acl_tensor_ptr acl_y = ggml_cann_create_tensor(
-        dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
+                                                   ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
 
     // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
     int64_t strideVal[1]   = { 1 };
@@ -3791,22 +3889,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     cubeMathType = 1;
 #endif
 
-    GGML_CANN_CALL_ACLNN_OP(ctx,
-                            Convolution,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,
                             acl_x.get(),    // input:  N, C, L_in = ncs
                             acl_w.get(),    // weight: [C, 1, K] with groups=nr
                             nullptr,        // bias
-                            stride.get(),
-                            padding.get(),
-                            dilation.get(),
-                            transposed,
-                            padding.get(),   // output padding (unused for non-transposed)
-                            groups,
-                            acl_y.get(),
-                            cubeMathType);
+                            stride.get(), padding.get(), dilation.get(), transposed,
+                            padding.get(),  // output padding (unused for non-transposed)
+                            groups, acl_y.get(), cubeMathType);
 }
 
-
 void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
                                      ggml_tensor *               add_node,
                                      ggml_tensor *               rms_norm_node) {
@@ -3860,3 +3951,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
                             eps,  // double type
                             acl_yout.get(), acl_rstd.get(), acl_xout.get());
 }
+
+void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * k = dst->src[0];
+    ggml_tensor * v = dst->src[1];
+    ggml_tensor * q = dst->src[2];
+    ggml_tensor * g = dst->src[3];
+    ggml_tensor * s = dst->src[4];
+
+    int64_t B = dst->src[4]->ne[1];
+    int64_t T = dst->src[0]->ne[2];
+    int64_t H = dst->src[0]->ne[1];
+    int64_t C = dst->ne[0];
+    int64_t D = C / H;
+    int64_t L = T / B;
+
+    int64_t ne_qkg[2] = { 1, D };
+    int64_t ne_s[2]   = { D, D };
+    int64_t ne_st[2]  = { ne_s[1], ne_s[0] };
+    int64_t ne_vo[2]  = { D, 1 };
+    int64_t ne_q[1]   = { D };
+    size_t  nb_base   = ggml_type_size(k->type);
+    size_t  nb_qkg[2] = { nb_base, nb_base };
+    size_t  nb_s[2]   = { nb_base, D * nb_base };
+    size_t  nb_st[2]  = { nb_s[1], nb_s[0] };
+    size_t  nb_vo[2]  = { nb_base, D * nb_base };
+    size_t  nb_q[1]   = { nb_base };
+
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    acl_tensor_ptr acl_s     = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);
+    acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);
+    cann_copy(ctx, acl_s.get(), new_state.get());
+
+    for (int64_t b = 0; b < B; b++) {
+        for (int64_t h = 0; h < H; h++) {
+            size_t         s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
+            // D * D
+            acl_tensor_ptr acl_s_new =
+                ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
+            acl_tensor_ptr acl_s_new_t =
+                ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
+            for (int64_t l = 0; l < L; l++) {
+                size_t               qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
+                // D * 1
+                acl_tensor_ptr       acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
+                acl_tensor_ptr       acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
+                // D
+                acl_tensor_ptr       acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
+                // 1 * D
+                acl_tensor_ptr       acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
+                // D
+                acl_tensor_ptr       acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
+                // k ⊗ v
+                size_t               buf_size = D * D * nb_base;
+                ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);
+                acl_tensor_ptr       tmp_tensor = ggml_cann_create_tensor(
+                    buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);
+                aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());
+                //s_new = g ⊗ s_old + k ⊗ v
+                aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);
+                aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);
+                // compute output
+                GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);
+                aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h
index 08ee7b1f..3effa1c2 100644
--- a/ggml/src/ggml-cann/aclnn_ops.h
+++ b/ggml/src/ggml-cann/aclnn_ops.h
@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
@@ -814,67 +814,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
  */
 void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
-/*
- * @brief A generic wrapper for ACL resources with custom deleter support.
- */
-using any_acl_resource = std::unique_ptr>;
-
 /**
- * @brief Trait structure used to define how to destroy a given ACL resource type.
+ * @brief Forward Gated Linear Attention on the CANN backend.
  *
- * @tparam T ACL resource type.
- */
-template  struct acl_resource_traits;
-
-/**
- * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
- */
-template <> struct acl_resource_traits {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast(p))); }
-};
-
-/**
- * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
- */
-template <> struct acl_resource_traits {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast(p))); }
-};
-
-/**
- * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
- */
-template <> struct acl_resource_traits {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast(p))); }
-};
-
-/**
- * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
- */
-template <> struct acl_resource_traits {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast(p))); }
-};
-
-/**
- * @brief Creates a generic ACL resource wrapper with proper destruction logic.
+ * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions:
+ *   k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H
+ *   s: initial state [B, H, D, D], where B is batch and D=C/H
+ * dst holds both outputs (o) and updated state; a scale factor is read from op params.
  *
- * @tparam T ACL resource type.
- * @param ptr Raw pointer to ACL resource.
- * @return any_acl_resource Smart pointer that handles destruction.
- */
-template  any_acl_resource make_acl_resource(T * ptr) {
-    return any_acl_resource(static_cast(ptr), [](void * p) { acl_resource_traits::destroy(p); });
-}
-
-/**
- * @brief Registers multiple ACL resources into a vector for lifetime management.
+ * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale.
  *
- * @tparam Args Variadic list of ACL resource types.
- * @param vec Target vector to hold ACL resources.
- * @param args Raw pointers to ACL resources.
+ * @param ctx Backend context providing stream/allocator utilities.
+ * @param dst Output tensor; src deps are k, v, q, g, s as above.
  */
-template  void register_acl_resources(std::vector & vec, Args *... args) {
-    (vec.emplace_back(make_acl_resource(args)), ...);
-}
+void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
 /**
  * @brief Launches an asynchronous task using the memory allocator.
@@ -894,19 +847,19 @@ template  void register_acl_resources(std::vector 0) {                                                             \
-            ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);             \
-            workspaceAddr = workspace_allocator.get();                                       \
-        }                                                                                    \
-        ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));     \
-    } while (0)
+#    define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                           \
+        do {                                                                                     \
+            uint64_t        workspaceSize = 0;                                                   \
+            aclOpExecutor * executor;                                                            \
+            void *          workspaceAddr = nullptr;                                             \
+            ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
+            /* workspace should alloced in main thread to keep malloc order when using vmm. */   \
+            if (workspaceSize > 0) {                                                             \
+                ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);             \
+                workspaceAddr = workspace_allocator.get();                                       \
+            }                                                                                    \
+            ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));     \
+        } while (0)
 
 /**
  * @brief   Performs sparse expert-based matrix multiplication using the CANN backend.
@@ -947,7 +900,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
  * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
  *                        and epsilon parameter.
  */
-void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
+                                     ggml_tensor *               add_node,
+                                     ggml_tensor *               rms_norm_node);
 
 /**
  * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
@@ -1104,13 +1059,13 @@ void ggml_cann_op_unary_gated(std::function get_env_as_lowercase(const std::string & name);
 bool                       parse_bool(const std::string & value);
@@ -382,7 +381,7 @@ struct ggml_cann_graph_lru_cache {
 
     std::list cache_list; /**< List storing cached graphs as raw pointers. */
 
-    ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
+    ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env_as_lowercase("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
 
     /**
      * @brief Push a new graph to the front of the cache.
@@ -574,7 +573,7 @@ struct ggml_backend_cann_context {
         description = aclrtGetSocName();
 
 #ifdef USE_ACL_GRAPH
-        acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
+        acl_graph_mode = parse_bool(get_env_as_lowercase("GGML_CANN_ACL_GRAPH").value_or("on"));
         GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
                       acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
 #endif
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index d7a93848..3f3de9f0 100644
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2023-2024 The ggml authors
+ * Copyright (c) 2023-2026 The ggml authors
  *
  * Permission is hereby granted, free of charge, to any person obtaining a copy
  * of this software and associated documentation files (the "Software"), to
@@ -93,17 +93,6 @@ void ggml_cann_set_device(const int32_t device) {
     g_current_cann_device = device;
 }
 
-/**
- * @brief Retrieves the current device ID.
- *
- * @return The current device ID.
- */
-int32_t ggml_cann_get_device() {
-    int32_t id;
-    ACL_CHECK(aclrtGetDevice(&id));
-    return id;
-}
-
 /**
  * @brief Get the value of the specified environment variable (name) as lowercase.
  *        if not empty, return a std::string object
@@ -805,19 +794,44 @@ struct ggml_backend_cann_buffer_context {
     ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
 };
 
+// cann buffer type
 /**
- * @brief Check if a buffer is a CANN buffer.
- *
- * This function checks if a given buffer is a CANN buffer by comparing its
- * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
- *
- * @param buffer The buffer to check.
- * @return true if the buffer is a CANN buffer, false otherwise.
+ * @brief Structure representing context information for a specific backend
+ * buffer type.
  */
-static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
+struct ggml_backend_cann_buffer_type_context {
+    int32_t     device; /**< Device identifier associated with the buffer context. */
+    std::string name;   /**< Name associated with the buffer context. */
+};
 
-static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
-    return ggml_backend_buft_is_cann(buffer->buft);
+/**
+ * @brief Retrieves the name associated with a CANN buffer type.
+ *
+ * This function returns the descriptive name associated with the specified
+ * CANN buffer type context.
+ *
+ * @param buft Pointer to the buffer type context.
+ * @return Const pointer to the C-style string containing the name.
+ */
+static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
+
+    return buft_ctx->name.c_str();
+}
+
+/**
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
+ *
+ * This function checks whether the provided backend buffer type is associated
+ * with the CANN backend based on the comparison of its name retrieval function
+ * pointer.
+ *
+ * @param buft Pointer to the backend buffer type to check.
+ * @return bool Returns true if the buffer type is associated with the CANN
+ * backend, otherwise false.
+ */
+static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
 }
 
 /**
@@ -1282,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
 static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
                                                 const ggml_tensor *   src,
                                                 ggml_tensor *         dst) {
-    if (ggml_backend_buffer_is_cann(src->buffer)) {
+    if (ggml_backend_buft_is_cann(src->buffer->buft)) {
         ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
         ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
 
@@ -1346,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
     /* .reset           = */ NULL,
 };
 
-// cann buffer type
-/**
- * @brief Structure representing context information for a specific backend
- * buffer type.
- */
-struct ggml_backend_cann_buffer_type_context {
-    int32_t     device; /**< Device identifier associated with the buffer context. */
-    std::string name;   /**< Name associated with the buffer context. */
-};
-
-/**
- * @brief Retrieves the name associated with a CANN buffer type.
- *
- * This function returns the descriptive name associated with the specified
- * CANN buffer type context.
- *
- * @param buft Pointer to the buffer type context.
- * @return Const pointer to the C-style string containing the name.
- */
-static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
-
-    return buft_ctx->name.c_str();
-}
-
 /**
  * @brief Allocates a new CANN buffer of the specified type and size.
  *
@@ -1889,6 +1878,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
         case GGML_OP_OUT_PROD:
             ggml_cann_out_prod(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cann_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_cann_ssm_conv(ctx, dst);
             break;
@@ -2005,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t      backend_src,
 
     GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
 
-    if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) {
+    if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
         return false;
     }
 
@@ -2154,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
                 continue;
             }
 
+            if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+                continue;
+            }
+
             bool ok = ggml_cann_compute_forward(*cann_ctx, node);
             if (!ok) {
                 GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
@@ -2454,6 +2450,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
         case GGML_OP_MEAN:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_COUNT_EQUAL:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_OUT_PROD:
             {
@@ -2526,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
     GGML_UNUSED(dev);
 }
 
-/**
- * @brief Checks if the backend buffer type is associated with the CANN backend.
- *
- * This function checks whether the provided backend buffer type is associated
- * with the CANN backend based on the comparison of its name retrieval function
- * pointer.
- *
- * @param buft Pointer to the backend buffer type to check.
- * @return bool Returns true if the buffer type is associated with the CANN
- * backend, otherwise false.
- */
-static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
-}
-
 /**
  * @brief Records an event on the CANN backend stream.
  *
diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h
index 93ab7ea4..92cf739e 100644
--- a/ggml/src/ggml-common.h
+++ b/ggml/src/ggml-common.h
@@ -102,6 +102,9 @@ typedef sycl::half2 ggml_half2;
 #define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
 #define QR_MXFP4 2
 
+#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4))
+#define QR_NVFP4 2
+
 #define QI5_0 (QK5_0 / (4 * QR5_0))
 #define QR5_0 2
 
@@ -194,6 +197,14 @@ typedef struct {
 } block_mxfp4;
 static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
 
+#define QK_NVFP4 64
+#define QK_NVFP4_SUB 16  // sub-block size for per-group scales
+typedef struct {
+    uint8_t d[QK_NVFP4/QK_NVFP4_SUB]; // UE4M3 scales (4 bytes, one per 16-element sub-block)
+    uint8_t qs[QK_NVFP4/2];           // packed 4-bit E2M1 values (32 bytes)
+} block_nvfp4;
+static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
+
 #define QK5_0 32
 typedef struct {
     ggml_half d;           // delta
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index 7622d0bf..6ca3176a 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch)
     target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN})
     target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
     set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
+    # Disable LTO for the feature detection code to prevent cross-module optimization
+    # from inlining architecture-specific instructions into the score function.
+    # Without this, LTO can cause SIGILL when loading backends on older CPUs
+    # (e.g., loading power10 backend on power9 crashes before feature check runs).
+    target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto)
     target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME})
 endfunction()
 
@@ -561,35 +566,32 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.16.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.22.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "0a9e9008adb6031f9e8cf70dff4a3321")
+        set(KLEIDIAI_ARCHIVE_MD5  "54049037570ab0ee0a0d126b2ba5ece1")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
         endif()
 
+        # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
+        # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
         FetchContent_Declare(KleidiAI_Download
             URL ${KLEIDIAI_DOWNLOAD_URL}
             DOWNLOAD_EXTRACT_TIMESTAMP NEW
             URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
 
-        FetchContent_MakeAvailable(KleidiAI_Download)
         FetchContent_GetProperties(KleidiAI_Download
             SOURCE_DIR  KLEIDIAI_SRC
             POPULATED   KLEIDIAI_POPULATED)
 
         if (NOT KLEIDIAI_POPULATED)
-            message(FATAL_ERROR "KleidiAI source downloaded failed.")
+            FetchContent_Populate(KleidiAI_Download)
+            FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
         endif()
 
         add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
 
-        # Remove kleidiai target after fetching it
-        if (TARGET kleidiai)
-            set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
-        endif()
-
         list(APPEND GGML_CPU_SOURCES
             ggml-cpu/kleidiai/kleidiai.cpp
             ggml-cpu/kleidiai/kernels.cpp
@@ -606,6 +608,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
 
         set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -646,7 +649,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         if (NOT SME_ENABLED MATCHES -1)
             list(APPEND GGML_KLEIDIAI_SOURCES
-                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
@@ -654,10 +656,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c
                 ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
-            set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
+            set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16")
         endif()
 
         if (NOT SVE_ENABLED MATCHES -1)
diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp
index 895a5713..9baf3e02 100644
--- a/ggml/src/ggml-cpu/amx/amx.cpp
+++ b/ggml/src/ggml-cpu/amx/amx.cpp
@@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ
 namespace ggml::cpu::amx {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
-        // handle only 2d gemm for now
-        auto is_contiguous_2d = [](const struct ggml_tensor * t) {
-            return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
-        };
-
-        if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) &&  // src0 must be contiguous
-            is_contiguous_2d(op->src[1]) &&                               // src1 must be contiguous
-            op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
-            op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
-            op->ne[0] % (TILE_N * 2) == 0 &&                              // out_features is 32x
-            (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
-            // src1 must be host buffer
-            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
-                return false;
-            }
-            // src1 must be float32
-            if (op->src[1]->type == GGML_TYPE_F32) {
-                return true;
-            }
+        if (op->op != GGML_OP_MUL_MAT) {
+            return false;
         }
-        return false;
+        auto * src0 = op->src[0];
+        auto * src1 = op->src[1];
+
+        if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+            return false;
+        }
+        if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) {
+            return false;
+        }
+        if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) {
+            return false;
+        }
+        if (op->ne[0] % (TILE_N * 2)) {
+            return false;
+        }
+        int alignment;
+        switch (src0->type) {
+            case GGML_TYPE_Q4_0:
+            case GGML_TYPE_Q4_1:
+            case GGML_TYPE_Q8_0:
+                alignment = TILE_K;
+                break;
+            case GGML_TYPE_Q4_K:
+            case GGML_TYPE_Q5_K:
+            case GGML_TYPE_Q6_K:
+            case GGML_TYPE_IQ4_XS:
+                alignment = 256; // QK_K
+                break;
+            case GGML_TYPE_F16:
+                alignment = 16;
+                break;
+            default:
+                return false;
+        }
+        if (src0->ne[0] % alignment) {
+            return false;
+        }
+        if (src1->type != GGML_TYPE_F32) {
+            return false;
+        }
+        return true;
     }
 
     ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
diff --git a/ggml/src/ggml-cpu/amx/common.h b/ggml/src/ggml-cpu/amx/common.h
index f392e898..26a6ec1a 100644
--- a/ggml/src/ggml-cpu/amx/common.h
+++ b/ggml/src/ggml-cpu/amx/common.h
@@ -9,6 +9,8 @@
 
 #if defined(GGML_USE_OPENMP)
 #include 
+#else
+#include 
 #endif
 
 #define TILE_M 16
@@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
 }
 
 template 
-inline void parallel_for(int n, const func_t& f) {
+inline void parallel_for(int n, const func_t & f) {
+    if (n <= 0) {
+        return;
+    }
 #if defined(GGML_USE_OPENMP)
-#pragma omp parallel
-{
-    int nth = omp_get_num_threads();
-    int ith = omp_get_thread_num();
-    int tbegin, tend;
-    balance211(n, nth, ith, tbegin, tend);
-    f(tbegin, tend);
-}
+    #pragma omp parallel
+    {
+        int nth = omp_get_num_threads();
+        int ith = omp_get_thread_num();
+        int tbegin, tend;
+        balance211(n, nth, ith, tbegin, tend);
+        f(tbegin, tend);
+    }
 #else
-    f(0, n);
+    int nth = std::thread::hardware_concurrency();
+    if (nth <= 1) {
+        f(0, n);
+        return;
+    }
+    if (nth > n) {
+        nth = n;
+    }
+    std::vector threads;
+    threads.reserve(nth);
+    for (int ith = 0; ith < nth; ++ith) {
+        threads.emplace_back([&f, n, ith, nth] {
+            int tbegin, tend;
+            balance211(n, nth, ith, tbegin, tend);
+            f(tbegin, tend);
+        });
+    }
+    for (auto & t : threads) {
+        t.join();
+    }
 #endif
 }
 
diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp
index 47c61b88..93a6d397 100644
--- a/ggml/src/ggml-cpu/amx/mmq.cpp
+++ b/ggml/src/ggml-cpu/amx/mmq.cpp
@@ -1,4 +1,3 @@
-
 #if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wpedantic"
 #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
@@ -196,41 +195,33 @@ struct tile_config_t{
 // will be needed.
 //
 // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;
-// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.
+// and the single batch gemm (m=1) has a special fast path with `avx512-vnni`.
 //
 // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/
 //    advanced-matrix-extensions-intrinsics-functions.html
 //
 
-#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
-void ggml_tile_config_init(void) {
-    static thread_local bool is_first_time = true;
+inline void ggml_tile_config_init(void) {
+    static thread_local bool done = false;
 
-    if (!is_first_time) {
+    if (done) {
         return;
     }
 
-    static thread_local tile_config_t tc;
-    tile_config_t current_tc;
-    _tile_storeconfig(¤t_tc);
+    alignas(64) tile_config_t tc = {};
+    tc.palette_id = 1;
+    tc.start_row = 0;
+    tc.rows[0] = 8;   tc.colsb[0] = 64;
+    tc.rows[1] = 8;   tc.colsb[1] = 64;
+    tc.rows[2] = 16;  tc.colsb[2] = 32;
+    tc.rows[3] = 16;  tc.colsb[3] = 32;
+    tc.rows[4] = 16;  tc.colsb[4] = 64;
+    tc.rows[5] = 16;  tc.colsb[5] = 64;
+    tc.rows[6] = 16;  tc.colsb[6] = 64;
+    tc.rows[7] = 16;  tc.colsb[7] = 64;
 
-    // load only when config changes
-    if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
-                               memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
-        tc.palette_id = 1;
-        tc.start_row = 0;
-        TC_CONFIG_TILE(TMM0, 8, 64);
-        TC_CONFIG_TILE(TMM1, 8, 64);
-        TC_CONFIG_TILE(TMM2, 16, 32);
-        TC_CONFIG_TILE(TMM3, 16, 32);
-        TC_CONFIG_TILE(TMM4, 16, 64);
-        TC_CONFIG_TILE(TMM5, 16, 64);
-        TC_CONFIG_TILE(TMM6, 16, 64);
-        TC_CONFIG_TILE(TMM7, 16, 64);
-        _tile_loadconfig(&tc);
-    }
-
-    is_first_time = false;
+    _tile_loadconfig(&tc);
+    done = true;
 }
 
 // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
@@ -268,33 +259,6 @@ int get_row_size(int K) {
     return row_size;
 }
 
-// vectorized dtype conversion
-inline float FP16_TO_FP32(ggml_half val) {
-    __m256i v = _mm256_setr_epi16(
-        val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
-    __m512 o = _mm512_cvtph_ps(v);
-    return _mm512_cvtss_f32(o);
-}
-
-inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
-    __m256i v = _mm256_set1_epi16(val);
-    return _mm512_cvtph_ps(v);
-}
-
-// horizontal reduce
-inline float _mm512_reduce_max_ps(const __m512 x) {
-    __m512 v = x;
-    __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_ps(v, v, 0x4E);
-    v = _mm512_max_ps(v, v1);
-    v1 = _mm512_shuffle_ps(v, v, 0xB1);
-    v = _mm512_max_ps(v, v1);
-    return _mm512_cvtss_f32(v);
-}
-
 // transpose utils
 #define SHUFFLE_EPI32(a, b, mask) \
     _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
@@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx
 
 #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                \
     tinygemm_kernel_avx::apply(    \
-        K, (const float *)src1->data + mb_start * K,                                \
-        (const type *)src0->data + nb_start * K,                                    \
-        (float *)dst->data + mb_start * ldc + nb_start, ldc);
+        K, (const float *)src1->data + src1_offset + mb_start * K,                  \
+        (const type *)src0->data + src0_offset + nb_start * K,                      \
+        (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc)
 
 
 // re-organize in the format {NB, KB, TILE_SIZE}:
@@ -1415,8 +1379,8 @@ struct tinygemm_kernel_vnni::apply(   \
-        KB, (const char *)wdata + 0 * row_size_A,                                    \
-        (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE),     \
-        (float *) dst->data + 0 * N + nb_start, ldc)
+#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                                   \
+    tinygemm_kernel_vnni::apply(             \
+        KB, wdata_batch,                                                                       \
+        (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \
+        (float *) dst->data + dst_offset + nb_start, ldc)
 
 template ::value, int>::type = 0>
@@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v
         _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
 
         if (need_unpack) {
-            unpack_B(Tile1, B_blk0);
+            unpack_B(Tile1, B_blk1);
             _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
         } else {
             _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
@@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d
     });
 }
 
+// ne2 is passed explicitly to help compiler optimize repeated calls
+inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) {
+    const int64_t i2 = batch_idx % ne2;
+    const int64_t i3 = batch_idx / ne2;
+    return i3 * t->nb[3] + i2 * t->nb[2];
+}
+
 size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
     struct ggml_tensor * src0 = dst->src[0];
 
@@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
 
     const int M = dst->ne[1];
     const int K = src0->ne[0];
+    const int64_t n_batch = dst->ne[2] * dst->ne[3];
 
     size_t desired_wsize = 0;
 
     GGML_DISPATCH_QTYPES(TYPE, [&] {
         const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
-        desired_wsize = M * row_size_A;
+        desired_wsize = n_batch * M * row_size_A;
     });
 
     return desired_wsize;
@@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) {
 // src1: input  in shape of {M, K}, float32
 // dst:  output in shape of {M, N}, float32
 //
-// the function performs: dst = src1 @ src0.T
+// the function performs: dst = src1 @ src0.T for each batch
 //
 void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) {
     struct ggml_tensor * src0 = dst->src[0];
@@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     const int K = src0->ne[0];
     const int ldc = dst->nb[1] / dst->nb[0];
 
+    const int64_t ne2 = dst->ne[2];
+    const int64_t n_batch = ne2 * dst->ne[3];
+
     if (is_floating_type) {
         constexpr int BLOCK_M = 4;
         constexpr int BLOCK_N = 6;
         const int MB = div_up(M, BLOCK_M);
         const int NB = div_up(N, BLOCK_N);
 
-        parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
+        parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
             GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
                 for (int i = begin; i < end; ++i) {
-                    int mb = i / NB;
-                    int nb = i % NB;
+                    int batch_idx = i / (MB * NB);
+                    int remaining = i % (MB * NB);
+                    int mb = remaining / NB;
+                    int nb = remaining % NB;
+
+                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
+                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
 
                     int mb_start = mb * BLOCK_M;
                     int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     void * wdata = params->wdata;
 
     //TODO: performance improvement: merge quant A
-    if (params->ith == 0) {
+ // if (params->ith == 0) {
         GGML_DISPATCH_QTYPES(TYPE, [&] {
             const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
-            const size_t desired_wsize = M * row_size_A;
+            const size_t desired_wsize = n_batch * M * row_size_A;
             if (params->wsize < desired_wsize) {
                 GGML_ABORT("insufficient work space size");
             }
@@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
             // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
             GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
 
-            const float * A_data = static_cast(src1->data);
-            for (int m = 0; m < M; ++m) {
-                from_float(A_data + m * K, (char *)wdata + m * row_size_A, K);
-            }
+            parallel_for_ggml(params, n_batch, [&](int begin, int end) {
+                for (int batch_idx = begin; batch_idx < end; ++batch_idx) {
+                    int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2);
+                    const float * A_data = (const float *)((const char *)src1->data + src1_offset);
+                    char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A;
+
+                    for (int m = 0; m < M; ++m) {
+                        from_float(A_data + m * K, wdata_batch + m * row_size_A, K);
+                    }
+                }
+            });
         });
-    }
+ // }
 
     ggml_barrier(params->threadpool);
 
@@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
         constexpr int BLOCK_N = TILE_N * kTilesN;
         const int NB = div_up(N, BLOCK_N);
 
-        parallel_for_ggml(params, NB, [&](int begin, int end) {
+        parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) {
             GGML_DISPATCH_QTYPES(TYPE, [&] {
                 const int KB = K / blck_size;
                 const int TILE_SIZE = get_tile_size();
                 const int row_size_A = KB * sizeof(vec_dot_type);
                 for (int i = begin; i < end; ++i) {
-                    int nb = i;
+                    int batch_idx = i / NB;
+                    int nb = i % NB;
+
+                    int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                    int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
+                    const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A;
+
                     int nb_start = nb * BLOCK_N;
                     int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
 
@@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
     const int MB = div_up(M, BLOCK_M);
     const int NB = div_up(N, BLOCK_N);
 
-    parallel_for_ggml(params, MB * NB, [&](int begin, int end) {
+    parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) {
         // init tile config for each thread
         ggml_tile_config_init();
 
@@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
             const int row_size_A = KB * sizeof(vec_dot_type);
 
             for (int i = begin; i < end; ++i) {
-                int mb = i / NB;
-                int nb = i % NB;
+                int batch_idx = i / (MB * NB);
+                int remaining = i % (MB * NB);
+                int mb = remaining / NB;
+                int nb = remaining % NB;
+
+                int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2);
+                int64_t dst_offset  = ggml_batch_offset(dst,  batch_idx, ne2);
+                const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A;
 
                 int mb_start = mb * BLOCK_M;
                 int mb_size = std::min(BLOCK_M, M - mb_start);
@@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te
 
                 tinygemm_kernel_amx(
                     mb_size, nb_size, KB,
-                    (const char *)wdata + mb_start * row_size_A,
-                    (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
-                    (float *) dst->data + mb_start * N + nb_start, ldc);
+                    wdata_batch + mb_start * row_size_A,
+                    (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
+                    (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc);
             }
         });
     });
diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h
index 3f8946ac..41da8293 100644
--- a/ggml/src/ggml-cpu/arch-fallback.h
+++ b/ggml/src/ggml-cpu/arch-fallback.h
@@ -1,3 +1,4 @@
+
 #pragma once
 
 // Rename `_generic` functions if no native implementation is available.
@@ -14,6 +15,7 @@
 #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -38,21 +40,33 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
@@ -60,29 +74,44 @@
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
+// quants.c
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__POWERPC__) || defined(__powerpc__)
 // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
@@ -94,21 +123,33 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__loongarch64)
@@ -118,6 +159,7 @@
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -126,64 +168,78 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__riscv)
 // quants.c
-#define quantize_row_q8_K_generic quantize_row_q8_K
-#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
-#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
-#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
-#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
-#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
-#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
-#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
-#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
-#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
-#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
-#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
-#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
+#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
-#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
+#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__s390x__)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -202,21 +258,33 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__wasm__)
@@ -234,6 +302,7 @@
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -242,21 +311,33 @@
 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
+#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
-#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
+#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
+#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
+#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
+#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
-#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
+#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
+#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #endif
diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c
index b390ab61..82b048bb 100644
--- a/ggml/src/ggml-cpu/arch/arm/quants.c
+++ b/ggml/src/ggml-cpu/arch/arm/quants.c
@@ -650,6 +650,90 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 }
 
+void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_NVFP4 == 0);
+
+    const block_nvfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    // Each NVFP4 super-block (64 elements) spans 2 q8_0 blocks
+    const int nb = n / QK_NVFP4;
+
+    float sumf = 0;
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+    const int8x16_t values = vld1q_s8(kvalues_mxfp4);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    float32x4_t acc = vdupq_n_f32(0.0f);
+
+    for (int ib = 0; ib < nb; ++ib) {
+        const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs);
+        const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16);
+
+        const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits_0, m4b));
+        const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4));
+        const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits_1, m4b));
+        const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
+
+        const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
+        const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
+        const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
+        const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b));
+
+        const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs);
+        const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16);
+        const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b));
+        const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
+
+        const int32x4_t p0 = vaddq_s32(
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
+        const int32x4_t p1 = vaddq_s32(
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
+
+        const int32x4_t sums = vpaddq_s32(p0, p1);
+
+        // Decode 4 UE4M3 scales to f32 and multiply with q8 scales
+        const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
+        const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
+        const float32x4_t nvsc = {
+            ggml_ue4m3_to_fp32(x[ib].d[0]),
+            ggml_ue4m3_to_fp32(x[ib].d[1]),
+            ggml_ue4m3_to_fp32(x[ib].d[2]),
+            ggml_ue4m3_to_fp32(x[ib].d[3])
+        };
+        const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
+
+        acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
+    }
+    sumf = vaddvq_f32(acc);
+#else
+    for (int ib = 0; ib < nb; ++ib) {
+        for (int si = 0; si < 4; ++si) {
+            const float d = ggml_ue4m3_to_fp32(x[ib].d[si]);
+            const int q8b = si / 2;
+            const int q8o = (si % 2) * QK_NVFP4_SUB;
+            const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8b].d);
+
+            int sumi_lo = 0, sumi_hi = 0;
+            for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
+                const uint8_t qv = x[ib].qs[si*(QK_NVFP4_SUB/2) + j];
+                sumi_lo += y[2*ib + q8b].qs[q8o + j +               0] * kvalues_mxfp4[qv & 0xf];
+                sumi_hi += y[2*ib + q8b].qs[q8o + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >>  4];
+            }
+            sumf += dy * d * (sumi_lo + sumi_hi);
+        }
+    }
+#endif
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -968,7 +1052,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     const int vector_length = ggml_cpu_get_sve_cnt()*8;
 
-    //VLA Implemenation for SVE
+    //VLA Implementation for SVE
     switch (vector_length) {
         case 128:
             {
diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp
index b61220a1..80ff5ce5 100644
--- a/ggml/src/ggml-cpu/arch/arm/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp
@@ -25,9 +25,8 @@
 #define UNUSED GGML_UNUSED
 
 #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
-static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
-                                             int16x8_t *     out_mins,
-                                             int8_t *        out_scales) {
+// Helper for decoding scales and mins of Q4_K and Q5_K block formats
+static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
     constexpr uint32_t kmask1 = 0x3f3f3f3f;
     constexpr uint32_t kmask2 = 0x0f0f0f0f;
     constexpr uint32_t kmask3 = 0x03030303;
@@ -499,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    float * res_ptr = s;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+        float32x4_t sumf = vdupq_n_f32(0);
+        for (int l = 0; l < nb; l++) {
+            uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
+            uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
+            uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
+            uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
+
+            int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
+            int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
+            int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
+            int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
+            int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
+            int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
+            int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
+            int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
+
+            int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
+            int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
+
+            int32x4_t sumi = vdupq_n_s32(0);
+            sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
+            sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
+            sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
+            sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
+            sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
+            sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
+            sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
+            sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
+
+            float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
+            float32x4_t b_d = {
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
+            };
+            float32x4_t d = a_d * b_d;
+
+            sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
+        }
+
+        vst1q_f32(res_ptr + x * 4, sumf);
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     constexpr int qk = QK_K;
     const int     nb = n / qk;
@@ -561,7 +635,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                 for (int i = 0; i < 2; i++) {
                     int8_t    aux_q4sb[8];
                     const int offset = sb * 24 + i * 12;
-                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                     q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                 }
 
@@ -701,13 +775,13 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
                 for (int i = 0; i < 2; i++) {
                     int8_t    aux_q4sb[8];
                     const int offset = sb * 24 + i * 12;
-                    decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                    decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                     q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                 }
 
                 const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
 
-                // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
+                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
                 // but still need the qs to use the low and hi bits from q4
                 const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
                 int8x16_t      q8_qs[8];
@@ -786,6 +860,842 @@ void ggml_gemv_q4_K_8x8_q8_K(int                        n,
     ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_q5_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_groups = ncols_interleaved / 4;  // 0123 and 4567
+    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
+    const uint8x16_t mone       = vdupq_n_u8(1);
+    const uint8x16_t mtwo       = vdupq_n_u8(2);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[col_groups];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < col_groups; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q5_d_0        = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q5_d_1        = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d          = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
+            float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
+            float32x4_t q5_dmin_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3
+            float32x4_t q5_dmin_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7
+            float32x4_t sb_min_0123   = vmulq_f32(q5_dmin_0, q8_d);
+            float32x4_t sb_min_4567   = vmulq_f32(q5_dmin_1, q8_d);
+
+            // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
+            int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
+            int32x4_t acc_lo[col_groups];
+            int32x4_t acc_hi[col_groups];
+
+            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+            int16_t         bsums_arr[8];
+            vst1q_s16(bsums_arr, bsums);
+
+            uint8x16_t qh[col_groups][8];
+            for (int c = 0; c < col_groups; c++) {
+                for (int i = 0; i < 8; i++) {
+                    qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
+                }
+            }
+
+            for (int sb = 0; sb < QK_K / 64; sb++) {
+                for (int i = 0; i < col_groups; i++) {
+                    acc_lo[i] = vdupq_n_s32(0);
+                    acc_hi[i] = vdupq_n_s32(0);
+                }
+                // Need scales for the low and high nibbles
+                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                int16x8_t q5sb_mins[2];
+                int16x8_t q5sb_scales[2];
+                for (int i = 0; i < 2; i++) {
+                    int8_t    aux_q5sb[8];
+                    const int offset = sb * 24 + i * 12;
+                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                }
+
+                int8x16_t q8_qs[4];
+                for (int i = 0; i < 4; i++) {
+                    q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
+                }
+
+                for (int c = 0; c < col_groups; c++) {
+                    uint8x16_t q5_cols[8];
+                    uint8x16_t hbit_lo[8];
+                    uint8x16_t hbit_hi[8];
+                    int8x16_t  q5_lo[8];
+                    int8x16_t  q5_hi[8];
+
+                    for (int i = 0; i < 8; i++) {
+                        q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
+                        hbit_lo[i] = vandq_u8(qh[c][i], mone);
+                        hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
+                        qh[c][i]   = vshrq_n_u8(qh[c][i], 2);
+                        q5_lo[i]   = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
+                        q5_hi[i]   = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
+                    }
+
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
+                    acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
+
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
+                    acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
+                }
+
+                // Scales
+                // row c0123 blk0 and blk1
+                const int16x4_t   sc_0123_lo = vget_low_s16(q5sb_scales[0]);
+                const int16x4_t   sc_0123_hi = vget_low_s16(q5sb_scales[1]);
+                const float32x4_t sumf_0123  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
+                                                                       vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
+                acc_f32[0]                   = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
+                // row c4567 blk0 and blk1
+                const int16x4_t   sc_4567_lo = vget_high_s16(q5sb_scales[0]);
+                const int16x4_t   sc_4567_hi = vget_high_s16(q5sb_scales[1]);
+                const float32x4_t sumf_4567  = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
+                                                                       vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
+                acc_f32[1]                   = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
+
+                // Bias Correction
+                const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+                const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+            }  // for sb
+
+            acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
+            acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q5_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_pairs = ncols_interleaved / 2;
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+    const uint8x16_t mone      = vdupq_n_u8(1);
+    const uint8x16_t mtwo      = vdupq_n_u8(2);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[ncols_interleaved / 4];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < ncols_interleaved / 4; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q5_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q5_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
+            float32x4_t q5_dmin_0  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));      // dmin 0..3
+            float32x4_t q5_dmin_1  = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));  // dmin 4..7
+            float32x4_t sb_min_0   = vmulq_f32(q5_dmin_0, q8_d);
+            float32x4_t sb_min_1   = vmulq_f32(q5_dmin_1, q8_d);
+
+            // 2 sb each iteration
+            int32x4_t acc_lo[col_pairs];
+            int32x4_t acc_hi[col_pairs];
+
+            // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
+            const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
+            int16_t         bsums_arr[8];
+            vst1q_s16(bsums_arr, bsums);
+
+            // Load qh once per block and shift after each subblock
+            const uint8_t * qh_base = q5_ptr[b].qh;
+            uint8x16_t      qh[col_pairs][4];
+            for (int cp = 0; cp < col_pairs; cp++) {
+                qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+                qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+                qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+                qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+            }
+
+            for (int sb = 0; sb < QK_K / 64; sb++) {
+                for (int i = 0; i < col_pairs; i++) {
+                    acc_lo[i] = vdupq_n_s32(0);
+                    acc_hi[i] = vdupq_n_s32(0);
+                }
+                // Need scales for the low and high nibbles
+                // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
+                int16x8_t q5sb_scales[2];
+                for (int i = 0; i < 2; i++) {
+                    int8_t    aux_q5sb[8];
+                    const int offset = sb * 24 + i * 12;
+                    decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                    q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                }
+
+                const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
+
+                // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
+                const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
+                int8x16_t      q8_qs[8];
+                for (int i = 0; i < 8; i++) {
+                    q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
+                }
+
+                // Q5s column pair loop unrolled
+                {
+                    // Cols 01
+                    uint8x16_t qs_0 = vld1q_u8(qs_base);
+                    uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
+                    uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
+                    uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
+
+                    uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
+                    uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
+                    uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
+                    uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
+                    uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
+                    uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
+                    uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
+                    uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
+
+                    qh[0][0] = vshrq_n_u8(qh[0][0], 2);
+                    qh[0][1] = vshrq_n_u8(qh[0][1], 2);
+                    qh[0][2] = vshrq_n_u8(qh[0][2], 2);
+                    qh[0][3] = vshrq_n_u8(qh[0][3], 2);
+
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[0] = ggml_vdotq_s32(
+                        acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 23
+                    qs_0 = vld1q_u8(qs_base + 16);
+                    qs_1 = vld1q_u8(qs_base + 80);
+                    qs_2 = vld1q_u8(qs_base + 144);
+                    qs_3 = vld1q_u8(qs_base + 208);
+
+                    hbit_lo_0 = vandq_u8(qh[1][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[1][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[1][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[1][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
+
+                    qh[1][0] = vshrq_n_u8(qh[1][0], 2);
+                    qh[1][1] = vshrq_n_u8(qh[1][1], 2);
+                    qh[1][2] = vshrq_n_u8(qh[1][2], 2);
+                    qh[1][3] = vshrq_n_u8(qh[1][3], 2);
+
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[1] = ggml_vdotq_s32(
+                        acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 45
+                    qs_0 = vld1q_u8(qs_base + 32);
+                    qs_1 = vld1q_u8(qs_base + 96);
+                    qs_2 = vld1q_u8(qs_base + 160);
+                    qs_3 = vld1q_u8(qs_base + 224);
+
+                    hbit_lo_0 = vandq_u8(qh[2][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[2][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[2][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[2][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
+
+                    qh[2][0] = vshrq_n_u8(qh[2][0], 2);
+                    qh[2][1] = vshrq_n_u8(qh[2][1], 2);
+                    qh[2][2] = vshrq_n_u8(qh[2][2], 2);
+                    qh[2][3] = vshrq_n_u8(qh[2][3], 2);
+
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[2] = ggml_vdotq_s32(
+                        acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+
+                    // Cols 45
+                    qs_0 = vld1q_u8(qs_base + 48);
+                    qs_1 = vld1q_u8(qs_base + 112);
+                    qs_2 = vld1q_u8(qs_base + 176);
+                    qs_3 = vld1q_u8(qs_base + 240);
+
+                    hbit_lo_0 = vandq_u8(qh[3][0], mone);
+                    hbit_lo_1 = vandq_u8(qh[3][1], mone);
+                    hbit_lo_2 = vandq_u8(qh[3][2], mone);
+                    hbit_lo_3 = vandq_u8(qh[3][3], mone);
+                    hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
+                    hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
+                    hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
+                    hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
+
+                    qh[3][0] = vshrq_n_u8(qh[3][0], 2);
+                    qh[3][1] = vshrq_n_u8(qh[3][1], 2);
+                    qh[3][2] = vshrq_n_u8(qh[3][2], 2);
+                    qh[3][3] = vshrq_n_u8(qh[3][3], 2);
+
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
+                    acc_lo[3] = ggml_vdotq_s32(
+                        acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
+                                               q8_qs[4]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
+                                               q8_qs[5]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
+                                               q8_qs[6]);
+                    acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
+                                               q8_qs[7]);
+                }
+
+                // Prepare bsum vectors for bias computation
+                // Each pair of subblocks share the same bsums
+                int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
+                int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
+
+                // Iterates over a pair of column pairs (4 columns) to use a single 128 register
+                // p = 0 -> 0123  p2 -> 4567
+                for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
+                    int16x4_t   group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
+                    int16x4_t   group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
+                    int16x4_t   group_mins_lo   = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
+                    int16x4_t   group_mins_hi   = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
+                    float32x4_t sb_scale        = p == 0 ? sb_scale_0 : sb_scale_1;
+                    float32x4_t sb_min          = p == 0 ? sb_min_0 : sb_min_1;
+
+                    // 0123 or 4567
+                    float32x4_t sumf_0 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
+
+                    float32x4_t sumf_1 =
+                        vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
+                    acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
+
+                    // FUSED BIAS: Compute and subtract bias immediately
+                    // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
+                    int32x4_t bias       = vmull_s16(bsums_vec_lo, group_mins_lo);
+                    bias                 = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
+                    float32x4_t bias_f32 = vcvtq_f32_s32(bias);
+                    acc_f32[i]           = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
+                }
+            }  // for sb
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q6_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_groups = ncols_interleaved / 4;
+    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo    = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi    = vdupq_n_u8(0x30);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[2];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < col_groups; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
+
+            int32x4_t acc[col_groups];
+            for (int i = 0; i < col_groups; i++) {
+                acc[i] = vdupq_n_s32(0);
+            }
+
+            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
+            // Reused for bias and dequantization later
+            int16_t q6_scales[16 * 8];
+            for (int i = 0; i < 16; i++) {
+                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                vst1q_s16(q6_scales + i * 8, scales);
+            }
+
+            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
+            int32x4_t bias_lo = vdupq_n_s32(0);
+            int32x4_t bias_hi = vdupq_n_s32(0);
+
+            // Load bsums in chunks of 4 to process with vectorized operations
+            for (int i = 0; i < 16; i += 4) {
+                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
+                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
+                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
+                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
+                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
+                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
+                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
+                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
+                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
+
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
+            }
+            bias_lo = vshlq_n_s32(bias_lo, 5);
+            bias_hi = vshlq_n_s32(bias_hi, 5);
+
+            // Process two 128-value halves per superblock
+            for (int half = 0; half < 2; half++) {
+                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                // A subblock (sb) is a set of weights that share the scale
+                // Since q6_K scales are per 16 elements
+                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
+                    const int8_t * q8_base_h = q8_base_l + 64;
+
+                    // Load and duplicate q8 values (each register covers four interleaved columns of q6)
+                    int8x16_t q8_l[4];
+                    int8x16_t q8_h[4];
+                    for (int i = 0; i < 4; i++) {
+                        q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
+                        q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
+                    }
+
+                    const int ql_off_base = sb * QK_K / 2;
+                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
+
+                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
+                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
+                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
+                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
+                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
+
+                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
+                    if (sb > 1) {
+                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
+                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
+                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
+                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
+                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
+                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
+                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
+                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
+                    }
+
+                    const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
+                                                  q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
+                    const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
+                                                  q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
+
+                    // Process column groups (0-3, 4-7)
+                    for (int g = 0; g < col_groups; g++) {
+                        int32x4_t sb_acc_l = vdupq_n_s32(0);
+                        int32x4_t sb_acc_h = vdupq_n_s32(0);
+
+                        for (int chunk = 0; chunk < 4; chunk++) {
+                            const int idx = chunk * 2 + g;
+
+                            const uint8x16_t q6_qs_l = q6_ql[idx];
+                            const uint8x16_t q6_qs_h = q6_qh[idx];
+
+                            // Extract high 2 bits for upper nibble reconstruction
+                            const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
+
+                            // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
+                            const int8x16_t q6_l =
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
+                            const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
+
+                            sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
+                            sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
+                        }
+
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
+                        const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
+
+                        acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
+                        acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
+                    }
+                }
+            }  // for half
+
+            // Bias correction
+            acc[0] = vsubq_s32(acc[0], bias_lo);
+            acc[1] = vsubq_s32(acc[1], bias_hi);
+
+            // Apply superblock scale (no mins for q6_K)
+            // acc[g] has [c0, c1, c2, c3]
+            float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
+            float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
+
+            acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
+            acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q6_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_pairs = ncols_interleaved / 2;
+    const uint8x16_t m4b       = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo   = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi   = vdupq_n_u8(0x30);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[2];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        acc_f32[0] = vdupq_n_f32(0);
+        acc_f32[1] = vdupq_n_f32(0);
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
+
+            int32x2_t acc[col_pairs];
+            for (int i = 0; i < col_pairs; i++) {
+                acc[i] = vdup_n_s32(0);
+            }
+
+            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
+            // Reused for bias and dequantization later
+            int16_t q6_scales[16 * 8];
+            for (int i = 0; i < 16; i++) {
+                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                vst1q_s16(q6_scales + i * 8, scales);
+            }
+
+            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
+            int32x4_t bias_lo = vdupq_n_s32(0);
+            int32x4_t bias_hi = vdupq_n_s32(0);
+
+            // Load bsums in chunks of 4 to process with vectorized operations
+            for (int i = 0; i < 16; i += 4) {
+                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
+                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
+                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
+                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
+                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
+                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
+                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
+                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
+                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
+
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
+            }
+            bias_lo = vshlq_n_s32(bias_lo, 5);
+            bias_hi = vshlq_n_s32(bias_hi, 5);
+
+            // Process two 128-value halves per superblock
+            for (int half = 0; half < 2; half++) {
+                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                // A subblock (sb) is a set of weights that share the scale
+                // Since q6_K scales are per 16 elements
+                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
+                    const int8_t * q8_base_h = q8_base_l + 64;
+
+                    // Load and duplicate q8 values (each register covers two interleaved columns of q6)
+                    int8x16_t q8_l[2];
+                    int8x16_t q8_h[2];
+                    for (int i = 0; i < 2; i++) {
+                        q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8));
+                        q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
+                    }
+
+                    const int ql_off_base = sb * QK_K / 2;
+                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
+
+                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
+                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
+                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
+                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
+                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
+
+                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
+                    if (sb > 1) {
+                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
+                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
+                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
+                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
+                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
+                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
+                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
+                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
+                    }
+
+                    // Process column pairs (0-1, 2-3, 4-5, 6-7)
+                    for (int cp = 0; cp < col_pairs; cp++) {
+                        const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp];
+                        const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp];
+                        const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp];
+                        const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp];
+
+                        // Extract high 2 bits for upper nibble reconstruction
+                        const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
+                        const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
+
+                        // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
+                        const int8x16_t q6_l0 = vreinterpretq_s8_u8(
+                            vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4));
+                        const int8x16_t q6_l1 = vreinterpretq_s8_u8(
+                            vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4));
+                        const int8x16_t q6_h0 =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh));
+                        const int8x16_t q6_h1 =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh));
+
+                        int32x4_t sb_acc_l = vdupq_n_s32(0);
+                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]);
+                        sb_acc_l           = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]);
+
+                        int32x4_t sb_acc_h = vdupq_n_s32(0);
+                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]);
+                        sb_acc_h           = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]);
+
+                        // Pairwise add to get per-column sums: [col0, col1]
+                        int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l));
+                        int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h));
+
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        // Access scales using array indexing (scales are interleaved by column)
+                        const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2],
+                                                        (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] };
+                        const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2],
+                                                        (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] };
+
+                        // Accumulate scaled results
+                        acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l);
+                        acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h);
+                    }
+                }
+            }  // for half
+
+            // Bias correction
+            acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo));
+            acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo));
+            acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi));
+            acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi));
+
+            // Apply superblock scale (no mins for q6_K)
+            // acc[cp] has [c0, c1]
+            float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0));
+            float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0));
+            float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1));
+            float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1));
+
+            acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23));
+            acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67));
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q8_0_4x4_q8_0(int                        n,
                              float * GGML_RESTRICT      s,
                              size_t                     bs,
@@ -2329,6 +3239,87 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+            float32x4_t sumf[4];
+            for (int m = 0; m < 4; m++) {
+                sumf[m] = vdupq_n_f32(0);
+            }
+
+            for (int l = 0; l < nb; l++) {
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                float32x4_t b_d = {
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
+                };
+
+                int32x4_t sumi_0 = vdupq_n_s32(0);
+                int32x4_t sumi_1 = vdupq_n_s32(0);
+                int32x4_t sumi_2 = vdupq_n_s32(0);
+                int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                for (int k = 0; k < 4; k++) {
+                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                }
+
+                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+            }
+
+            for (int m = 0; m < 4; m++) {
+                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+            }
+        }
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     constexpr int qk = QK_K;
     const int     nb = n / qk;
@@ -2431,7 +3422,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     for (int i = 0; i < 2; i++) {
                         int8_t    aux_q4sb[8];
                         const int offset = sb * 24 + i * 12;
-                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
+                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
                         q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
                     }
 
@@ -2529,6 +3520,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
     ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_q5_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    acc_size      = 2 * 4;  // 2 row pairs, 4 col pairs
+    constexpr int    col_groups    = ncols_interleaved / 4;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mone          = vdupq_n_u8(1);
+    const uint8x16_t mtwo          = vdupq_n_u8(2);
+
+    // 8 accumulators: 2 row pairs, 4 col pairs
+    float32x4_t acc_f32[acc_size];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < acc_size; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // d5 0 1 2 3, 4 5 6 7
+                float32x4_t q5_d_0123    = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
+                float32x4_t q5_d_4567    = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
+                // d8 0 1 2 3
+                float32x4_t q8_d_0123    = vld1q_f32(q8_ptr[b].d);
+                // mins
+                float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
+                float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
+
+                // Precomputation of scales and mins
+                float32x4_t sbd_scale_0123[q8_k_blocklen];
+                float32x4_t sbd_scale_4567[q8_k_blocklen];
+                float32x4_t sbd_min_0123[q8_k_blocklen];
+                float32x4_t sbd_min_4567[q8_k_blocklen];
+
+                sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
+                sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
+                sbd_min_0123[0]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
+                sbd_min_4567[0]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
+
+                sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
+                sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
+                sbd_min_0123[1]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
+                sbd_min_4567[1]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
+
+                sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
+                sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
+                sbd_min_0123[2]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
+                sbd_min_4567[2]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
+
+                sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
+                sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
+                sbd_min_0123[3]   = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
+                sbd_min_4567[3]   = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
+
+                // Precomputation of bsums, each vpaddq calcs all the bsums for each row
+                const int16x8_t bsums[q8_k_blocklen] = {
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[QK_K / 64][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
+                int32x4_t bias_acc[acc_size];
+                for (int i = 0; i < acc_size; i++) {
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                uint8x16_t qh[col_groups][8];
+                for (int c = 0; c < col_groups; c++) {
+                    for (int i = 0; i < 8; i++) {
+                        qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
+                    }
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Int accumulators for qs vecdot (4 row * 2 col quartets)
+                    int32x4_t acc_lo[acc_size];
+                    int32x4_t acc_hi[acc_size];
+                    for (int i = 0; i < acc_size; i++) {
+                        acc_lo[i] = vdupq_n_s32(0);
+                        acc_hi[i] = vdupq_n_s32(0);
+                    }
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int16x8_t q5sb_scales[2];
+                    int16x8_t q5sb_mins[2];
+                    for (int i = 0; i < 2; i++) {
+                        int8_t    aux_q5sb[8];
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
+                        q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
+                    }
+
+                    constexpr int reads_per_sb = 8;  // 8 * 16 bytes each => 32 qs * 4 rows
+                    for (int k = 0; k < reads_per_sb; k++) {
+                        const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
+                        const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
+
+                        // 0..3 & 32..35
+                        const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
+                        const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
+
+                        // NOTE: This is the only difference with q4_K
+                        const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
+                        const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
+                        qh[0][k]                      = vshrq_n_u8(qh[0][k], 2);
+                        const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
+                        const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
+                        qh[1][k]                      = vshrq_n_u8(qh[1][k], 2);
+                        // From here, same as q4_K
+
+                        const int8x16_t q5_0123_lo =
+                            vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
+                        const int8x16_t q5_0123_hi =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
+
+                        acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0);  //  0..3  r0 c0123
+                        acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1);  //  0..3  r1 c0123
+                        acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2);  //  0..3  r2 c0123
+                        acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3);  //  0..3  r3 c0123
+
+                        acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0);  // 32..35 r0 c0123
+                        acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1);  // 32..35 r1 c0123
+                        acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2);  // 32..35 r2 c0123
+                        acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3);  // 32..35 r3 c0123
+
+                        const int8x16_t q5_4567_lo =
+                            vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
+                        const int8x16_t q5_4567_hi =
+                            vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
+
+                        acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0);  //  0..3  r0 c4567
+                        acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1);  //  0..3  r1 c4567
+                        acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2);  //  0..3  r2 c4567
+                        acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3);  //  0..3  r3 c4567
+
+                        acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0);  // 32..35 r0 c4567
+                        acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1);  // 32..35 r1 c4567
+                        acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2);  // 32..35 r2 c4567
+                        acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3);  // 32..35 r3 c4567
+                    }
+
+                    // Scale and bias application
+                    // acc is stored interleaved to match output layout
+                    const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
+                    const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
+                    const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
+                    const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
+                    for (int row = 0; row < q8_k_blocklen; row++) {
+                        // Bias correction
+                        // row c0123 blk0 and blk1
+                        const float32x4_t sumf_0123 =
+                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
+                                                    vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
+                        acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
+
+                        // row c4567 blk0 and blk1
+                        const float32x4_t sumf_4567 =
+                            vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
+                                                    vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
+                        acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
+
+                        // Bias
+                        const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
+                        const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
+
+                        // row c0123 blk0 and blk1
+                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                        bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+
+                        // row c4567 blk0 and blk1
+                        bias_acc[2 * row + 1] =
+                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                        bias_acc[2 * row + 1] =
+                            vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+                    }
+                }  // for sb
+
+                for (int row = 0; row < q8_k_blocklen; row++) {
+                    acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
+                    acc_f32[2 * row + 1] =
+                        vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
+                }
+            }  // for b
+
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q4_K_8x8_q8_K(int                        n,
                              float * GGML_RESTRICT      s,
                              size_t                     bs,
@@ -2550,6 +3770,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
     UNUSED(ncols_interleaved);
     UNUSED(blocklen);
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (svcntb() * 8 == 256) {
+        constexpr int    q8_k_blocklen = 4;
+        const svuint8_t m4b_1          = svdup_n_u8(0x0f);
+        // 8 accumulators: 2 row pairs × 4 col pairs
+        svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
+        uint32_t idx_arr[8] = { 0, 2, 4, 6,  1, 3, 5, 7 };
+        svbool_t pg = svptrue_pat_b32(SV_VL8);
+        svuint32_t idx = svld1(pg, idx_arr);
+
+        static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
+        svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
+
+        for (int y = 0; y < nr / q8_k_blocklen; y++) {
+            const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
+
+                acc_f32_01 = svdup_n_f32(0);
+                acc_f32_23 = svdup_n_f32(0);
+                acc_f32_45 = svdup_n_f32(0);
+                acc_f32_67 = svdup_n_f32(0);
+
+                for (int b = 0; b < nb; b++) {
+                    // bsums pairs belongs to the same q8_k subblock
+                    // 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
+                    const int16x8_t bsums[4]{
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                        vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                    };
+
+                    int32_t bsums_arr32[4][8];
+
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        int16x8_t v16 = bsums[q8_row];
+
+                        // low 4
+                        int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
+                        vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
+
+                        // high 4
+                        int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
+                        vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
+                    }
+
+                    svint32_t sb_acc_0 = svdup_n_s32(0);
+                    svint32_t sb_acc_2 = svdup_n_s32(0);
+
+                    svint32_t acc_00 = svdup_n_s32(0);
+                    svint32_t acc_11 = svdup_n_s32(0);
+                    svint32_t acc_22 = svdup_n_s32(0);
+                    svint32_t acc_33 = svdup_n_s32(0);
+                    svint32_t acc_44 = svdup_n_s32(0);
+                    svint32_t acc_55 = svdup_n_s32(0);
+                    svint32_t acc_66 = svdup_n_s32(0);
+                    svint32_t acc_77 = svdup_n_s32(0);
+
+                    svint32_t bias_acc_00 = svdup_n_s32(0);
+                    svint32_t bias_acc_22 = svdup_n_s32(0);
+                    svint32_t bias_acc_44 = svdup_n_s32(0);
+                    svint32_t bias_acc_66 = svdup_n_s32(0);
+
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        // Need scales for the low and high nibbles
+                        // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                        svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
+                        svint32_t q4sb_mins_0, q4sb_mins_1;
+                        {
+                            // 2-superblock I am working on
+                            const int offset = sb * 24 + 0 * 12;
+                            const uint8_t * scales_in = &q4_ptr[b].scales[offset];
+
+                            const int offset1 = sb * 24 + 12;
+                            const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
+
+                            constexpr uint32_t kmask1 = 0x3f3f3f3f;
+                            constexpr uint32_t kmask2 = 0x0f0f0f0f;
+                            constexpr uint32_t kmask3 = 0x03030303;
+                            constexpr uint8_t  scales_size = 12;
+
+                            uint32_t sm[3];
+                            memcpy(sm, scales_in, scales_size);
+
+                            uint32_t sm1[3];
+                            memcpy(sm1, scales_in1, scales_size);
+
+                            const uint32_t mins_0_3 = sm[1] & kmask1;
+                            const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
+
+                            const uint32_t mins_0_3_1 = sm1[1] & kmask1;
+                            const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
+
+                            svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
+                            svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
+
+                            /* reinterpret u32 → u8 */
+                            svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
+                            svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
+
+                            /* widen u8 → u16->u32 (lower half only) */
+                            svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
+                            svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
+
+                            q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
+                            q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
+
+                            uint32_t scales_u32_0 = sm[0] & kmask1;
+                            uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
+                            uint32_t scales_u32_2 = sm1[0] & kmask1;
+                            uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
+
+                            svuint32_t S01 = svdup_n_u32(scales_u32_0);
+                            svuint32_t S23 = svdup_n_u32(scales_u32_1);
+                            svuint32_t R01 = svdup_n_u32(scales_u32_2);
+                            svuint32_t R23 = svdup_n_u32(scales_u32_3);
+
+                            svint8_t S01_b = svreinterpret_s8_u32(S01);
+                            svint8_t S23_b = svreinterpret_s8_u32(S23);
+                            svint8_t R01_b = svreinterpret_s8_u32(R01);
+                            svint8_t R23_b = svreinterpret_s8_u32(R23);
+
+                            svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
+                            svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
+                            svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
+                            svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
+
+                            block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
+                            block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
+                            block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
+                            block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
+                        }
+
+                        const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
+
+                        // Load 32-byte per row pair, 1 subblock each time
+                        // predicate for activating higher lanes for 16 int8 elements
+                        const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                        // predicate for activating lower lanes for  16 int8 elements
+                        const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
+
+                        svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
+                        svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
+                        svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
+                        svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
+
+                        svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
+                        svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
+                        svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
+                        svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
+
+                        // Q4s columns iterated in pairs (01, 23, 45, 67)
+                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
+
+                            sb_acc_0 = svdup_n_s32(0);
+                            sb_acc_2 = svdup_n_s32(0);
+
+                            svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
+                            svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
+                            svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
+                            svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
+
+                            svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
+                            svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
+                            svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
+                            svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
+
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
+
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
+                            sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
+
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
+
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
+                            sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
+
+                            if(cp == 0) {
+                                acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
+                                acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
+                            }
+                            if(cp == 1) {
+                                acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
+                                acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
+                            }
+                            if(cp == 2) {
+                                acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
+                                acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
+                            }
+                            if(cp == 3) {
+                                acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
+                                acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
+                            }
+                        }
+
+                        bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
+                        bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
+
+                        bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
+                        bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
+
+                        bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
+                        bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
+
+                        bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
+                        bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
+                    }  // for sb
+
+
+                    acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
+                    acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
+                    acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
+                    acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
+                    acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
+                    acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
+                    acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
+                    acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
+
+                    svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
+                    svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
+
+                    svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
+                    svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
+
+                    // Broadcast q8 scalar
+                    svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
+
+                    svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
+
+                    svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
+
+                    svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
+                    acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
+
+                    q8_d = svdup_f32(q8_ptr[b].d[1]);
+
+                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
+                    acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
+
+                    q8_d = svdup_f32(q8_ptr[b].d[2]);
+
+
+                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
+                    acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
+
+                    q8_d = svdup_f32(q8_ptr[b].d[3]);
+
+                    scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
+                    dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
+
+                    acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
+                    acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
+
+                }  // for b
+
+                // With the previous reorder, the tile is already in the correct memory layout.
+                // Predicate for exactly 4 lanes
+                svbool_t pg4 = svptrue_pat_b32(SV_VL4);
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    int row = y * q8_k_blocklen + i;
+                    for (int j = 0; j < 2; j++) {
+                        int col    = x * ncols_interleaved + j * 4;
+                        int offset = row * bs + col;
+
+                        if (i == 0 && j == 0) {
+                            // acc_f32_0 → lower half of acc_f32_01
+                            svst1_f32(pg4, s + offset, acc_f32_01);
+                        } else if (i == 0 && j == 1) {
+                            // acc_f32_1 → upper half of acc_f32_01
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
+                        } else if (i == 1 && j == 0) {
+                            // acc_f32_2
+                            svst1_f32(pg4, s + offset, acc_f32_23);
+                        } else if (i == 1 && j == 1) {
+                            // acc_f32_3
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
+                        } else if (i == 2 && j == 0) {
+                            // acc_f32_4
+                            svst1_f32(pg4, s + offset, acc_f32_45);
+                        } else if (i == 2 && j == 1) {
+                            // acc_f32_5
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
+                        } else if (i == 3 && j == 0) {
+                            // acc_f32_6
+                            svst1_f32(pg4, s + offset, acc_f32_67);
+                        } else if (i == 3 && j == 1) {
+                            // acc_f32_7
+                            svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
+                        }
+                    }
+                }
+            }  // for x
+        }  // for y
+        return;
+    }
+#endif  // SVE compile-time end
+
 #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
     constexpr int    q8_k_blocklen = 4;
     const uint8x16_t m4b           = vdupq_n_u8(0x0f);
@@ -2595,7 +4125,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
                     int16x8_t q4sb_mins[2];  // int16 as its needed for bias_acc later
                     for (int i = 0; i < 2; i++) {
                         const int offset = sb * 24 + i * 12;
-                        decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
+                        decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
                     }
 
                     // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
@@ -2660,16 +4190,17 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
 
                         // Scales[i] corresponds to column i
                         const int scale_offset = cp * 2;
-                        for (int blk = 0; blk < 2; blk++) {
-                            const int32x4_t block_scale = {
-                                (int32_t) q4sb_scales[blk][scale_offset],
-                                (int32_t) q4sb_scales[blk][scale_offset],
-                                (int32_t) q4sb_scales[blk][scale_offset + 1],
-                                (int32_t) q4sb_scales[blk][scale_offset + 1],
-                            };
-                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
-                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
-                        }
+                        const int32_t scale_00 = q4sb_scales[0][scale_offset];
+                        const int32_t scale_01 = q4sb_scales[0][scale_offset + 1];
+                        const int32_t scale_10 = q4sb_scales[1][scale_offset];
+                        const int32_t scale_11 = q4sb_scales[1][scale_offset + 1];
+                        const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01));
+                        const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11));
+
+                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0);
+                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0);
+                        acc[cp]     = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1);
+                        acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1);
                     }
 
                     // Multiply Acc bsum + mins
@@ -2738,6 +4269,671 @@ void ggml_gemm_q4_K_8x8_q8_K(int                        n,
     ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_q5_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    col_pairs     = ncols_interleaved / 2;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mone          = vdupq_n_u8(1);
+    const uint8x16_t mtwo          = vdupq_n_u8(2);
+
+    // 8 accumulators: 2 row pairs × 4 col pairs
+    float32x4_t acc_f32[blocklen];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < blocklen; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                // bsums pairs belongs to the same q8_k subblock
+                const int16x8_t bsums[4]{
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
+                    vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
+                };
+                int16_t bsums_arr[4][8];
+                for (int q8_row = 0; q8_row < 4; q8_row++) {
+                    vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
+                }
+
+                int32x4_t sb_acc[4];    // Aux accumulators to store subblock (partial) results
+                int32x4_t acc[8];       // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
+                int32x4_t bias_acc[8];  // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
+                for (int i = 0; i < 8; i++) {
+                    acc[i]      = vdupq_n_s32(0);
+                    bias_acc[i] = vdupq_n_s32(0);
+                }
+
+                // Load qh once per block and shift after each subblock
+                const uint8_t * qh_base = q5_ptr[b].qh;
+                uint8x16_t      qh[col_pairs][4];
+                for (int cp = 0; cp < col_pairs; cp++) {
+                    qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
+                    qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
+                    qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
+                    qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
+                }
+
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    // Need scales for the low and high nibbles
+                    // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
+                    int8_t    q5sb_scales[2][8];
+                    int16x8_t q5sb_mins[2];  // int16 as its needed for bias_acc later
+                    for (int i = 0; i < 2; i++) {
+                        const int offset = sb * 24 + i * 12;
+                        decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
+                    }
+
+                    // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
+                    const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
+
+                    int8x16_t q8_qs_01[8];
+                    int8x16_t q8_qs_23[8];
+
+                    // Load 32-byte per row pair, 1 subblock each time
+                    for (int i = 0; i < 8; i++) {
+                        const int offset = i * 32;  // 16 for row 01, 16 for row 23
+                        q8_qs_01[i]      = vld1q_s8(q8_base + offset);
+                        q8_qs_23[i]      = vld1q_s8(q8_base + offset + 16);
+                    }
+
+                    const int8x16_t q8s[2][8] = {
+                        { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
+                         q8_qs_01[7] },
+                        { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
+                         q8_qs_23[7] },
+                    };
+
+                    // Q5s columns iterated in pairs (01, 23, 45, 67)
+                    for (int cp = 0; cp < col_pairs; cp++) {
+                        for (int i = 0; i < 4; i++) {
+                            sb_acc[i] = vdupq_n_s32(0);
+                        }
+
+                        uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0);    // 0 .. 7 & 32..39
+                        uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64);   // 8 ..15 & 40..47
+                        uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128);  // 16..23 & 48..55
+                        uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192);  // 24..31 & 56..63
+
+                        // This is the only part of the algorithm that differs with Q4_K
+                        // Extract High bits and pack into 5 bit weights
+                        uint8x16_t hbit_lo_0    = vandq_u8(qh[cp][0], mone);
+                        uint8x16_t hbit_hi_0    = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
+                        qh[cp][0]               = vshrq_n_u8(qh[cp][0], 2);
+                        // Same as Q4_K, i8mm to dequantize the weights.
+                        const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
+                        int32x4_t       acc_0   = sb_acc[0];
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
+                        int32x4_t acc_2         = sb_acc[2];
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
+                        const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
+                        int32x4_t       acc_1   = sb_acc[1];
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
+                        int32x4_t acc_3         = sb_acc[3];
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
+
+                        // Repeat for the other 3 columns (8..15, 16..23, 24..31)
+                        uint8x16_t hbit_hi_1    = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
+                        uint8x16_t hbit_lo_1    = vandq_u8(qh[cp][1], mone);
+                        qh[cp][1]               = vshrq_n_u8(qh[cp][1], 2);
+                        const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
+                        const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
+
+                        uint8x16_t hbit_hi_2    = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
+                        uint8x16_t hbit_lo_2    = vandq_u8(qh[cp][2], mone);
+                        qh[cp][2]               = vshrq_n_u8(qh[cp][2], 2);
+                        const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
+                        const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
+
+                        uint8x16_t hbit_lo_3    = vandq_u8(qh[cp][3], mone);
+                        uint8x16_t hbit_hi_3    = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
+                        qh[cp][3]               = vshrq_n_u8(qh[cp][3], 2);
+                        const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
+                        acc_0                   = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
+                        sb_acc[0]               = acc_0;
+                        acc_2                   = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
+                        sb_acc[2]               = acc_2;
+
+                        // Scales[i] corresponds to column i
+                        const int       scale_offset = cp * 2;
+                        const int32_t   s0           = q5sb_scales[0][scale_offset];
+                        const int32_t   s1           = q5sb_scales[0][scale_offset + 1];
+                        const int32x4_t block_scale  = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
+                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
+                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
+
+                        const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
+                        acc_1                   = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
+                        sb_acc[1]               = acc_1;
+                        acc_3                   = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
+                        sb_acc[3]               = acc_3;
+
+                        const int32_t   s2           = q5sb_scales[1][scale_offset];
+                        const int32_t   s3           = q5sb_scales[1][scale_offset + 1];
+                        const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
+                        acc[cp]                      = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
+                        acc[cp + 4]                  = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
+                    }
+
+                    // Multiply Acc bsum + mins
+                    for (int q8_row = 0; q8_row < 4; q8_row++) {
+                        // Each pair of subblocks share the same bsums
+                        // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
+                        int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
+                        int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
+
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
+                        bias_acc[2 * q8_row] =
+                            vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
+                        bias_acc[2 * q8_row + 1] =
+                            vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
+                    }
+                }  // for sb
+
+                // Reorder of i8mm output with bias and output layout
+                for (int i = 0; i < 8; i++) {
+                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
+                }
+                int32x4_t reorder_acc[8] = {
+                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+                };
+
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    for (int j = 0; j < 2; j++) {
+                        float32x4_t       q8_d    = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
+                        const float32x4_t dmins   = vmulq_f32(q5_dmin, q8_d);
+
+                        float32x4_t       q5_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q5_d, q8_d);
+
+                        acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
+                        acc_f32[2 * i + j] =
+                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+                    }
+                }
+            }  // for b
+
+            // With the previous reorder, the tile is already in the correct memory layout.
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q6_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    col_groups    = ncols_interleaved / 4;
+    constexpr int    acc_size      = q8_k_blocklen * col_groups;  // 4 rows, 2 column groups
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
+    const int8x16_t  m32s          = vdupq_n_s8(32);
+
+    float32x4_t acc_f32[acc_size];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < acc_size; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
+                float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
+                float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
+
+                float32x4_t sbd_scale_0123[q8_k_blocklen];
+                float32x4_t sbd_scale_4567[q8_k_blocklen];
+
+                sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
+                sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
+                sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
+                sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
+                sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
+                sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
+                sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
+                sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
+
+                int32x4_t acc_s32[acc_size];
+                for (int i = 0; i < acc_size; i++) {
+                    acc_s32[i] = vdupq_n_s32(0);
+                }
+
+                int16_t q6_scales[8 * 16];
+                for (int i = 0; i < 16; i++) {
+                    int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                    vst1q_s16(q6_scales + i * 8, scales);
+                }
+
+                for (int half = 0; half < 2; half++) {
+                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        int32x4_t acc_lo[acc_size];
+                        int32x4_t acc_hi[acc_size];
+                        for (int i = 0; i < acc_size; i++) {
+                            acc_lo[i] = vdupq_n_s32(0);
+                            acc_hi[i] = vdupq_n_s32(0);
+                        }
+
+                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
+                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
+
+                        // 4 rows * 16 elements per scale
+                        // 4 reads of 16 bytes each
+                        constexpr int reads_per_sb = 4;
+                        int8x16_t     q8_l[reads_per_sb];
+                        int8x16_t     q8_h[reads_per_sb];
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
+                            q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
+                        }
+
+                        const int ql_off_base = sb * QK_K / 2;
+                        const int qh_off_base = ql_off_base & 255;
+
+                        uint8x16_t q6_ql_0123[reads_per_sb];
+                        uint8x16_t q6_ql_4567[reads_per_sb];
+                        uint8x16_t q6_qh_0123[reads_per_sb];
+                        uint8x16_t q6_qh_4567[reads_per_sb];
+
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
+                            q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
+                            q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
+                            q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
+                        }
+
+                        if (sb > 1) {
+                            for (int k = 0; k < reads_per_sb; k++) {
+                                q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
+                                q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
+                            }
+                        }
+
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            // q = (ql | qh) - 32
+                            const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
+                            const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
+                            const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
+                            const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
+
+                            const int8x16_t q6_0123_lo = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
+                            const int8x16_t q6_0123_hi = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
+
+                            acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0);  //  0..3  r0 c0123
+                            acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1);  //  0..3  r1 c0123
+                            acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2);  //  0..3  r2 c0123
+                            acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3);  //  0..3  r3 c0123
+
+                            acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0);  // 64..67 r0 c0123
+                            acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1);  // 64..67 r1 c0123
+                            acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2);  // 64..67 r2 c0123
+                            acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3);  // 64..67 r3 c0123
+
+                            const int8x16_t q6_4567_lo = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
+                            const int8x16_t q6_4567_hi = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
+
+                            acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0);  //  0..3  r0 c4567
+                            acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1);  //  0..3  r1 c4567
+                            acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2);  //  0..3  r2 c4567
+                            acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3);  //  0..3  r3 c4567
+
+                            acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0);  // 64..67 r0 c4567
+                            acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1);  // 64..67 r1 c4567
+                            acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2);  // 64..67 r2 c4567
+                            acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3);  // 64..67 r3 c4567
+                        }
+
+                        // Scale and bias
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        for (int g = 0; g < col_groups; g++) {
+                            const int16x4_t scales_l16  = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
+                            const int16x4_t scales_h16  = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
+                            const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
+                            const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
+                            const int       acc_offset  = g * q8_k_blocklen;
+
+                            for (int row = 0; row < q8_k_blocklen; row++) {
+                                const int idx = row * 2 + g;
+                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
+                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
+                            }
+                        }
+                    }
+                }
+
+                // Finally we apply the superblock scales
+                for (int row = 0; row < q8_k_blocklen; row++) {
+                    const int       idx0     = 2 * row;
+                    const int       idx1     = 2 * row + 1;
+                    const int32x4_t acc_0123 = acc_s32[idx0];
+                    const int32x4_t acc_4567 = acc_s32[idx1];
+
+                    acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
+                    acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
+                }
+            }  // for b
+
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q6_K_8x8_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    constexpr int    q8_k_blocklen = 4;
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
+    const int8x16_t  m32s          = vdupq_n_s8(32);
+
+    // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7)
+    float32x4_t acc_f32[blocklen];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < blocklen; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                int32x4_t acc[8];  // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7]
+                for (int i = 0; i < 8; i++) {
+                    acc[i] = vdupq_n_s32(0);
+                }
+
+                // Q6_K has simple 8-bit scales, 16 per block (one per 16 values)
+                // Reused for bias and dequantization later
+                int16_t q6_scales[16 * 8];
+                for (int i = 0; i < 16; ++i) {
+                    int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                    vst1q_s16(q6_scales + i * 8, s16);
+                }
+
+                // Process two 128-value halves per superblock
+                for (int half = 0; half < 2; half++) {
+
+                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                    // A subblock (sb) is a set of weights that share the scale
+                    // Since q6_K scales are per 16 elements
+                    // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        // Q6_K weight index increasing by 64 instead of 32 requires
+                        // loading various q8 memory regions
+                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
+                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
+
+                        int8x16_t q8_l_01[2];
+                        int8x16_t q8_l_23[2];
+                        for (int i = 0; i < 2; i++) {
+                            const int offset = i * 32;
+                            q8_l_01[i]       = vld1q_s8(q8_base_l + offset);       // 0..7 & 8..15 (r01)
+                            q8_l_23[i]       = vld1q_s8(q8_base_l + offset + 16);  // 0..7 & 8..15 (r23)
+                        }
+
+                        int8x16_t q8_h_01[2];
+                        int8x16_t q8_h_23[2];
+                        for (int i = 0; i < 2; i++) {
+                            const int offset = i * 32;
+                            q8_h_01[i]       = vld1q_s8(q8_base_h + offset);
+                            q8_h_23[i]       = vld1q_s8(q8_base_h + offset + 16);
+                        }
+
+                        const int ql_off_base = sb * QK_K / 2;
+
+                        uint8x16_t q6_ql_0[4];
+                        uint8x16_t q6_ql_1[4];
+                        for (int k = 0; k < 4; k++) {
+                            q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k);
+                            q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k);
+                        }
+
+                        const int  qh_off_base = (sb * QK_K / 2) & 255;  // wrap after 256 bytes
+                        uint8x16_t q6_qh_0[4];
+                        uint8x16_t q6_qh_1[4];
+                        for (int k = 0; k < 4; k++) {
+                            q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k);
+                            q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k);
+                        }
+
+                        // Adjust for the proper high bits (Sb 2 and 3)
+                        if (sb > 1) {
+                            for (int k = 0; k < 4; k++) {
+                                q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2);
+                                q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2);
+                            }
+                        }
+
+                        // Process column pairs (0-1, 2-3, 4-5, 6-7)
+                        for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
+                            const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp];
+                            const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp];
+                            const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp];
+                            const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp];
+
+                            // Extract high 2 bits for upper nibble reconstruction
+                            const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi);
+                            const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi);
+
+                            // q6 = (low4 | high2<<4) - 32
+                            // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K)
+                            const int8x16_t q6_l0 = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)),
+                                m32s);
+                            const int8x16_t q6_l1 = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)),
+                                m32s);
+                            const int8x16_t q6_h0 = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s);
+                            const int8x16_t q6_h1 = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s);
+
+                            // row pair 0, base_l
+                            int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]);
+                            sb_acc_0l           = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]);
+                            // row pair 0, base_h
+                            int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]);
+                            sb_acc_0h           = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]);
+                            // row pair 1, base_l
+                            int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]);
+                            sb_acc_1l           = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]);
+                            // row pair 1, base_h
+                            int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]);
+                            sb_acc_1h           = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]);
+
+                            const int scale_idx_l = half * 8 + sb;
+                            const int scale_idx_h = half * 8 + sb + 4;
+
+                            const int32x4_t scale_vec_l = {
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],
+                                q6_scales[scale_idx_l * 8 + cp * 2 + 1],
+                            };
+                            const int32x4_t scale_vec_h = {
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 0],
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],
+                                q6_scales[scale_idx_h * 8 + cp * 2 + 1],
+                            };
+
+                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l);
+                            acc[cp]     = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h);
+                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l);
+                            acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h);
+                        }
+                    }
+                }  // for half
+
+                // Reorder i8mm output to match memory layout
+                for (int i = 0; i < 8; i++) {
+                    int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
+                    acc[i]          = vcombine_s32(aux.val[0], aux.val[1]);
+                }
+                int32x4_t reorder_acc[8] = {
+                    vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
+                    vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
+                    vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
+                    vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
+                    vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
+                    vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
+                    vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
+                    vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
+                };
+
+                // Apply superblock scale (no mins for q6_K)
+                for (int i = 0; i < q8_k_blocklen; i++) {
+                    for (int j = 0; j < 2; j++) {
+                        float32x4_t       q8_d  = vdupq_n_f32(q8_ptr[b].d[i]);
+                        float32x4_t       q6_d  = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4)));
+                        const float32x4_t scale = vmulq_f32(q6_d, q8_d);
+
+                        acc_f32[2 * i + j] =
+                            vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
+                    }
+                }
+            }  // for b
+
+            // Store results
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
 
 void ggml_gemm_q8_0_4x4_q8_0(int                        n,
                              float * GGML_RESTRICT      s,
diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c
index ae0ebb3c..826055dd 100644
--- a/ggml/src/ggml-cpu/arch/riscv/quants.c
+++ b/ggml/src/ggml-cpu/arch/riscv/quants.c
@@ -113,6 +113,104 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
 #endif
 }
 
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q8_K * y_blocks = (block_q8_K *)y;
+    size_t nb = k / QK_K;
+
+#if defined(__riscv_v_intrinsic)
+    const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8();
+
+    for (size_t i = 0; i < nb; i++) {
+        const float* x_block = x + i * QK_K;
+        block_q8_K* y_block = &y_blocks[i];
+
+        // 1. Calculate Min/Max
+        vfloat32m8_t max_v = __riscv_vfmv_v_f_f32m8(-__builtin_inff(), vlmax_f32m8);
+        vfloat32m8_t min_v = __riscv_vfmv_v_f_f32m8(__builtin_inff(), vlmax_f32m8);
+
+        size_t rem = QK_K;
+        size_t offset = 0;
+        while (rem > 0) {
+            size_t vl = __riscv_vsetvl_e32m8(rem);
+            vfloat32m8_t v_curr = __riscv_vle32_v_f32m8(x_block + offset, vl);
+            max_v = __riscv_vfmax_vv_f32m8(max_v, v_curr, vl);
+            min_v = __riscv_vfmin_vv_f32m8(min_v, v_curr, vl);
+            rem -= vl;
+            offset += vl;
+        }
+
+        vfloat32m1_t v_init_max = __riscv_vfmv_s_f_f32m1(-__builtin_inff(), 1);
+        vfloat32m1_t v_init_min = __riscv_vfmv_s_f_f32m1(__builtin_inff(), 1);
+
+        vfloat32m1_t v_scalar_max = __riscv_vfredmax_vs_f32m8_f32m1(max_v, v_init_max, vlmax_f32m8);
+        vfloat32m1_t v_scalar_min = __riscv_vfredmin_vs_f32m8_f32m1(min_v, v_init_min, vlmax_f32m8);
+
+        float max_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_max);
+        float min_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_min);
+
+        float amax = fabsf(max_val) > fabsf(min_val) ? fabsf(max_val) : fabsf(min_val);
+
+        if (amax == 0.0f) {
+            y_block->d = 0.0f;
+            memset(y_block->qs, 0, QK_K);
+            memset(y_block->bsums, 0, sizeof(y_block->bsums));
+            continue;
+        }
+
+        const float iscale = -127.f / (fabsf(max_val) > fabsf(min_val) ? max_val : min_val);
+        y_block->d = 1.0f / iscale;
+
+        // 2. Quantize and Calculate Sums
+        offset = 0;
+        rem = QK_K;
+        vint16m1_t v_zero_sum = __riscv_vmv_v_x_i16m1(0, 1);
+
+        while (rem > 0) {
+            size_t vl = __riscv_vsetvl_e32m8(rem);
+            vfloat32m8_t v_f = __riscv_vle32_v_f32m8(x_block + offset, vl);
+
+            v_f = __riscv_vfmul_vf_f32m8(v_f, iscale, vl);
+
+            vint32m8_t v_i32 = __riscv_vfcvt_x_f_v_i32m8_rm(v_f, __RISCV_FRM_RNE, vl);
+            vint16m4_t v_i16 = __riscv_vnclip_wx_i16m4(v_i32, 0, __RISCV_VXRM_RNE, vl);
+            vint8m2_t v_q = __riscv_vnclip_wx_i8m2(v_i16, 0, __RISCV_VXRM_RNE, vl);
+
+            __riscv_vse8_v_i8m2(y_block->qs + offset, v_q, vl);
+
+            // first iteration clear
+
+            int sum_idx;
+            vint8m1_t chunk_m1;
+            vint16m1_t v_sum;
+            sum_idx = offset / 16;
+            chunk_m1 = __riscv_vget_v_i8m2_i8m1(v_q, 0);
+            v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16);
+            y_block->bsums[sum_idx] = (int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum);
+
+            // remaining iterations
+            vint8m2_t slid_q = v_q;
+            for (size_t k = 16; k < vl; k += 16) {
+                slid_q = __riscv_vslidedown_vx_i8m2(slid_q, 16, vl);
+
+                sum_idx = (offset + k) / 16;
+                chunk_m1 = __riscv_vget_v_i8m2_i8m1(slid_q, 0);
+
+                v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16);
+                y_block->bsums[sum_idx] =(int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum);
+            }
+
+            rem -= vl;
+            offset += vl;
+        }
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_K_ref(x, y, k);
+#endif
+}
+
 //===================================== Dot products =================================
 
 void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -1954,3 +2052,1558 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 #endif
 }
 
+static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        // Load qh once for the entire superblock.
+        vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8);
+
+        // Calculate ls.
+        vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8);
+        temp = __riscv_vand_vx_u16mf2(temp, 7, 8);
+        vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8));
+        ls = __riscv_vadd_vx_i32m1(ls, 1, 8);
+
+        // Calculate delta.
+        vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8);
+        vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8);
+        vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8);
+        vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8);
+
+        // Load qs.
+        vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32);
+
+        // Prepare the indices.
+        const uint64_t shift = 0x0009000600030000;
+        vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8));
+        vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2(
+            __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32));
+        vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh));
+        vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32);
+        qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32);
+        qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32);
+        qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32);
+        qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32);
+        vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32);
+
+        // Final lsums.
+        int32_t lsums_s[8];
+        vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1);
+
+        // Sub-blocks 1-4
+        {
+            vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0);
+            vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16));
+            vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128);
+            vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128);
+            lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32));
+            lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32));
+            lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32));
+            lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32));
+        }
+        __asm__ __volatile__("" ::: "memory");
+        // Sub-blocks 5-8
+        {
+            vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1);
+            vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16));
+            vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128);
+            vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128);
+            lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32));
+            lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32));
+            lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32));
+            lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32));
+        }
+        __asm__ __volatile__("" ::: "memory");
+        vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8);
+
+        // Calculate the bsums.
+        vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16);
+        const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0));
+        const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8));
+        const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8));
+        const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8);
+
+        // Accumulation.
+        vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8);
+        vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8);
+
+        // Update sumf.
+        int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
+        int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8));
+        sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_m * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    iq1m_scale_t scale;
+    float sumf = 0.0f;
+    for (int i = 0; i < nb; ++i) {
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        // Accumulators.
+        vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16);
+        vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16);
+
+        // We process 4 sub-blocks together.
+        for (int ib = 0; ib < QK_K/128; ib++) {
+            // Load qh for 4 sub-blocks.
+            const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8);
+            const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8);
+            const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8);
+            const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1(
+                __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16);
+            qh += 8;
+
+            // Prepare grid indices.
+            const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16);
+            const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8));
+            vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16);
+            index = __riscv_vsll_vx_u16m1(index, 3, 16);
+            qs += 16;
+
+            // Load the grid.
+            const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4(
+                __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16)));
+
+            // Prepare the deltas.
+            const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16(
+                __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16);
+            const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16);
+            const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16);
+            const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4(
+                __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16));
+
+            // Load q8 for sub-blocks.
+            const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
+            q8 += 128;
+
+            // Calculate the lsums.
+            const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128);
+            const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128);
+
+            // Prepare the scales.
+            const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1;
+            const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1;
+            const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1;
+            const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1;
+            const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1;
+            const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1;
+            const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1;
+            const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1;
+            sc += 2;
+
+            // Accumulate in acc0 and acc1 for each sub-block.
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16);
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16);
+            //
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16);
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16);
+            //
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16);
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16);
+            //
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16);
+            acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16);
+            acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16);
+        }
+
+        // Reduce and accumulate in `sumf`.
+        vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1);
+        int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16));
+        int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16));
+        sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static const uint8_t sign_gather_indices_arr[64] = {
+    0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,
+    4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7
+};
+
+static const uint8_t sign_bit_masks_arr[64] = {
+    1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128,
+    1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128
+};
+
+
+static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
+
+    const block_iq2_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+    const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
+
+    // Pre-load Constants
+    vuint8m2_t v_ids = __riscv_vid_v_u8m2(32);
+    vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32);
+    vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32);
+    vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32);
+    vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32);
+    uint16_t shift_qh_arr[4] = {11, 9, 7, 5};
+    vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4);
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint8_t * GGML_RESTRICT scales = x[i].scales;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        const uint8_t * signs_ptr = qs + 32;
+        float sum_block = 0.0f;
+
+        for (int ib = 0; ib < 8; ++ib) {
+
+            // Load Low Bits [4 bytes]
+            vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4);
+            qs += 4;
+
+            // Load 1 byte. It contains bits for 4 mini-blocks.
+            uint8_t qh_val = *qh++;
+
+            // Combine Low + High bits of 10bit indices
+            vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4);
+            vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4);
+            vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4);
+            v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4);
+            vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4);
+            vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4);
+            vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4);
+
+            // Lookup Grid
+            vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4)));
+
+            vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4);
+            signs_ptr += 4;
+            vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
+            vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32);
+
+            // generating sign mask
+            vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32);
+            vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32);
+
+            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32);
+            q8 += 32;
+
+            // apply signs
+            vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32);
+            vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32);
+
+            // Reduction
+            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
+
+            // Reduce 0-15 (First Half)
+            int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
+                __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16));
+
+            // Reduce 16-31 (Second Half)
+            int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(
+                __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16));
+
+            // Apply sub Scales
+            uint8_t sc = *scales++;
+
+            sum_block += s0 * (2 * (sc & 0xF) + 1);
+            sum_block += s1 * (2 * (sc >> 4)  + 1);
+        }
+        sumf += sum_block * combined_scale;
+    }
+    *s = 0.125f * sumf;
+}
+
+static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs);
+
+    const block_iq2_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+    const uint64_t * grid64 = (const uint64_t *)iq2s_grid;
+
+    // --- Pre-load Constants ---
+    uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1};
+    vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8);
+    uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5};
+    vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8);
+
+    // Constants for sign extraction
+    vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
+    vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint8_t * GGML_RESTRICT scales = x[i].scales;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        const uint8_t * signs_ptr = qs + 32;
+
+        float sum_block = 0.0f;
+
+        for (int ib = 0; ib < 4; ++ib) {
+            // Combine low + high bits
+            vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8);
+            qs += 8;
+            uint16_t qh_val;
+            memcpy(&qh_val, qh, 2);
+            qh += 2;
+            vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2);
+            vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2);
+            vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16);
+            vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8);
+            v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8);
+
+            // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000
+            v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8);
+            vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8);
+
+            // Multiply by 8 to get byte offset, instead of element offset
+            v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8);
+            vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8);
+
+            // Lookup Grid using Byte Offsets
+            vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8);
+
+            vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals);
+            vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8);
+
+            // Load signs and generate sign mask
+            vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8);
+            signs_ptr += 8;
+
+            vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
+            vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
+
+            vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
+            vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
+
+            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
+            q8 += 64;
+
+            vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
+            vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64);
+
+            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
+
+            int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16));
+            int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16));
+            int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16));
+            int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16));
+
+            uint8_t sc0 = scales[0];
+            uint8_t sc1 = scales[1];
+            scales += 2;
+
+            sum_block += s0 * (2 * (sc0 & 0xF) + 1);
+            sum_block += s1 * (2 * (sc0 >> 4)  + 1);
+            sum_block += s2 * (2 * (sc1 & 0xF) + 1);
+            sum_block += s3 * (2 * (sc1 >> 4)  + 1);
+        }
+        sumf += sum_block * combined_scale;
+    }
+    *s = 0.125f * sumf;
+}
+
+void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 128:
+            ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        case 256:
+            ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+#if defined(__riscv_v)
+static const int8_t keven_signs_q2xs[1024] = {
+     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
+     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
+     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
+     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
+     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
+     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
+     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
+     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
+     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
+     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
+     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
+     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
+     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
+     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
+     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
+     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
+     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
+     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
+     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
+     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
+     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
+     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
+     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
+     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
+     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
+     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
+     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
+     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
+     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
+     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
+     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
+     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+#endif
+
+static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xs * GGML_RESTRICT x = vx;
+    const block_q8_K   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+    const uint64_t * grid64  = (const uint64_t *)iq2xs_grid;
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * GGML_RESTRICT qs = x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
+        const uint8_t  * GGML_RESTRICT scales = x[i].scales;
+
+        int32_t sum_int = 0;
+
+        // Loop over 4 subblocks of 64 elements (QK_K = 256)
+        for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) {
+            // Load 8 uint16 indices (controls 64 values)
+            vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8);
+            qs += 8;
+
+            // Extract indices for grid (low 9 bits) and signs (high 7 bits)
+            // Multiply by 8 (<< 3) for byte offsets into the uint64 tables
+            vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8);
+            vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8);
+
+            vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8);
+            vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8);
+
+            vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64));
+            vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64));
+
+            vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64);
+
+            vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64);
+            q8 += 64;
+
+            vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64);
+
+            vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
+
+            int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                           __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16));
+            int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                           __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16));
+            int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                           __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16));
+            int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
+                           __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16));
+
+            const uint8_t scale_byte_1 = scales[0];
+            const uint8_t scale_byte_2 = scales[1];
+            scales += 2;
+
+            sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1);
+            sum_int += sum1 * ((scale_byte_1 >> 4)   * 2 + 1);
+            sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1);
+            sum_int += sum3 * ((scale_byte_2 >> 4)   * 2 + 1);
+        }
+
+        sumf += d * sum_int;
+    }
+    *s = 0.125f * sumf;
+}
+
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+      switch (__riscv_vlenb() * 8) {
+          case 256:
+              ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+              break;
+          default:
+              ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+              break;
+      }
+#else
+    ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+    const uint64_t * grid64  = (const uint64_t *)iq2xxs_grid;
+
+    uint32_t shift_constants[4] = {0, 7, 14, 21};
+    vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shift_constants, 4);
+
+    float sumf = 0.0f;
+    for (int i = 0; i < nb; ++i) {
+        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t  * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
+
+        float sum = 0.0f;
+
+        #pragma GCC unroll 1
+        for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {
+            vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;
+            vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;
+
+            vuint8mf4_t v_raw_q2_1 = __riscv_vle8_v_u8mf4(q2_ptr, 4);
+            vuint8mf4_t v_raw_q2_2 = __riscv_vle8_v_u8mf4(q2_ptr + 8, 4);
+
+            vuint16mf2_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_1, 4);
+            vuint16mf2_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_2, 4);
+
+            vidx_q2_1 = __riscv_vsll_vx_u16mf2(vidx_q2_1, 3, 4);
+            vidx_q2_2 = __riscv_vsll_vx_u16mf2(vidx_q2_2, 3, 4);
+
+            uint32_t s_packed_1, s_packed_2;
+            memcpy(&s_packed_1, q2_ptr + 4, 4);
+            memcpy(&s_packed_2, q2_ptr + 12, 4);
+
+            vuint32m1_t v_s_1 = __riscv_vmv_v_x_u32m1(s_packed_1, 4);
+            vuint32m1_t v_s_2 = __riscv_vmv_v_x_u32m1(s_packed_2, 4);
+            v_s_1 = __riscv_vsrl_vv_u32m1(v_s_1, v_shifts, 4);
+            v_s_2 = __riscv_vsrl_vv_u32m1(v_s_2, v_shifts, 4);
+
+            v_s_1 = __riscv_vand_vx_u32m1(v_s_1, 127, 4);
+            v_s_2 = __riscv_vand_vx_u32m1(v_s_2, 127, 4);
+
+            vuint16mf2_t vidx_s2_1 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_1, 4), 3, 4);
+            vuint16mf2_t vidx_s2_2 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_2, 4), 3, 4);
+
+            vuint64m2_t vq2_64_1 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_1, 4);
+            vuint64m2_t vq2_64_2 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_2, 4);
+
+            vint8m2_t q2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_1));
+            vint8m2_t q2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_2));
+
+            vuint64m2_t vs2_64_1 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_1, 4);
+            vuint64m2_t vs2_64_2 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_2, 4);
+            vint8m2_t s2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_1));
+            vint8m2_t s2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_2));
+
+            vint8m2_t q8s_1 = __riscv_vmul_vv_i8m2(q8_1, s2_1, 32);
+            vint8m2_t q8s_2 = __riscv_vmul_vv_i8m2(q8_2, s2_2, 32);
+
+            vint16m4_t dot1 = __riscv_vwmul_vv_i16m4(q8s_1, q2_1, 32);
+            vint16m4_t dot2 = __riscv_vwmul_vv_i16m4(q8s_2, q2_2, 32);
+
+            vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
+            vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m4_i32m1(dot1, zero_vec, 32);
+            vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m4_i32m1(dot2, zero_vec, 32);
+
+            int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1);
+            int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2);
+
+            int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1;
+            int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1;
+
+            sum += scalar_sum1 * scale1 + scalar_sum2 * scale2;
+            q2_ptr += 16;
+        }
+        sumf += sum * combined_scale;
+    }
+    *s = 0.125f * sumf;
+}
+
+static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+    const uint64_t * grid64  = (const uint64_t *)iq2xxs_grid;
+
+    uint32_t shift_constants[4] = {0, 7, 14, 21};
+    vuint32mf2_t v_shifts = __riscv_vle32_v_u32mf2(shift_constants, 4);
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t  * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs;
+        const int8_t   * GGML_RESTRICT q8 = y[i].qs;
+
+        float sum = 0.0f;
+
+        for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {
+            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32;
+            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32;
+
+            vuint8mf8_t v_raw_q2_1 = __riscv_vle8_v_u8mf8(q2_ptr, 4);
+            vuint8mf8_t v_raw_q2_2 = __riscv_vle8_v_u8mf8(q2_ptr + 8, 4);
+
+            vuint16mf4_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_1, 4);
+            vuint16mf4_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_2, 4);
+
+            vidx_q2_1 = __riscv_vsll_vx_u16mf4(vidx_q2_1, 3, 4);
+            vidx_q2_2 = __riscv_vsll_vx_u16mf4(vidx_q2_2, 3, 4);
+
+            uint32_t s_packed_1, s_packed_2;
+            memcpy(&s_packed_1, q2_ptr + 4, 4);
+            memcpy(&s_packed_2, q2_ptr + 12, 4);
+
+            vuint32mf2_t v_s_1 = __riscv_vmv_v_x_u32mf2(s_packed_1, 4);
+            vuint32mf2_t v_s_2 = __riscv_vmv_v_x_u32mf2(s_packed_2, 4);
+
+            v_s_1 = __riscv_vsrl_vv_u32mf2(v_s_1, v_shifts, 4);
+            v_s_2 = __riscv_vsrl_vv_u32mf2(v_s_2, v_shifts, 4);
+
+            v_s_1 = __riscv_vand_vx_u32mf2(v_s_1, 127, 4);
+            v_s_2 = __riscv_vand_vx_u32mf2(v_s_2, 127, 4);
+
+            // Narrow u32 -> u16 (vncvt) and Scale by 8 to get byte offsets
+            vuint16mf4_t vidx_s2_1 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_1, 4), 3, 4);
+            vuint16mf4_t vidx_s2_2 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_2, 4), 3, 4);
+
+            // Load q2 values from lookup grid
+            vuint64m1_t vq2_64_1 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_1, 4);
+            vuint64m1_t vq2_64_2 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_2, 4);
+            vint8m1_t q2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_1));
+            vint8m1_t q2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_2));
+
+            // Load sign values
+            vuint64m1_t vs2_64_1 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_1, 4);
+            vuint64m1_t vs2_64_2 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_2, 4);
+            vint8m1_t s2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_1));
+            vint8m1_t s2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_2));
+
+            // Apply signs to q8
+            vint8m1_t q8s_1 = __riscv_vmul_vv_i8m1(q8_1, s2_1, 32);
+            vint8m1_t q8s_2 = __riscv_vmul_vv_i8m1(q8_2, s2_2, 32);
+
+            // multiplying q2 with q8
+            vint16m2_t dot1 = __riscv_vwmul_vv_i16m2(q8s_1, q2_1, 32);
+            vint16m2_t dot2 = __riscv_vwmul_vv_i16m2(q8s_2, q2_2, 32);
+
+            vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
+            vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m2_i32m1(dot1, zero_vec, 32);
+            vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m2_i32m1(dot2, zero_vec, 32);
+            int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1);
+            int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2);
+            int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1;
+            int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1;
+
+            sum += scalar_sum1 * scale1 + scalar_sum2 * scale2;
+            q2_ptr += 16;
+        }
+        sumf += sum * combined_scale;
+    }
+    *s = 0.125f * sumf;
+}
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 128:
+            ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_s * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    const uint64_t * grid64 = (const uint64_t *)iq3s_grid;
+
+    // --- Pre-load Constants ---
+    const uint16_t qh_bit_shifts_arr[16] = {
+        0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
+    };
+    vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
+    vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
+    vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_CPU_FP16_TO_FP32(x[i].d);
+        const float combined_scale = d * y[i].d;
+
+        const uint8_t * GGML_RESTRICT qs = x[i].qs;
+        const uint8_t * GGML_RESTRICT qh = x[i].qh;
+        const uint8_t * GGML_RESTRICT scales = x[i].scales;
+        const uint8_t * GGML_RESTRICT signs = x[i].signs;
+        const int8_t  * GGML_RESTRICT q8 = y[i].qs;
+
+        float sum_block = 0.0f;
+
+        // Loop: Process 64 weights (16 mini-blocks of 4) per iteration
+        for (int ib = 0; ib < 4; ++ib) {
+
+            vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);
+            qs += 16;
+
+            uint16_t qh_val;
+            memcpy(&qh_val, qh, 2);
+            qh += 2;
+
+            vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);
+            // Extract bits: (qh >> i) & 1
+            v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);
+            v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);
+
+            vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);
+            v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);
+            v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);
+            vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);
+
+            // Grid value is 4xuint8
+            vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);
+            vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);
+            vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);
+            signs += 8;
+
+            // Generate sign mask
+            vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
+            vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
+            vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
+            vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
+
+            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
+            q8 += 64;
+
+            // Apply Signs
+            vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
+            vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);
+
+            // Reduction
+            vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
+            vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
+            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
+
+            int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));
+            int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));
+
+            // Apply sub-scales
+            uint8_t sc_byte = *scales++;
+            int sc_lo = (sc_byte & 0xF) * 2 + 1;
+            int sc_hi = (sc_byte >> 4)  * 2 + 1;
+
+            sum_block += s_lo * sc_lo + s_hi * sc_hi;
+        }
+        sumf += sum_block * combined_scale;
+    }
+    *s = sumf;
+}
+
+void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_xxs * GGML_RESTRICT x = vx;
+    const block_q8_K    * GGML_RESTRICT y = vy;
+    const int nb = n / QK_K;
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+    const uint32_t * grid32  = (const uint32_t *)iq3xxs_grid;
+
+    // constants for unpacking logic
+    const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21};
+    vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8);
+
+    const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1};
+    vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8);
+
+    uint32_t aux32[2];
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * GGML_RESTRICT q3_indices = x[i].qs;
+        const uint8_t * GGML_RESTRICT metadata   = x[i].qs + QK_K/4;
+        const int8_t  * GGML_RESTRICT q8         = y[i].qs;
+
+        float block_sum = 0.0f;
+
+        for (int ib = 0; ib < QK_K / 64; ++ib) {
+            // Load q8 (64 bytes)
+            vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
+            q8 += 64;
+
+            // load of metadata via memcpy
+            memcpy(aux32, metadata, 2 * sizeof(uint32_t));
+            metadata += 2 * sizeof(uint32_t);
+
+            // Load q3 indices and gather magnitudes
+            vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16);
+            q3_indices += 16;
+
+            vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16);
+            vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16);
+            vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32));
+
+            // --- Unpacking of Sign Indices ---
+
+            // 1. Load the 2 auxiliary 32-bit integers into a vector
+            vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2);
+
+            // 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes
+            vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8);
+
+            // 3. Apply Shifts and Mask: ((val >> shift) & 127)
+            vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8);
+
+            // 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table)
+            vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8);
+
+            // 5. Gather Signs
+            vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8);
+            vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64));
+
+            vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64);
+            vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64);
+
+            vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
+            vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
+
+            vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
+            vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32);
+            vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32);
+
+            int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1);
+            int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2);
+
+            const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1);
+            const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1);
+
+            block_sum += sum1_i * scale1_f + sum2_i * scale2_f;
+        }
+
+        sumf += d * block_sum;
+    }
+    *s = 0.25f * sumf;
+}
+
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK4_NL == 0);
+    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+    const block_iq4_nl * GGML_RESTRICT x = vx;
+    const block_q8_0   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK4_NL;
+
+    int ib = 0;
+    float sumf = 0;
+
+    // Load the lookup table once.
+    const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16);
+    int acc1, acc2;
+
+    // We process 2 blocks at once.
+    for (; ib + 1 < nb; ib += 2) {
+        // Weights and activations.
+        vuint8m1_t iq4_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16);
+        vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32);
+        vuint8m1_t iq4_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16);
+        vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32);
+
+        // Unpack the weight blocks.
+        vuint8m2_t iq4bits1;
+        iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16));
+        iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16));
+        vuint8m2_t iq4bits2;
+        iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16));
+        iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16));
+
+        // Gather values from the lookup table.
+        vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32);
+        vint8m2_t iq4b2 = __riscv_vrgather_vv_i8m2(values, iq4bits2, 32);
+
+        // Accumulation.
+        vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, iq4b1, 32);
+        vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, iq4b2, 32);
+        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
+        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
+    }
+
+    *s = sumf;
+}
+
+static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK4_NL == 0);
+    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+    const block_iq4_nl * GGML_RESTRICT x = vx;
+    const block_q8_0   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK4_NL;
+
+    int ib = 0;
+    float sumf = 0;
+
+    // Load the lookup table once.
+    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
+    int acc1, acc2;
+
+    // We process 2 blocks at once.
+    for (; ib + 1 < nb; ib += 2) {
+        // Weights and activations.
+        vuint8mf2_t iq4_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16);
+        vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16);
+        vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16);
+        vuint8mf2_t iq4_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16);
+        vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16);
+        vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16);
+
+        // Unpack the weight blocks.
+        vuint8mf2_t iq4bits_lo1 = __riscv_vand_vx_u8mf2(iq4_packed1, 0xf, 16);
+        vuint8mf2_t iq4bits_hi1 = __riscv_vsrl_vx_u8mf2(iq4_packed1, 4, 16);
+        vuint8mf2_t iq4bits_lo2 = __riscv_vand_vx_u8mf2(iq4_packed2, 0xf, 16);
+        vuint8mf2_t iq4bits_hi2 = __riscv_vsrl_vx_u8mf2(iq4_packed2, 4, 16);
+
+        // Gather values from the lookup table.
+        vint8mf2_t iq4b_lo1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo1, 16);
+        vint8mf2_t iq4b_hi1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi1, 16);
+        vint8mf2_t iq4b_lo2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo2, 16);
+        vint8mf2_t iq4b_hi2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi2, 16);
+
+        // Accumulation.
+        vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, iq4b_lo1, 16);
+        sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, iq4b_hi1, 16);
+        vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, iq4b_lo2, 16);
+        sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, iq4b_hi2, 16);
+        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
+        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
+        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
+        sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 128:
+            ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_K == 0);
+
+    const block_iq4_xs * GGML_RESTRICT x = vx;
+    const block_q8_K   * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __riscv_v_intrinsic
+    const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16);
+    float sumf = 0;
+    int acc[4];
+
+    // Indices for re-ordering IQ4 data.
+    uint64_t index[16] = {
+        0, 1, 8, 9,
+        2, 3, 10, 11,
+        4, 5,12, 13,
+        6, 7, 14, 15,
+    };
+    vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16);
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const int8_t  * q8 = y[ibl].qs;
+        const uint8_t * iq4 = x[ibl].qs;
+        uint16_t h = x[ibl].scales_h;
+
+        int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
+
+        for (int ib = 0; ib < QK_K / 128; ++ib) {
+            // Weights and activations.
+            vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64);
+            vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128);
+            iq4 += 64;
+            q8 += 128;
+
+            // Unpack the weight blocks.
+            vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64);
+            vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64);
+            vuint8m4_t iq4bits;
+            iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo);
+            iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi);
+            vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16));
+            vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128);
+
+            // Multiply with activations.
+            vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128);
+
+            // Reduce separately.
+            __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+            __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+            __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+            __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+
+            int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf)  | ((h << 4) & 0x30)) - 32;
+            int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >>  4)  | ((h << 2) & 0x30)) - 32;
+            int ls3 = ((x[ibl].scales_l[ib * 2 + 1] &  0xf) | ((h << 0) & 0x30)) - 32;
+            int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >>  4)  | ((h >> 2) & 0x30)) - 32;
+            h >>= 8;
+
+            sumi1 += acc[0] * ls1;
+            sumi2 += acc[1] * ls2;
+            sumi3 += acc[2] * ls3;
+            sumi4 += acc[3] * ls4;
+        }
+
+        sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4);
+    }
+
+    *s = sumf;
+
+#else
+    UNUSED(x);
+    UNUSED(y);
+    UNUSED(nb);
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_tq1_0 * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    float sumf = 0.0f;
+    uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
+
+    for (int i = 0; i < nb; i++) {
+        // First loop.
+        vint32m4_t suml1;
+        {
+            const int vl = 32;
+            vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl);
+
+            vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl);
+            vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl);
+            vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl);
+            vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl);
+            vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl);
+
+            vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl);
+            vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl);
+            vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl);
+            vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl);
+            vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl);
+
+            vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl);
+            vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl);
+            vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl);
+            vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl);
+            vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl);
+
+            vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl);
+            vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl);
+            suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl);
+        }
+
+        // Second loop.
+        vint32m2_t suml2;
+        {
+            const int vl = 16;
+            vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl);
+
+            vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl);
+            vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl);
+            vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl);
+            vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl);
+            vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl);
+
+            vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl);
+            vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl);
+            vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl);
+            vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl);
+            vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl);
+
+            vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
+            vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl);
+            vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl);
+            vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl);
+            vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl);
+
+            vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl);
+            vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl);
+            suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl);
+        }
+
+        // Third loop.
+        vint32m2_t suml3;
+        {
+            const int vl = 16;
+
+            uint32_t qh;
+            memcpy(&qh, &x[i].qh[0], 4);
+            // Prevent fusion with vmv.
+            __asm__ __volatile__("" : "+r"(qh));
+            vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4));
+
+            vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl);
+
+            vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl);
+
+            vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl);
+
+            vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl);
+            suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl);
+        }
+
+        vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16);
+        sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16);
+        sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16);
+
+        vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16);
+        sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_tq2_0 * GGML_RESTRICT x = vx;
+    const block_q8_K  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_K;
+
+    float sumf = 0.0f;
+    for (int i = 0; i < nb; ++i) {
+        int32_t sumi = 0;
+
+        for (size_t j = 0; j < sizeof(x[0].qs); j += 32) {
+            const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32];
+            const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32];
+            const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32];
+            const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32];
+            const uint8_t* px  = &x[i].qs[j];
+
+            size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32);
+            vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2);
+
+            size_t vl = __riscv_vsetvl_e8m1(32);
+
+            vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl);
+
+            vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl);
+            vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl);
+            vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl);
+            vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl);
+
+            // l=0 (bits 1:0)
+            vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl);
+            vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl);
+
+            // l=1 (bits 3:2)
+            vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl);
+            vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl);
+
+            // l=2 (bits 5:4)
+            vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl);
+            vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl);
+
+            // l=3 (bits 7:6)
+            vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros
+            vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl);
+
+            // 4. Multiply and accumulate
+            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl);
+            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl);
+            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl);
+            vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl);
+
+            vlmax_16m2 = __riscv_vsetvl_e16m2(32);
+            vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1);
+            vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2);
+
+            sumi += __riscv_vmv_x_s_i32m1_i32(vred32);
+        }
+        const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
+        sumf += (float)sumi * d;
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 256:
+            ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
+
+static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+    // Load the lookup table once.
+    const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_mxfp4, 16);
+    int acc1, acc2;
+
+    // We process 2 blocks at once.
+    for (; ib + 1 < nb; ib += 2) {
+        // Weights and activations.
+        vuint8m1_t mx_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16);
+        vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32);
+        vuint8m1_t mx_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16);
+        vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32);
+
+        // Unpack the weight blocks.
+        vuint8m2_t mxbits1;
+        mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16));
+        mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16));
+        vuint8m2_t mxbits2;
+        mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16));
+        mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16));
+
+        // Gather values from the lookup table.
+        vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32);
+        vint8m2_t mxb2 = __riscv_vrgather_vv_i8m2(values, mxbits2, 32);
+
+        // Accumulation.
+        vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, mxb1, 32);
+        vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, mxb2, 32);
+        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1);
+        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
+        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
+    }
+
+    *s = sumf;
+}
+
+static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0  * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+    // Load the lookup table once.
+    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_mxfp4, 16);
+    int acc1, acc2;
+
+    // We process 2 blocks at once.
+    for (; ib + 1 < nb; ib+=2) {
+        // Weights and activations.
+        vuint8mf2_t mx_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16);
+        vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16);
+        vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16);
+        vuint8mf2_t mx_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16);
+        vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16);
+        vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16);
+
+        // Unpack the weight blocks.
+        vuint8mf2_t mxbits_lo1 = __riscv_vand_vx_u8mf2(mx_packed1, 0xf, 16);
+        vuint8mf2_t mxbits_hi1 = __riscv_vsrl_vx_u8mf2(mx_packed1, 4, 16);
+        vuint8mf2_t mxbits_lo2 = __riscv_vand_vx_u8mf2(mx_packed2, 0xf, 16);
+        vuint8mf2_t mxbits_hi2 = __riscv_vsrl_vx_u8mf2(mx_packed2, 4, 16);
+
+        // Gather values from the lookup table.
+        vint8mf2_t mxb_lo1 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo1, 16);
+        vint8mf2_t mxb_hi1 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi1, 16);
+        vint8mf2_t mxb_lo2 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo2, 16);
+        vint8mf2_t mxb_hi2 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi2, 16);
+
+        // Accumulation.
+        vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, mxb_lo1, 16);
+        sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, mxb_hi1, 16);
+        vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, mxb_lo2, 16);
+        sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, mxb_hi2, 16);
+        __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
+        __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1);
+        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1));
+        sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2));
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+#if defined __riscv_v_intrinsic
+    switch (__riscv_vlenb() * 8) {
+        case 128:
+            ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+        default:
+            ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
+            break;
+    }
+#else
+    return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
+#endif
+}
diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp
index 2a35ff9a..cd580787 100644
--- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp
@@ -24,6 +24,94 @@
 
 #define UNUSED GGML_UNUSED
 
+void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+#if defined(__riscv_v_intrinsic)
+    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
+    const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0);
+    const size_t vl_save = __riscv_vsetvl_e64m2(4);
+    vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1));
+
+    for (int i = 0; i < nb; i++) {
+        const float *x_block_base = x + i * QK8_0;
+        vint8m2_t q_r0, q_r1, q_r2, q_r3;
+        {
+            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc);
+            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
+            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
+            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
+
+            float d = amax / 127.0f;
+            y[i].d[0] = GGML_CPU_FP32_TO_FP16(d);
+
+            float id = d ? 1.0f / d : 0.0f;
+            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
+            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
+            q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
+        }
+        asm volatile ("" ::: "memory");
+
+        {
+            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc);
+            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
+            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
+            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
+
+            float d = amax / 127.0f;
+            y[i].d[1] = GGML_CPU_FP32_TO_FP16(d);
+            float id = d ? 1.0f / d : 0.0f;
+
+            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
+            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
+            q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
+        }
+        asm volatile ("" ::: "memory");
+        {
+            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc);
+            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
+            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
+            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
+
+            float d = amax / 127.0f;
+            y[i].d[2] = GGML_CPU_FP32_TO_FP16(d);
+            float id = d ? 1.0f / d : 0.0f;
+
+            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
+            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
+            q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
+        }
+        asm volatile ("" ::: "memory");
+        {
+            vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc);
+            vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc);
+            vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc);
+            float amax = __riscv_vfmv_f_s_f32m1_f32(v_max);
+
+            float d = amax / 127.0f;
+            y[i].d[3] = GGML_CPU_FP32_TO_FP16(d);
+            float id = d ? 1.0f / d : 0.0f;
+
+            vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc);
+            vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc);
+            q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc);
+        }
+        vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0);
+        vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1);
+        vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2);
+        vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3);
+        vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3);
+        __riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save);
+    }
+#else
+    UNUSED(nb);
+    UNUSED(y);
+    ggml_quantize_mat_q8_0_4x4_generic(x, vy, k);
+#endif
+}
+
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -115,6 +203,486 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
+
+        // 1x16 Accumulator
+        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+        for (int l = 0; l < nb; l++) {
+            // 1x16 Integer Accumulator
+            vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+            vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+
+            // Accumulation loop.
+            for (int i = 0; i < QK4_0 / 2; i++) {
+                // Load `b_ptr`.
+                const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
+                const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);
+                const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);
+
+                sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16);
+                sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16);
+            }
+
+            const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);
+
+            const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
+            const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
+
+            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
+        }
+
+        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
+    }
+    return;
+#endif
+    ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
+
+        // 1x16 Accumulator
+        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+        for (int l = 0; l < nb; l++) {
+            vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16);
+
+            // Load `dmin`.
+            const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2(
+                __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16);
+
+            // We process 4 sub-blocks at once.
+            for (int j = 0; j < QK_K / 128; j++) {
+                // Extract the scales and the mins.
+                //
+                // Low bits.
+                vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);
+                vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);
+                vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);
+
+                // High bits.
+                vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);
+                vuint8m2_t scales_hi;
+                vuint8m2_t mins_hi;
+                if (!j) {
+                    scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);
+                    mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);
+                } else {
+                    scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);
+                    mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);
+                }
+                vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);
+                vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));
+
+                // Reduce the mins and multiply with `dmin`.
+                //
+                // Correct in `sumf`.
+                vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16);
+                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
+                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
+                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
+                bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
+
+                sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16);
+
+                // Accumulation for 2 sub-blocks.
+                //
+                // This might overflow, so we accumulate in two steps.
+                //
+                // Recheck.
+                for (int k = 0; k < 2; k++) {
+                    vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                    vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+
+                    for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
+                        // Load `b_ptr`.
+                        const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);
+                        const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
+                        const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
+
+                        sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16);
+                        sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16);
+                    }
+
+                    sumi = __riscv_vwmacc_vv_i32m2(sumi,
+                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
+                        sumi_s_0_16, 16);
+                    sumi = __riscv_vwmacc_vv_i32m2(sumi,
+                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
+                        sumi_s_1_16, 16);
+                }
+                // Accumulation for 2 sub-blocks.
+                //
+                // This might overflow, so we accumulate in two steps.
+                //
+                // Recheck.
+                for (int k = 0; k < 2; k++) {
+                    vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                    vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+
+                    for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
+                        // Load `b_ptr`.
+                        const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);
+                        const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
+                        const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
+
+                        sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16);
+                        sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16);
+                    }
+
+                    sumi = __riscv_vwmacc_vv_i32m2(sumi,
+                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
+                        sumi_s_0_16, 16);
+                    sumi = __riscv_vwmacc_vv_i32m2(sumi,
+                        __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
+                        sumi_s_1_16, 16);
+                }
+            }
+
+            const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16);
+            const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16);
+
+            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
+        }
+
+        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
+    }
+    return;
+#endif
+    ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
+
+        // 1x16 Accumulator1
+        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+        for (int l = 0; l < nb; l++) {
+            // 1x16 integer accumulator
+            vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
+
+            // Accumulation loop.
+            for (int i = 0; i < QK4_NL / 2; i++) {
+                // Load `b_ptr`.
+                const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
+                const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
+                const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
+                // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
+                // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
+
+                const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16);
+                const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16);
+                sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16);
+            }
+
+            const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
+            const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
+
+            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
+        }
+
+        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
+    }
+    return;
+#endif
+    ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+    UNUSED(bs);
+
+#if defined __riscv_v_intrinsic
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
+
+        // 1x16 Accumulator
+        vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+        for (int l = 0; l < nb; l++) {
+            // 1x16 Integer Accumulator
+            vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
+
+            // Accumulation loop.
+            for (int i = 0; i < QK8_0; i++) {
+                // Load `b_ptr`.
+                const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
+                // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);
+
+                sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16);
+            }
+
+            const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
+            const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
+
+            sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
+        }
+
+        __riscv_vse32_v_f32m2(s + x * 16, sumf, 16);
+    }
+    return;
+#endif
+    ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    assert(n % QK_K == 0);
+    assert(nr == 1);
+    assert(nc % 16 == 0);
+
+    UNUSED(bs);
+
+    const int N_COLS_TILE = 16;
+    const int num_k_blocks = n / QK_K;
+
+    const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);
+    for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {
+
+        const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy;
+        const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;
+
+        vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl);
+
+        for (int k_block = 0; k_block < num_k_blocks; ++k_block) {
+            const block_q8_K* lhs_current = &lhs_base_ptr[k_block];
+            const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];
+
+            // 1. Prepare Global Min Scales
+            vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);
+            vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);
+
+            vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl);
+
+            vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl);
+
+            const uint8_t* rhs_qs_ptr = rhs_current->qs;
+            const uint8_t* rhs_sc_ptr = rhs_current->scales;
+            const int8_t*  lhs_qs_ptr = lhs_current->qs;
+
+            // --- Phase Loop (4 phases x 64 elements) ---
+            for (int phase = 0; phase < 4; ++phase) {
+
+                // A. Load Scales/Mins
+                vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;
+                vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;
+
+                {
+                    vuint8mf2_t v_raw;
+                    // Sub-block 0
+                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);
+                    v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                    v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                    // Sub-block 1
+                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);
+                    v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                    v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                    // Sub-block 2
+                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);
+                    v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                    v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                    // Sub-block 3
+                    v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);
+                    v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                    v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                    rhs_sc_ptr += 64;
+                }
+
+                int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);
+                int k_offsets[4] = {0, 32, 64, 96};
+
+                // B. Inner Dot Product Loop
+                for (int l = 0; l < 16; ++l) {
+                    vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);
+                    rhs_qs_ptr += 16;
+
+                    // Sub-block 0
+                    {
+                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);
+                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);
+
+                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l];
+                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
+                    }
+                    // Sub-block 1
+                    {
+                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);
+                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);
+
+                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l];
+                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
+                    }
+                    // Sub-block 2
+                    {
+                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);
+                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);
+
+                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l];
+                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
+                    }
+                    // Sub-block 3
+                    {
+                        vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);
+                        vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                            __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                            __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);
+
+                        int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l];
+                        v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl);
+                    }
+                }
+
+                // correction
+                int sb_base_abs = base_k_phase / 16;
+
+                // Sub-block 0
+                {
+                    int sb_idx = sb_base_abs + (k_offsets[0] / 16);
+                    int16_t bsum = lhs_current->bsums[sb_idx];
+                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);
+                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
+                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
+                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
+                }
+                // Sub-block 1
+                {
+                    int sb_idx = sb_base_abs + (k_offsets[1] / 16);
+                    int16_t bsum = lhs_current->bsums[sb_idx];
+                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);
+                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
+                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
+                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
+                }
+                // Sub-block 2
+                {
+                    int sb_idx = sb_base_abs + (k_offsets[2] / 16);
+                    int16_t bsum = lhs_current->bsums[sb_idx];
+                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);
+                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
+                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
+                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
+                }
+                // Sub-block 3
+                {
+                    int sb_idx = sb_base_abs + (k_offsets[3] / 16);
+                    int16_t bsum = lhs_current->bsums[sb_idx];
+                    vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);
+                    vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl);
+                    vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl);
+                    v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl);
+                }
+
+            } // End Phase Loop
+
+            // Apply global Scales
+            vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);
+            vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);
+
+            vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl);
+            vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl);
+            v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl);
+            v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl);
+
+        } // End K-Block
+        __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl);
+
+    }
+}
+
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -340,3 +908,826 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 #endif
     ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
+
+void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
+
+            // 4x16 Accumulators
+            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+            for (int l = 0; l < nb; l++) {
+                // 4x16 integer accumulators
+                vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+
+                // Accumulation loop.
+                for (int i = 0; i < QK4_0 / 2; i++) {
+                    // Load `b_ptr`.
+                    const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
+                    const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16);
+                    const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16);
+
+                    sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16);
+                    sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16);
+                    sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16);
+                    sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16);
+
+                    sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16);
+                    sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16);
+                    sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16);
+                    sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16);
+                }
+
+                // Do the final accumulation in i32 to prevent overflow.
+                const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16);
+                const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16);
+                const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16);
+                const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16);
+
+                const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
+                const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
+                const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
+                const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
+                const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
+
+                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
+                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
+                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
+                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
+            }
+
+            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
+        }
+    }
+    return;
+#endif
+    ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
+
+            // 4x16 Accumulators
+            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+            for (int l = 0; l < nb; l++) {
+                vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16);
+                vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16);
+                vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16);
+                vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16);
+
+                // Load `dmin`.
+                const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16);
+
+                // We process 4 sub-blocks at once.
+                for (int j = 0; j < QK_K / 128; j++) {
+                    // Extract the scales and the mins.
+                    //
+                    // Low bits.
+                    vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64);
+                    vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64);
+                    vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64);
+
+                    // High bits.
+                    vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64);
+                    vuint8m2_t scales_hi;
+                    vuint8m2_t mins_hi;
+                    if (!j) {
+                        scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64);
+                        mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64);
+                    } else {
+                        scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64);
+                        mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64);
+                    }
+                    vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64);
+                    vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64));
+
+                    // Reduce the mins and multiply with `dmin`.
+                    //
+                    // Correct in `sumf`.
+                    vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16);
+                    vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16);
+                    vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16);
+                    vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16);
+
+                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
+                                a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4],
+                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
+                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
+                                a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5],
+                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
+                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
+                                a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6],
+                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
+                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
+                                a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7],
+                                __riscv_vget_v_i16m4_i16m1(mins, 0), 16);
+                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
+                                a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4],
+                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
+                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
+                                a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5],
+                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
+                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
+                                a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6],
+                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
+                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
+                                a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7],
+                                __riscv_vget_v_i16m4_i16m1(mins, 1), 16);
+                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
+                                a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4],
+                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
+                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
+                                a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5],
+                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
+                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
+                                a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6],
+                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
+                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
+                                a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7],
+                                __riscv_vget_v_i16m4_i16m1(mins, 2), 16);
+                    bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0,
+                                a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4],
+                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
+                    bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1,
+                                a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5],
+                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
+                    bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2,
+                                a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6],
+                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
+                    bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3,
+                                a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7],
+                                __riscv_vget_v_i16m4_i16m1(mins, 3), 16);
+
+                    const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16);
+                    const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16);
+                    const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16);
+                    const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16);
+
+                    sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16);
+                    sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16);
+                    sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16);
+                    sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16);
+
+
+                    // Accumulation for 2 sub-blocks.
+                    //
+                    // This might overflow, so we accumulate in two steps.
+                    //
+                    // Recheck.
+                    for (int k = 0; k < 2; k++) {
+                        // 4x16 integer accumulators
+                        vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+
+                        for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
+                            // Load `b_ptr`.
+                            const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16);
+                            const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
+                            const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
+
+                            sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16);
+                            sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16);
+                            sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16);
+                            sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16);
+
+                            sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16);
+                            sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16);
+                            sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16);
+                            sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16);
+                        }
+
+                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
+                                    sumi_0_s_0_16, 16);
+                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
+                                    sumi_0_s_1_16, 16);
+                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
+                                    sumi_1_s_0_16, 16);
+                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
+                                    sumi_1_s_1_16, 16);
+                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
+                                    sumi_2_s_0_16, 16);
+                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
+                                    sumi_2_s_1_16, 16);
+                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)),
+                                    sumi_3_s_0_16, 16);
+                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)),
+                                    sumi_3_s_1_16, 16);
+                    }
+                    // Accumulation for 2 sub-blocks.
+                    //
+                    // This might overflow, so we accumulate in two steps.
+                    //
+                    // Recheck.
+                    for (int k = 0; k < 2; k++) {
+                        // 4x16 integer accumulators
+                        vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+                        vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16);
+
+                        for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) {
+                            // Load `b_ptr`.
+                            const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16);
+                            const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16));
+                            const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16));
+
+                            sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16);
+                            sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16);
+                            sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16);
+                            sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16);
+
+                            sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16);
+                            sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16);
+                            sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16);
+                            sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16);
+                        }
+
+                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
+                                    sumi_0_s_0_16, 16);
+                        sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
+                                    sumi_0_s_1_16, 16);
+                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
+                                    sumi_1_s_0_16, 16);
+                        sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
+                                    sumi_1_s_1_16, 16);
+                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
+                                    sumi_2_s_0_16, 16);
+                        sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
+                                    sumi_2_s_1_16, 16);
+                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)),
+                                    sumi_3_s_0_16, 16);
+                        sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3,
+                                    __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)),
+                                    sumi_3_s_1_16, 16);
+                    }
+                }
+
+                const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16);
+                const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16);
+                const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16);
+                const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16);
+                const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16);
+
+                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
+                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
+                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
+                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
+            }
+
+            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
+        }
+    }
+    return;
+#endif
+    ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16);
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
+
+            // 4x16 Accumulators
+            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+            for (int l = 0; l < nb; l++) {
+                // 4x16 integer accumulators
+                vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+                vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+                vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+                vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+
+                // Accumulation loop.
+                for (int i = 0; i < QK4_NL / 2; i++) {
+                    // Load `b_ptr`.
+                    const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
+                    const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
+                    const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
+                    // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
+                    // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
+
+                    const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16);
+                    const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16);
+                    const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16);
+                    const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16);
+
+                    const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16);
+                    const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16);
+                    const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16);
+                    const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16);
+
+                    sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16);
+                    sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16);
+                    sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16);
+                    sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16);
+                }
+
+                const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
+                const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
+                const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
+                const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
+                const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
+
+                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
+                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
+                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
+                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
+            }
+
+            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
+        }
+    }
+    return;
+#endif
+    ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined __riscv_v_intrinsic
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
+
+            // 4x16 Accumulators
+            vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+            vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16);
+
+            for (int l = 0; l < nb; l++) {
+                // 4x16 Integer Accumulators
+                vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+                vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+                vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+                vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
+
+                // Accumulation loop.
+                for (int i = 0; i < QK8_0; i++) {
+                    // Load `b_ptr`.
+                    const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16);
+                    // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16);
+
+                    sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16);
+                    sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16);
+                    sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16);
+                    sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16);
+                }
+
+                const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16);
+                const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
+                const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
+                const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
+                const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
+
+                sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
+                sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);
+                sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16);
+                sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16);
+            }
+
+            __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16);
+            __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16);
+        }
+    }
+    return;
+#endif
+    ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    assert(n % QK_K == 0);
+    const int num_k_blocks = n / QK_K;
+    const int N_ROWS_TILE = 4;
+    const int N_COLS_TILE = 16;
+    assert(nr % N_ROWS_TILE == 0);
+    assert(nc % N_COLS_TILE == 0);
+
+    const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE);
+    // --- Tiling Loops ---
+#pragma GCC unroll 1
+    for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) {
+#pragma GCC unroll 1
+        for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) {
+            // Base Pointers
+            const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks;
+            const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks;
+
+            // Persistent Float Accumulators
+            vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
+            vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
+            vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
+            vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
+
+            // --- Super-Block Loop (K=0..255) ---
+#pragma GCC unroll 1
+            for (int k_block = 0; k_block < num_k_blocks; ++k_block) {
+                const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block];
+                const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block];
+
+                // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers)
+                vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl);
+                vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl);
+
+                // 2. Initialize Integer Accumulators
+                vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl);
+                vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl);
+                vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl);
+                vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl);
+
+                const uint8_t* rhs_qs_ptr = rhs_current->qs;
+                const uint8_t* rhs_sc_ptr = rhs_current->scales;
+                const int8_t*  lhs_qs_ptr = lhs_current->qs;
+
+                // --- Phase Loop (4 phases x 64 elements) ---
+#pragma GCC unroll 1
+                for (int phase = 0; phase < 4; ++phase) {
+
+                    // A. Load Scales/Mins for the 4 interleaved sub-blocks
+                    vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3;
+                    vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3;
+
+                    // Unrolled Load Logic
+                    {
+                        vuint8mf2_t v_raw;
+                        // Sub-block 0
+                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl);
+                        v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                        v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                        // Sub-block 1
+                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl);
+                        v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                        v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                        // Sub-block 2
+                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl);
+                        v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                        v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                        // Sub-block 3
+                        v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl);
+                        v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl);
+                        v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl);
+
+                        rhs_sc_ptr += 64;
+                    }
+
+                    int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16);
+                    int k_offsets[4] = {0, 32, 64, 96};
+
+                    // B. Inner Dot Product Loop
+#pragma GCC unroll 1
+                    for (int l = 0; l < 16; ++l) {
+                        vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl);
+                        rhs_qs_ptr += 16;
+
+                        // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase)
+
+                        // --- Sub-block 0 ---
+                        {
+                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl);
+                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl);
+
+                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4];
+                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
+                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
+                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
+                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
+                        }
+                        // --- Sub-block 1 ---
+                        {
+                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl);
+                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl);
+
+                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4];
+                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
+                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
+                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
+                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
+                        }
+                        // --- Sub-block 2 ---
+                        {
+                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl);
+                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl);
+
+                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4];
+                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
+                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
+                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
+                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
+                        }
+                        // --- Sub-block 3 ---
+                        {
+                            vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl);
+                            vint16m1_t v_w = __riscv_vmul_vv_i16m1(
+                                __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)),
+                                __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl);
+
+                            const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4];
+                            v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl);
+                            v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl);
+                            v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl);
+                            v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl);
+                        }
+                    }
+
+                    // C CORRECTION
+                    int sb_base_abs = base_k_phase / 16;
+
+                    // --- Correction Sub-block 0 ---
+                    {
+                        int sb_abs = sb_base_abs + (k_offsets[0] / 16);
+                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0);
+
+                        // Row 0
+                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
+                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
+                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
+
+                        // Row 1
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
+
+                        // Row 2
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
+
+                        // Row 3
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
+                    }
+
+                    // --- Correction Sub-block 1 ---
+                    {
+                        int sb_abs = sb_base_abs + (k_offsets[1] / 16);
+                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1);
+
+                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
+                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
+                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
+                    }
+
+                    // --- Correction Sub-block 2 ---
+                    {
+                        int sb_abs = sb_base_abs + (k_offsets[2] / 16);
+                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2);
+
+                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
+                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
+                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
+                    }
+
+                    // --- Correction Sub-block 3 ---
+                    {
+                        int sb_abs = sb_base_abs + (k_offsets[3] / 16);
+                        vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3);
+
+                        vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl);
+                        vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl);
+                        vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl);
+
+                        v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl);
+                        v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl);
+                        vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl);
+                        v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl);
+                    }
+
+                } // End Phase Loop
+
+                // --- Apply Main Scales ---
+                vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl);
+                vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl);
+
+                {
+                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl);
+                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl);
+                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
+                    v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl);
+                }
+                // Row 1
+                {
+                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl);
+                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl);
+                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
+                    v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl);
+                }
+                // Row 2
+                {
+                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl);
+                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl);
+                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
+                    v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl);
+                }
+                // Row 3
+                {
+                    vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl);
+                    vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl);
+                    v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl);
+                    v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl);
+                }
+
+            } // End K-Block
+
+            __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl);
+            __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl);
+            __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl);
+            __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl);
+        }
+    }
+}
diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c
index 19d225a4..34184ed8 100644
--- a/ggml/src/ggml-cpu/arch/s390/quants.c
+++ b/ggml/src/ggml-cpu/arch/s390/quants.c
@@ -181,11 +181,11 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
         const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
 
         const int16x8_t v_xylso = vec_mulo(v_xls, v_yl);
-        const int16x8_t v_xylse = vec_mule(v_xls, v_yl);
+        const int16x8_t v_xyl = vec_meadd(v_xls, v_yl, v_xylso);
         const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh);
-        const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh);
+        const int16x8_t v_xyh = vec_meadd(v_xhs, v_yh, v_xyhso);
 
-        int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
+        int16x8_t v_xy_ = v_xyl + v_xyh; v_xy_ += vec_reve(v_xy_);
 
         const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_));
         const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
@@ -890,8 +890,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
         const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
 
         const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
-        const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
-        const int32x4_t v_mins = v_minso + v_minse;
+        const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minso);
         sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
 
         const uint8_t * scales = (const uint8_t *)utmp;
@@ -1004,8 +1003,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
         const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
 
         const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
-        const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
-        const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
+        const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minsho);
         const int32_t mins = vec_hsum_i32x4(v_mins);
 
         const uint8_t * scales = (const uint8_t *)utmp;
@@ -1110,10 +1108,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
         const int16x8_t v_scaleh = vec_unpackl(v_scale);
 
         const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
-        const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
+        const int32x4_t v_minsl = vec_meadd(v_ysumsl, v_scalel, v_minslo);
         const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
-        const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
-        const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
+        const int32x4_t v_minsh = vec_meadd(v_ysumsh, v_scaleh, v_minsho);
+        const int32x4_t v_mins = vec_add(v_minsl, v_minsh);
 
         const int32_t mins = vec_hsum_i32x4(v_mins);
 
diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c
index cb49320a..74d699f6 100644
--- a/ggml/src/ggml-cpu/arch/x86/quants.c
+++ b/ggml/src/ggml-cpu/arch/x86/quants.c
@@ -268,9 +268,9 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
                            _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 }
 
-static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
-    return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
-                           _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
+static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) {
+    return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
+                           _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 }
 #endif
 #elif defined(__SSSE3__)
@@ -782,6 +782,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 
     __m256 accum1 = _mm256_setzero_ps();
     __m256 accum2 = _mm256_setzero_ps();
+
     for (; ib + 1 < nb; ib += 2) {
         const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
         const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
@@ -795,10 +796,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
         const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
         const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
         const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
-        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
-                _mm256_cvtepi32_ps(p_1), accum1);
-        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
-                _mm256_cvtepi32_ps(p_2), accum2);
+        const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e));
+        const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e));
+        accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2);
     }
 
     sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
@@ -830,7 +831,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
 
 #endif
     for (; ib < nb; ++ib) {
-        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e);
         int sumi1 = 0;
         int sumi2 = 0;
         for (int j = 0; j < QK_MXFP4/2; ++j) {
@@ -3817,4 +3818,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
     ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
 #endif
 }
-
diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp
index 7dda9eea..33c6cb65 100644
--- a/ggml/src/ggml-cpu/arch/x86/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp
@@ -423,7 +423,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
             quants_interleaved[j] = i0;
         }
 
-        // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
+        // Masks to shuffle the quants of corresponding sub blocks for rearranging quants for vectorized bsums computation
         __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
         shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
         __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
@@ -522,7 +522,8 @@ template
 static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
     static_assert(
             std::is_same_v ||
-            std::is_same_v,
+            std::is_same_v ||
+            std::is_same_v,
             "Unsupported block type");
 
     const int qk = QK8_0;
@@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+                } else if constexpr (std::is_same_v) {
+                    // Load 8 E8M0 exponents and convert to float via LUT
+                    // Rearranged to match changemask order: 0,4,1,5,2,6,3,7
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Load and convert to FP32 scale from block_q8_0
@@ -612,7 +625,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                 iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170));
                 iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255));
 
-                // Accumulated values multipled with appropriate scales
+                // Accumulated values multiplied with appropriate scales
                 acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
             }
 
@@ -628,7 +641,8 @@ template
 static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
     static_assert(
             std::is_same_v ||
-            std::is_same_v,
+            std::is_same_v ||
+            std::is_same_v,
             "Unsupported block type");
 
     const int qk = QK8_0;
@@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                } else if constexpr (std::is_same_v) {
+                    //TODO: simd-ify
+                    col_scale_f32 = _mm512_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
                 }
 
                 // Process LHS in pairs of rows
@@ -835,7 +868,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                     const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
                     const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
 
-                    // Multiply with appropiate scales and accumulate
+                    // Multiply with appropriate scales and accumulate
                     acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);
                     acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);
                     acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
@@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                } else if constexpr (std::is_same_v) {
+                    //TODO: simd-ify
+                    col_scale_f32 = _mm512_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
                 }
 
                 // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
@@ -1024,7 +1076,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                 const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
                 const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
 
-                // Multiply with appropiate scales and accumulate
+                // Multiply with appropriate scales and accumulate
                 acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);
                 acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);
                 acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
@@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                } else if constexpr (std::is_same_v) {
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Process LHS in groups of four
@@ -1195,7 +1257,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                     // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
                     const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
 
-                    // Multiply with appropiate scales and accumulate
+                    // Multiply with appropriate scales and accumulate
                     acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
                     acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
                     acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
@@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v ||
                         std::is_same_v) {
                     col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                } else if constexpr (std::is_same_v) {
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
@@ -1356,7 +1428,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                 // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
                 const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
 
-                // Multiply with appropiate scales and accumulate
+                // Multiply with appropriate scales and accumulate
                 acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
                 acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
                 acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
@@ -1540,7 +1612,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
 
                     // Dot product done within 32 bit lanes and accumulated in the same vector
-                    // First done for first sub block and thenn for second sub block in each sb
+                    // First done for first sub block and then for second sub block in each sb
                     // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
                     // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
                     // ...........................................................................
@@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__)
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+    gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+    return;
+#endif
+
+    ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
@@ -2337,7 +2422,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
                         const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
 
-                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        // Multiply with appropriate scales and accumulate (for both d and dmin) below
                         acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
                         acc_rows[rp * 4  + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
                         acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
@@ -2700,7 +2785,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
                     const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
 
-                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                    // Multiply with appropriate scales and accumulate (for both d and dmin) below
                     acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
                     acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
                     acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
@@ -2717,7 +2802,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
                 }
             }
-            // Store accumlated values
+            // Store accumulated values
             for (int i = 0; i < 4; i++) {
                 _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
             }
@@ -3045,7 +3130,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
                         const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
 
-                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        // Multiply with appropriate scales and accumulate (for both d and dmin) below
                         acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
                         acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
                         acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
@@ -3375,7 +3460,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
                     const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
 
-                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                    // Multiply with appropriate scales and accumulate (for both d and dmin) below
                     acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
                     acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
                     acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
@@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__) || defined(__AVX512F__)
+    {
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+        gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+        return;
+    }
+#endif // defined(__AVX2__) || defined(__AVX512F__)
+
+    ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
@@ -4168,7 +4268,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
                         const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1);
 
-                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        // Multiply with appropriate scales and accumulate (for both d and dmin) below
                         acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
                         acc_rows[rp * 4  + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
                         acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
@@ -4935,7 +5035,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]);
                 }
             }
-            // Store accumlated values
+            // Store accumulated values
             for (int i = 0; i < 4; i++) {
                 _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i]));
             }
@@ -5577,7 +5677,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                         const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d);
                         const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
 
-                        // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                        // Multiply with appropriate scales and accumulate (for both d and dmin) below
                         acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
                         acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
                         acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
@@ -6249,7 +6349,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
                     const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d);
                     const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);
 
-                    // Multiply with appropiate scales and accumulate (for both d and dmin) below
+                    // Multiply with appropriate scales and accumulate (for both d and dmin) below
                     acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
                     acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
                     acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp
index 14f5b43a..75e38290 100644
--- a/ggml/src/ggml-cpu/binary-ops.cpp
+++ b/ggml/src/ggml-cpu/binary-ops.cpp
@@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
     GGML_ASSERT(nb00 == sizeof(src0_t));
 
     const auto [ir0, ir1] = get_thread_range(params, src0);
-    const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
-
-    if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
-        GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    }
+    const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
 
 #ifdef GGML_USE_ACCELERATE
     vDSP_fn_t vDSP_op = nullptr;
@@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
         const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
         const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
-        if (is_src1_contiguous) {
+        if (is_src1_contiguous_rows) {
             // src1 is broadcastable across src0 and dst in i1, i2, i3
             const int64_t nr0 = ne00 / ne10;
 
diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h
index 6adca543..abbadc35 100644
--- a/ggml/src/ggml-cpu/common.h
+++ b/ggml/src/ggml-cpu/common.h
@@ -6,6 +6,9 @@
 #include "ggml-impl.h"
 #include "simd-mappings.h"
 
+#define GGML_FA_TILE_Q  64
+#define GGML_FA_TILE_KV 64
+
 #ifdef __cplusplus
 
 #include 
@@ -84,4 +87,9 @@ static std::pair get_thread_range(const struct ggml_compute_pa
     return {ir0, ir1};
 }
 
+struct ggml_fa_tile_config {
+    static constexpr size_t Q  = GGML_FA_TILE_Q;
+    static constexpr size_t KV = GGML_FA_TILE_KV;
+};
+
 #endif
diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h
index 0e8dd0ae..88a9c9ec 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -24,6 +24,9 @@ struct ggml_compute_params {
     void * wdata;
 
     struct ggml_threadpool * threadpool;
+
+    // use reference implementation
+    bool use_ref;
 };
 
 
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index f7ba1fe3..8b323bd9 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -5,7 +5,6 @@
 #include "ggml-backend.h"
 #include "traits.h"
 #include "ggml-cpu-impl.h"
-#include "ggml-cpu.h"
 #include "ggml-impl.h"
 #include "quants.h"
 #include "ggml-threading.h"
@@ -14,6 +13,7 @@
 #include "vec.h"
 #include "ops.h"
 #include "ggml.h"
+#include "common.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
 #include  // using malloc.h with MSC/MINGW
@@ -75,6 +75,9 @@
 // precomputed f32 table for f16 (256 KB) (simd-mappings.h)
 float ggml_table_f32_f16[1 << 16];
 
+// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
+float ggml_table_f32_e8m0_half[1 << 8];
+
 #if defined(__ARM_ARCH)
 struct ggml_arm_arch_features_type {
     int sve_cnt;
@@ -267,6 +270,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_0,
         .nrows                    = 1,
     },
+    [GGML_TYPE_NVFP4] = {
+        .from_float               = quantize_row_nvfp4,
+        .vec_dot                  = ggml_vec_dot_nvfp4_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q2_K] = {
         .from_float               = quantize_row_q2_K,
         .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
@@ -2018,6 +2027,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_solve_tri(params, tensor);
             } break;
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                ggml_compute_forward_gated_delta_net(params, tensor);
+            } break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -2197,6 +2210,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_COUNT_EQUAL:
         case GGML_OP_SOLVE_TRI:
+        case GGML_OP_GATED_DELTA_NET:
             {
                 n_tasks = n_threads;
             } break;
@@ -2474,7 +2488,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
 
     if (prio != GGML_SCHED_PRIO_LOW) {
         // Tell Windows that this thread should not be throttled (needs its own CPU core).
-        // Newer Windows 11 versions aggresively park (offline) CPU cores and often place
+        // Newer Windows 11 versions aggressively park (offline) CPU cores and often place
         // all our threads onto the first 4 cores which results in terrible performance with
         // n_threads > 4
         #if _WIN32_WINNT >= 0x0602
@@ -2866,10 +2880,20 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
-                        const int64_t ne10 = node->src[1]->ne[0]; // DK
-                        const int64_t ne20 = node->src[2]->ne[0]; // DV
+                        const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
+                        const int64_t DK = node->src[1]->ne[0];
+                        const int64_t DV = node->src[2]->ne[0];
 
-                        cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
+                        // Tiled flash attention scratch (tile sizes defined in common.h)
+                        // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
+                        size_t prefill  = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
+
+                        // Decode path: n_kv_chunks = n_tasks (one chunk per thread)
+                        // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
+                        size_t n_chunks = n_tasks;
+                        size_t decode   = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
+
+                        cur += MAX(prefill, decode);
                     } break;
                 case GGML_OP_FLASH_ATTN_BACK:
                     {
@@ -2892,6 +2916,11 @@ struct ggml_cplan ggml_graph_plan(
                     {
                         cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
                     } break;
+                case GGML_OP_GATED_DELTA_NET:
+                    {
+                        const int64_t S_v = node->src[2]->ne[0];
+                        cur = S_v * sizeof(float) * n_tasks;
+                    } break;
                 case GGML_OP_COUNT:
                     {
                         GGML_ABORT("fatal error");
@@ -2926,14 +2955,19 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     set_numa_thread_affinity(state->ith);
 
     struct ggml_compute_params params = {
-        /*.ith       =*/ state->ith,
-        /*.nth       =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
-        /*.wsize     =*/ cplan->work_size,
-        /*.wdata     =*/ cplan->work_data,
-        /*.threadpool=*/ tp,
+        /*.ith        =*/ state->ith,
+        /*.nth        =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
+        /*.wsize      =*/ cplan->work_size,
+        /*.wdata      =*/ cplan->work_data,
+        /*.threadpool =*/ tp,
+        /*.use_ref    =*/ cplan->use_ref,
     };
 
-    GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
+#ifdef GGML_USE_OPENMP
+    GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
+#else
+    GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
+#endif
 
     for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
         struct ggml_tensor * node = cgraph->nodes[node_n];
@@ -2943,6 +2977,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
             continue;
         }
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         ggml_compute_forward(¶ms, node);
 
         if (state->ith == 0 && cplan->abort_callback &&
@@ -2956,7 +2994,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         }
     }
 
-    GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
+#ifdef GGML_USE_OPENMP
+    GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
+#else
+    GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
+#endif
 
     ggml_barrier(state->threadpool);
 
@@ -3666,6 +3708,11 @@ void ggml_cpu_init(void) {
                 ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f));
             }
 
+            // initialize E8M0 half table (256 entries)
+            for (int i = 0; i < (1 << 8); ++i) {
+                ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
+            }
+
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
             GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp
index f4713a42..ddf1737a 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -105,6 +105,8 @@ struct ggml_backend_cpu_context {
 
     ggml_abort_callback abort_callback;
     void *              abort_callback_data;
+
+    bool                use_ref;  // use reference implementation
 };
 
 static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
@@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend
 
     cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
     cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+    cpu_plan->cplan.use_ref             = cpu_ctx->use_ref;
 
     return cpu_plan;
 }
@@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
 
     cplan.abort_callback      = cpu_ctx->abort_callback;
     cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+    cplan.use_ref             = cpu_ctx->use_ref;
 
     return ggml_graph_compute(cgraph, &cplan);
 }
@@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
     ctx->work_size           = 0;
     ctx->abort_callback      = NULL;
     ctx->abort_callback_data = NULL;
+    ctx->use_ref             = false;
 
     ggml_backend_t cpu_backend = new ggml_backend {
         /* .guid    = */ ggml_backend_cpu_guid(),
@@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
     ctx->abort_callback_data = abort_callback_data;
 }
 
+void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->use_ref = use_ref;
+}
+
 // CPU backend - device
 
 struct ggml_backend_cpu_device_context {
@@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
     if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
         return (void *)ggml_is_numa;
     }
+    if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) {
+        return (void *)ggml_backend_cpu_set_use_ref;
+    }
 
     // threadpool - TODO:  move to ggml-base
     if (strcmp(name, "ggml_threadpool_new") == 0) {
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
index d114f2d4..8c4d7bc9 100644
--- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
@@ -1,4 +1,4 @@
-// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates 
+// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates 
 // SPDX-License-Identifier: MIT
 //
 
@@ -9,7 +9,6 @@
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
-#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
 #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
 #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
@@ -20,6 +19,7 @@
 #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
+#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h"
 
 #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
 #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
@@ -31,6 +31,7 @@
 #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
 #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
 #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
+#include "kai_lhs_pack_f16pmrx2_f32_neon.h"
 
 #include "kai_common.h"
 
@@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
     {
         /* SME GEMM */
         /* .kern_info = */ {
-            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3,
-            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3,
-            /* .run_kernel_ex         = */ &kernel_run_fn11,
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3,
+            /* .run_kernel_ex         = */ &kernel_run_fn11,
         },
 
         /* .gemm_lhs_info = */ {
-            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .get_packed_offset_ex  = */ &lhs_offs_fn6,
-            /* .packed_size_ex        = */ &lhs_ps_fn6,
-            /* .pack_func_ex          = */ &lhs_pack_float_fn10,
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6,
+            /* .packed_size_ex        = */ &lhs_ps_fn6,
+            /* .pack_func_ex          = */ &lhs_pack_void_fn10,
         },
         /* SME GEMV */
         /* .kern_info = */ {
@@ -519,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4,
             /* .pack_func_ex          = */ &rhs_pack_fn12,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
@@ -630,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4,
             /* .pack_func_ex          = */ &rhs_pack_fn12,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
@@ -800,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4,
             /* .pack_func_ex          = */ &rhs_pack_scale_fn12,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q8_0,
         /* .op_type            = */ GGML_TYPE_F32,
diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
index ad23e731..9bcc18d4 100644
--- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
+++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
@@ -1,20 +1,31 @@
-// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates 
+// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates 
 // SPDX-License-Identifier: MIT
 //
 #include 
 #include 
+#include 
 #include 
 #include 
-#include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
 #if defined(__linux__)
 #include 
 #include 
+#include 
+#include 
+#include 
 #elif defined(__APPLE__)
 #include 
 #include 
@@ -39,11 +50,18 @@
 #define GGML_COMMON_DECL_CPP
 #include "ggml-common.h"
 
+static constexpr int      GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
+static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC       = 0x4b4c4149; // "KLAI"
+static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION     = 1;
+static constexpr size_t   GGML_KLEIDIAI_PACK_ALIGN       = 64;
+
 struct ggml_kleidiai_context {
     cpu_feature features;
     ggml_kleidiai_kernels * kernels_q4;
     ggml_kleidiai_kernels * kernels_q8;
-} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
+    int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
+    int thread_hint;    // <= 0 means “no hint”
+} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
 
 static const char* cpu_feature_to_string(cpu_feature f) {
     if (f == CPU_FEATURE_NONE) {
@@ -63,41 +81,335 @@ static const char* cpu_feature_to_string(cpu_feature f) {
     }
 }
 
-static void init_kleidiai_context(void) {
+static size_t detect_num_smcus() {
+    if (!ggml_cpu_has_sme()) {
+        return 0;
+    }
 
+#if defined(__linux__) && defined(__aarch64__)
+    // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
+    size_t num_private = 0;
+    std::set shared_ids;
+
+    for (size_t cpu = 0;; ++cpu) {
+        const std::string path =
+            "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
+            "/regs/identification/smidr_el1";
+
+        std::ifstream file(path);
+        if (!file.is_open()) {
+            break;
+        }
+
+        uint64_t smidr = 0;
+        if (!(file >> std::hex >> smidr)) {
+            continue;
+        }
+
+        // Arm ARM: SMIDR_EL1
+        const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
+        // Build an "affinity-like" identifier for shared SMCUs.
+        // Keep the original packing logic, but isolate it here.
+        const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
+
+        switch (sh) {
+            case 0b10: // private SMCU
+                ++num_private;
+                break;
+            case 0b11: // shared SMCU
+                shared_ids.emplace(id);
+                break;
+            case 0b00:
+                // Ambiguous / implementation-defined. Be conservative:
+                // treat id==0 as private, otherwise as shared.
+                if (id == 0) ++num_private;
+                else shared_ids.emplace(id);
+                break;
+            default:
+                break;
+        }
+    }
+
+    return num_private + shared_ids.size();
+
+#elif defined(__APPLE__) && defined(__aarch64__)
+    // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=.
+    char chip_name[256] = {};
+    size_t size = sizeof(chip_name);
+
+    if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
+        const std::string brand(chip_name);
+
+        struct ModelSMCU { const char *match; size_t smcus; };
+        static const ModelSMCU table[] = {
+            { "M4 Ultra", 2 },
+            { "M4 Max",   2 },
+            { "M4 Pro",   2 },
+            { "M4",       1 },
+        };
+
+        for (const auto &e : table) {
+            if (brand.find(e.match) != std::string::npos) {
+                return e.smcus;
+            }
+        }
+    }
+    return 1;
+
+#else
+    return 1;
+#endif
+}
+
+static int parse_uint_env(const char *s, const char *name, bool *ok) {
+    if (!s) { *ok = false; return 0; }
+    char *end = nullptr;
+    long v = strtol(s, &end, 10);
+    if (end == s || *end != '\0') {
+        GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
+        *ok = false;
+        return 0;
+    }
+    if (v < 0 || v > INT_MAX) {
+        GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
+        *ok = false;
+        return 0;
+    }
+    *ok = true;
+    return (int)v;
+}
+
+static void init_kleidiai_context(void) {
     ggml_critical_section_start();
     static bool initialized = false;
 
     if (!initialized) {
         initialized = true;
-        const char *env_var = getenv("GGML_KLEIDIAI_SME");
-        int sme_enabled = 0;
+
+        const char *env_sme     = getenv("GGML_KLEIDIAI_SME");
+        const char *env_threads = getenv("GGML_TOTAL_THREADS");
+
+        const bool cpu_has_sme = ggml_cpu_has_sme();
+        size_t detected_smcus = 0;
 
         ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
                         (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
                         ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
 
-        if (env_var) {
-            sme_enabled = atoi(env_var);
+        if (env_threads) {
+            bool ok = false;
+            int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
+            if (ok && hint > 0) {
+                ctx.thread_hint = hint;
+            }
         }
 
-        if (sme_enabled != 0) {
-            ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+        // SME policy:
+        // - If CPU doesn't support SME: SME always off.
+        // - Else:
+        //   - env unset => auto-detect cores; enable if detected > 0.
+        //   - env=0     => force off.
+        //   - env>0     => force N cores (skip detection).
+        int sme_cores = 0;
+        bool sme_env_ok = false;
+        bool sme_env_set = (env_sme != nullptr);
+
+        if (!cpu_has_sme) {
+            if (sme_env_set) {
+                bool ok = false;
+                int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+                if (ok && req > 0) {
+                    GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
+                }
+            }
+            sme_cores = 0;
+        } else {
+            if (sme_env_set) {
+                bool ok = false;
+                int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+                sme_env_ok = ok;
+
+                if (!ok) {
+                    GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
+                    detected_smcus = detect_num_smcus();
+                    sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+                } else if (v == 0) {
+                    sme_cores = 0;
+                } else {
+                    sme_cores = v;
+                }
+            } else {
+                detected_smcus = detect_num_smcus();
+                sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+            }
+
+            if (!sme_env_set && sme_cores == 0) {
+                GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
+            }
+
+            if (sme_cores > 0) {
+                ctx.features |= CPU_FEATURE_SME;
+            }
         }
+
+        // Kernel selection
         ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
         ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
-#ifndef NDEBUG
-        if (ctx.kernels_q4) {
-            GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+
+        if (!ctx.kernels_q4) {
+            GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
+        } else {
+            GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
         }
-        if (ctx.kernels_q8) {
-            GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
+
+        if (!ctx.kernels_q8) {
+            GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
+        } else {
+            GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
+        }
+
+        ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
+
+        if (ctx.features & CPU_FEATURE_SME) {
+            if (sme_env_set && sme_env_ok && sme_cores > 0) {
+                GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
+            } else {
+                GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
+            }
+        } else {
+            GGML_LOG_INFO("kleidiai: SME disabled\n");
         }
-#endif
     }
+
     ggml_critical_section_end();
 }
 
+static inline int kleidiai_sme_thread_cap() {
+    return ctx.sme_thread_cap;
+}
+
+static inline size_t align_up(size_t value, size_t alignment) {
+    if (alignment == 0) {
+        return value;
+    }
+    const size_t remainder = value % alignment;
+    return remainder == 0 ? value : value + (alignment - remainder);
+}
+
+static inline bool kleidiai_pack_fallback_allowed() {
+    if (ctx.sme_thread_cap <= 0) {
+        return false;
+    }
+    if (ctx.thread_hint <= 0) {
+        return true;
+    }
+    return ctx.thread_hint > ctx.sme_thread_cap;
+}
+
+struct kleidiai_weight_header {
+    uint32_t magic;
+    uint16_t version;
+    uint16_t slot_count;
+    uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+    uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+};
+
+static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
+    return reinterpret_cast(data);
+}
+
+static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
+    return reinterpret_cast(data);
+}
+
+static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
+    if (!header) {
+        return false;
+    }
+    if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
+        return false;
+    }
+    if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
+        return false;
+    }
+    return true;
+}
+
+static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
+    if (!kleidiai_is_weight_header_valid(header)) {
+        return nullptr;
+    }
+    if (slot < 0 || slot >= header->slot_count) {
+        return nullptr;
+    }
+    return reinterpret_cast(header) + header->offsets[slot];
+}
+
+static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
+    if (!kleidiai_is_weight_header_valid(header)) {
+        return nullptr;
+    }
+    if (slot < 0 || slot >= header->slot_count) {
+        return nullptr;
+    }
+    return reinterpret_cast(header) + header->offsets[slot];
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
+    return ctx.kernels_q4;
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
+    return ctx.kernels_q8;
+}
+
+template 
+static int kleidiai_collect_kernel_chain_common(
+        ggml_kleidiai_kernels * primary,
+        cpu_feature features,
+        std::array & out,
+        SelectFallback select_fallback) {
+    int count = 0;
+    if (!primary) {
+        return 0;
+    }
+    out[count++] = primary;
+
+    if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+        const cpu_feature fallback_mask = static_cast(features & ~CPU_FEATURE_SME);
+        if (fallback_mask != CPU_FEATURE_NONE) {
+            ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
+            if (fallback && fallback != primary &&
+                fallback->lhs_type == primary->lhs_type &&
+                fallback->rhs_type == primary->rhs_type &&
+                fallback->op_type  == primary->op_type) {
+                out[count++] = fallback;
+            }
+        }
+    }
+
+    return count;
+}
+
+static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
+        std::array & out) {
+    ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
+}
+
+static int kleidiai_collect_q4_chain(std::array & out) {
+    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
+}
+
+static int kleidiai_collect_q8_chain(std::array & out) {
+    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
+}
+
 static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
     return tensor->ne[dim];
@@ -126,49 +438,108 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         if (op->op != GGML_OP_MUL_MAT) {
             return false;
         }
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
-        if (!kernels) {
+
+        std::array kernel_chain;
+        const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
+        if (slot_count == 0) {
             return false;
         }
-        bool is_gemv = op->src[1]->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
 
-        size_t k = op->src[0]->ne[0];
-        size_t n = op->src[0]->ne[1];
-        size_t m = op->src[1]->ne[1];
+        const bool is_gemv = op->src[1]->ne[1] == 1;
+        const size_t k = op->src[0]->ne[0];
+        const size_t n = op->src[0]->ne[1];
+        const size_t m = op->src[1]->ne[1];
 
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
+        if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
+            const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
 
-        if (kernels->rhs_type == GGML_TYPE_Q4_0) {
-            if (!lhs_info->packed_size_ex) return false;
-            size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
-        } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
-            if (!lhs_info->packed_size_ex) return false;
-            size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
-        } else if (kernels->rhs_type == GGML_TYPE_F16) {
-            if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
+            size_t cursor = 0;
+            bool any_slot = false;
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;
+
+                if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
+                    return false;
+                }
+
+                const size_t mr = kernel->get_mr();
+                const size_t kr = kernel->get_kr();
+                const size_t sr = kernel->get_sr();
+
+                const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
+
+                cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += packed;
+                any_slot = true;
+            }
+
+            if (!any_slot) {
+                return false;
+            }
+
+            size = cursor;
+            return true;
+        }
+
+        if (op->src[0]->type == GGML_TYPE_F16) {
             const int64_t lhs_batch_size0 = op->src[1]->ne[2];
             const int64_t rhs_batch_size0 = op->src[0]->ne[2];
+            GGML_ASSERT(rhs_batch_size0 > 0);
             const int64_t r = lhs_batch_size0 / rhs_batch_size0;
-            size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
-                   kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
-                   k * n * sizeof(float) + n * sizeof(float);
-        } else {
-            return false;
+
+            size_t cursor = 0;
+            bool any_slot = false;
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;
+                if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
+                    return false;
+                }
+
+                const size_t mr = kernel->get_mr();
+                const size_t kr = kernel->get_kr();
+                const size_t sr = kernel->get_sr();
+
+                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
+                any_slot = true;
+            }
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+                if (!kernel || !kernels->rhs_info.packed_size_ex) {
+                    return false;
+                }
+                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
+            }
+
+            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            cursor += k * n * sizeof(float);
+            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            cursor += n * sizeof(float);
+
+            if (!any_slot) {
+                return false;
+            }
+
+            size = cursor;
+            return true;
         }
 
-        return true;
+        return false;
     }
 
     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
         if (dst->op == GGML_OP_MUL_MAT) {
-            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
-                return compute_forward_q4_0(params, dst);
-            } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
-                return compute_forward_q8_0(params, dst);
+            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
+                return compute_forward_qx(params, dst);
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
                 return compute_forward_fp16(params, dst);
             }
@@ -331,204 +702,457 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         return true;
     }
 
-    bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+    bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
 
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        if (!kernels) {
-            return false;
+        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+        const bool has_header = kleidiai_is_weight_header_valid(header);
+        const bool is_gemv = src1->ne[1] == 1;
+        std::array kernel_chain;
+        const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
+
+        auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
+            if (slot_index < 0 || slot_index >= slot_total) {
+                return nullptr;
+            }
+            if (has_header) {
+                if (slot_index < header->slot_count) {
+                    size_out = static_cast(header->sizes[slot_index]);
+                    return kleidiai_weight_slot_ptr(header, slot_index);
+                }
+                return nullptr;
+            }
+            if (slot_index == 0) {
+                size_out = ggml_nbytes(src0);
+                return static_cast(src0->data);
+            }
+            return nullptr;
+        };
+
+        struct runtime_slot {
+            int slot_index;
+            ggml_kleidiai_kernels * kernels;
+            kernel_info * kernel;
+            lhs_packing_info * lhs_info;
+            size_t mr;
+            size_t nr;
+            size_t kr;
+            size_t sr;
+            size_t n_step;
+            size_t lhs_packed_size;
+            size_t lhs_offset;
+            size_t n_offset;
+            size_t n_cols;
+            int assigned_threads;
+            int thread_begin;
+            int thread_end;
+            const uint8_t * rhs_base;
+        };
+
+        std::array runtime{};
+        int runtime_count = 0;
+
+        for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+            ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+            kernel_info * kinfo      = is_gemv ? &kernels->gemv : &kernels->gemm;
+            lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
+                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
+                continue;
+            }
+
+            size_t rhs_size = 0;
+            const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
+            if (!rhs_ptr || rhs_size == 0) {
+                continue;
+            }
+
+            runtime[runtime_count] = {
+                slot,
+                kernels,
+                kinfo,
+                linfo,
+                kinfo->get_mr(),
+                kinfo->get_nr(),
+                kinfo->get_kr(),
+                kinfo->get_sr(),
+                kinfo->get_n_step(),
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                rhs_ptr
+            };
+            ++runtime_count;
         }
 
-        bool is_gemv = src1->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
-
-        GGML_ASSERT(kernel);
-        if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
-            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
-            return false;
+        if (runtime_count == 0) {
+            ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
+            if (!fallback) {
+                return false;
+            }
+            kernel_info * kinfo      = is_gemv ? &fallback->gemv : &fallback->gemm;
+            lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
+            rhs_packing_info * rinfo = &fallback->rhs_info;
+            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
+                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
+                !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
+                return false;
+            }
+            kernel_chain[0] = fallback;
+            runtime[0] = {
+                0,
+                fallback,
+                kinfo,
+                linfo,
+                kinfo->get_mr(),
+                kinfo->get_nr(),
+                kinfo->get_kr(),
+                kinfo->get_sr(),
+                kinfo->get_n_step(),
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                nullptr
+            };
+            size_t rhs_size_fallback = 0;
+            const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
+            if (!rhs_base) {
+                rhs_base = static_cast(src0->data);
+            }
+            runtime[0].rhs_base = rhs_base;
+            runtime_count = 1;
         }
 
-        const int ith = params->ith;
-        const int nth_raw = params->nth;
-        const int nth = nth_raw > 0 ? nth_raw : 1;
+        const int nth_total = params->nth > 0 ? params->nth : 1;
+        const int ith_total = params->ith;
+
+        int sme_slot = -1;
+        for (int i = 0; i < runtime_count; ++i) {
+            if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+                sme_slot = i;
+                break;
+            }
+        }
+
+        const int sme_cap_limit = ctx.sme_thread_cap;
+        const bool use_hybrid = sme_cap_limit > 0 &&
+                                 runtime_count > 1 &&
+                                 nth_total > sme_cap_limit;
+        // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
+        // If rows are small or average columns per thread are small, keep single-slot.
+        size_t min_cols_per_thread = 0;
+        if (runtime_count > 0 && nth_total > 0) {
+            min_cols_per_thread = (size_t) std::max(1, (int64_t)ne01 / (int64_t)nth_total);
+        }
+        const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
+
+        const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
+
+        if (!hybrid_enabled) {
+            int chosen_slot = 0;
+            if (too_small_for_hybrid && sme_slot != -1) {
+                chosen_slot = sme_slot;
+            } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
+                chosen_slot = 1;
+            }
+            if (chosen_slot != 0 && chosen_slot < runtime_count) {
+                runtime[0] = runtime[chosen_slot];
+            }
+            runtime_count = runtime_count > 0 ? 1 : 0;
+
+            // Recompute SME slot based on the collapsed runtime[0]
+            sme_slot = -1;
+            if (runtime_count > 0 &&
+                (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+                sme_slot = 0;
+            }
+        }
+
+        int sme_cap = kleidiai_sme_thread_cap();
+        if (sme_cap < 0) {
+            sme_cap = nth_total;
+        }
+        sme_cap = std::min(sme_cap, nth_total);
+
+        int threads_remaining = nth_total;
+        if (sme_slot != -1) {
+            int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
+            runtime[sme_slot].assigned_threads = sme_threads;
+            threads_remaining -= sme_threads;
+        }
+
+        int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+        int fallback_count = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            if (i == sme_slot) {
+                continue;
+            }
+            fallback_indices[fallback_count++] = i;
+        }
+
+        for (int fi = 0; fi < fallback_count; ++fi) {
+            if (threads_remaining <= 0) {
+                break;
+            }
+            const int slot_index = fallback_indices[fi];
+            const int slots_left = fallback_count - fi;
+            int share = (threads_remaining + slots_left - 1) / slots_left;
+            share     = std::min(share, threads_remaining);
+            runtime[slot_index].assigned_threads = share;
+            threads_remaining -= share;
+        }
+
+        if (threads_remaining > 0) {
+            const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
+            runtime[fallback_slot].assigned_threads += threads_remaining;
+            threads_remaining = 0;
+        }
+
+        int thread_cursor = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            runtime[i].thread_begin = thread_cursor;
+            thread_cursor += runtime[i].assigned_threads;
+            runtime[i].thread_end = thread_cursor;
+        }
+
+        if (thread_cursor < nth_total && runtime_count > 0) {
+            runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
+            runtime[runtime_count - 1].thread_end = nth_total;
+        }
+
+        int local_slot = -1;
+        int local_ith  = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
+                local_slot = i;
+                local_ith  = ith_total - runtime[i].thread_begin;
+                break;
+            }
+        }
+        if (local_slot == -1) {
+            return false;
+        }
 
         const size_t k = ne00;
         const size_t m = ne11;
         const size_t n = ne01;
 
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
+        size_t cursor = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
+            const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                              slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+            runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
+            cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            runtime[i].lhs_offset = cursor;
+            cursor += runtime[i].lhs_packed_size;
+        }
 
-        const uint8_t * lhs        = static_cast(src1->data);
-        uint8_t * lhs_packed       = (uint8_t*)params->wdata;
-        const uint8_t * rhs_packed = static_cast(src0->data);
+        GGML_ASSERT(cursor <= params->wsize);
+        uint8_t * scratch = static_cast(params->wdata);
 
-        const size_t n_step = kernel->get_n_step();
-        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-        const size_t n_start = ith * num_n_per_thread;
-
-        size_t n_to_process = 0;
-        if (n_start < n) {
-            n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
+        size_t assigned_cols = 0;
+        uint64_t weighted_total = 0;
+        if (runtime_count > 1 && sme_slot != -1) {
+            for (int i = 0; i < runtime_count; ++i) {
+                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+                weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
             }
         }
-
-        // Calculate number of columns to be processed per thread
-        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-        const size_t m_start = ith * num_m_per_thread;
-        size_t m_to_process = num_m_per_thread;
-        if ((m_start + m_to_process) > m) {
-            m_to_process = m - m_start;
+        for (int i = 0; i < runtime_count; ++i) {
+            runtime[i].n_offset = assigned_cols;
+            if (runtime[i].assigned_threads == 0) {
+                runtime[i].n_cols = 0;
+                continue;
+            }
+            const size_t remaining_cols = n - assigned_cols;
+            if (remaining_cols == 0) {
+                runtime[i].n_cols = 0;
+                continue;
+            }
+            const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
+            size_t target      = 0;
+            if (weighted_total > 0) {
+                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
+            } else {
+                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
+            }
+            target             = std::min(target, remaining_cols);
+            size_t aligned     = round_down(target, step);
+            if (aligned == 0 && remaining_cols >= step) {
+                aligned = step;
+            }
+            runtime[i].n_cols = aligned;
+            assigned_cols += aligned;
         }
 
-        if (m_start < m) {
-            // Transform LHS
-            const size_t src_stride        = src1->nb[1];
-            const float * src_ptr          = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
-            void * lhs_packed_ptr          = static_cast(lhs_packed + lhs_packed_offset);
-
-            // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
-            lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
-        }
-
-        ggml_barrier(params->threadpool);
-
-        // Perform the operation
-        const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
-        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
-        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-        const void * rhs_ptr           = static_cast(rhs_packed + rhs_packed_offset);
-        const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
-        float *dst_ptr                 = reinterpret_cast(static_cast(dst->data) + dst_offset);
-
-        if (n_to_process > 0) {
-            kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                               sizeof(float), -FLT_MAX, FLT_MAX);
-        }
-
-        return true;
-    }
-
-    bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
-
-        const ggml_tensor * src0 = dst->src[0];
-        const ggml_tensor * src1 = dst->src[1];
-
-        GGML_TENSOR_BINARY_OP_LOCALS
-
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        if (!kernels) {
-            return false;
-        }
-
-        bool is_gemv = src1->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
-
-        if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
-            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
-            return false;
-        }
-
-        const int ith = params->ith;
-        const int nth_raw = params->nth;
-        const int nth = nth_raw > 0 ? nth_raw : 1;
-
-        const size_t k = ne00;
-        const size_t m = ne11;
-        const size_t n = ne01;
-
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
-
-        const uint8_t * lhs        = static_cast(src1->data);
-        uint8_t * lhs_packed       = static_cast(params->wdata);
-        const uint8_t * rhs_packed = static_cast(src0->data);
-
-        const size_t n_step = kernel->get_n_step();
-        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-        const size_t n_start = ith * num_n_per_thread;
-
-        size_t n_to_process = 0;
-        if (n_start < n) {
-            n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
+        if (assigned_cols < n) {
+            for (int i = runtime_count - 1; i >= 0; --i) {
+                if (runtime[i].assigned_threads > 0) {
+                    runtime[i].n_cols += n - assigned_cols;
+                    break;
+                }
             }
         }
+        const size_t dst_stride = dst->nb[1];
 
-        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-        const size_t m_start = ith * num_m_per_thread;
-        size_t m_to_process = num_m_per_thread;
-        if ((m_start + m_to_process) > m) {
-            m_to_process = m - m_start;
-        }
+        for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
+            const uint8_t * lhs_batch_base = static_cast(src1->data) + batch_idx * src1->nb[2];
+            uint8_t * dst_batch_base = static_cast(dst->data) + batch_idx * dst->nb[2];
 
-        if (m_start < m) {
-            const size_t src_stride        = src1->nb[1];
-            const float * src_ptr          = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
-            void * lhs_packed_ptr          = static_cast(lhs_packed + lhs_packed_offset);
+            if (runtime[local_slot].assigned_threads > 0) {
+                runtime_slot & slot = runtime[local_slot];
+                const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+                const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                 slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
+                int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
+                max_threads = std::max(1, max_threads);
+                const int64_t use_threads = std::min(slot.assigned_threads, max_threads);
 
-            lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
-        }
+                if (local_ith < use_threads) {
+                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
+                    const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
 
-        ggml_barrier(params->threadpool);
+                    const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
+                    const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
 
-        const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
-        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
-        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-        const void * rhs_ptr           = static_cast(rhs_packed + rhs_packed_offset);
-        const void * lhs_ptr           = static_cast(lhs_packed + lhs_packed_offset);
-        float * dst_ptr                = reinterpret_cast(static_cast(dst->data) + dst_offset);
+                    const size_t base_packed_off  = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                    const size_t next_block_off   = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                    const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
 
-        if (n_to_process > 0) {
-            kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                                  sizeof(float), -FLT_MAX, FLT_MAX);
+                    int64_t remaining = m_count;
+                    int64_t cur       = m_start;
+
+                    uint8_t * lhs_packed = scratch + slot.lhs_offset;
+                    while (remaining > 0) {
+                        const int64_t row_in_group = cur;
+                        const int64_t avail        = (int64_t)m - row_in_group;
+                        const int64_t take         = std::min(avail, remaining);
+
+                        const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
+                        const void * src_ptr = lhs_batch_base + src_off;
+                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
+                        void * dst_ptr       = lhs_packed + dst_off;
+
+                        slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
+
+                        cur       += take;
+                        remaining -= take;
+                    }
+                }
+            }
+
+            ggml_barrier(params->threadpool);
+
+            runtime_slot & slot = runtime[local_slot];
+            if (slot.n_cols > 0 && slot.assigned_threads > 0) {
+                int64_t active_threads = slot.assigned_threads;
+                const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
+                if (max_threads > 0) {
+                    active_threads = std::min(active_threads, std::max(1, max_threads));
+                }
+                active_threads = std::max(1, active_threads);
+
+                if (local_ith < active_threads) {
+                    const size_t step = slot.n_step ? slot.n_step : 1;
+                    const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
+                    const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
+                    const size_t local_start = (size_t)local_ith * chunk0;
+                    const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
+
+                    if (cols > 0) {
+                        const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+                        const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                         slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                        const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                          slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                        const size_t global_start = slot.n_offset + local_start;
+                        const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                        const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
+                        const size_t dst_offset        = slot.kernel->get_dst_offset(0, global_start, dst_stride);
+
+                        const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
+                        const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
+                        float * dst_ptr         = reinterpret_cast(dst_batch_base + dst_offset);
+
+                        slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
+                                                   lhs_ptr,
+                                                   rhs_ptr,
+                                                   dst_ptr,
+                                                   dst_stride,
+                                                   sizeof(float),
+                                                   -FLT_MAX,
+                                                   FLT_MAX);
+                    }
+                }
+            }
+
+            if (batch_idx != ne12 - 1) {
+                ggml_barrier(params->threadpool);
+            }
         }
 
         return true;
     }
 
     bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-        ggml_kleidiai_kernels * kernels = nullptr;
-        size_t block_len = 0;
-        size_t num_bytes_multiplier = 0;
+        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+        const bool has_header = kleidiai_is_weight_header_valid(header);
 
-        if (dst->src[0]->type == GGML_TYPE_Q4_0) {
-            if (!ctx.kernels_q4) {
-                return false;
+        std::array kernel_chain;
+        const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
+        const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                        : kleidiai_collect_q4_chain(kernel_chain);
+
+        ggml_kleidiai_kernels * kernels = nullptr;
+        const uint8_t * packed_base = static_cast(src0->data);
+
+        if (has_header && chain_count > 0) {
+            int select_slot = 0;
+            if (select_slot >= header->slot_count) {
+                select_slot = header->slot_count - 1;
             }
-            kernels = ctx.kernels_q4;
-            block_len = QK4_0;
-            num_bytes_multiplier = sizeof(uint16_t);
-        } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
-            if (!ctx.kernels_q8) {
-                return false;
+            if (select_slot >= 0 && select_slot < chain_count) {
+                kernels = kernel_chain[select_slot];
+                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
+                if (slot_ptr) {
+                    packed_base = slot_ptr;
+                }
             }
-            kernels = ctx.kernels_q8;
-            block_len = QK8_0;
-            num_bytes_multiplier = sizeof(float);
-        } else {
+        }
+
+        if (!kernels && chain_count > 0) {
+            kernels = kernel_chain[0];
+            if (has_header) {
+                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
+                if (slot_ptr) {
+                    packed_base = slot_ptr;
+                }
+            }
+        }
+
+        if (!kernels) {
             return false;
         }
 
@@ -541,6 +1165,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const int64_t nc     = ne00;
         const int64_t nr     = ggml_nelements(src1);
 
+        const ggml_type rhs_type = kernels->rhs_type;
+        size_t block_len = 0;
+        size_t num_bytes_multiplier = 0;
+        if (rhs_type == GGML_TYPE_Q4_0) {
+            block_len = QK4_0;
+            num_bytes_multiplier = sizeof(uint16_t);
+        } else if (rhs_type == GGML_TYPE_Q8_0) {
+            block_len = QK8_0;
+            num_bytes_multiplier = sizeof(float);
+        } else {
+            return false;
+        }
+
         const size_t block_rows = kernel->get_nr();
         const size_t kr         = kernel->get_kr();
 
@@ -559,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
 
             float *out = (float *)((char *)dst->data + i * nb1);
-            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
+            rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
         }
 
         return true;
@@ -567,36 +1204,39 @@ class tensor_traits : public ggml::cpu::tensor_traits {
 
 public:
     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
         const size_t n = tensor->ne[1];
         const size_t k = tensor->ne[0];
 
-        if (tensor->type == GGML_TYPE_Q4_0) {
-            if (!ctx.kernels_q4) {
-                return -1;
-            }
-            size_t nr = ctx.kernels_q4->gemm.get_nr();
-            size_t kr = ctx.kernels_q4->gemm.get_kr();
-            size_t sr = ctx.kernels_q4->gemm.get_sr();
+        kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
+        if (!header) {
+            return -1;
+        }
 
-            struct kai_rhs_pack_qs4cxs1s0_param params;
-            params.lhs_zero_point = 1;
-            params.rhs_zero_point = 8;
-            ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
-                                                  static_cast(data),
-                                                  nullptr, nullptr, tensor->data, 0, ¶ms);
-            GGML_UNUSED(data_size);
-            return 0;
-        } else if (tensor->type == GGML_TYPE_Q8_0) {
-            if (!ctx.kernels_q8) {
-                return -1;
-            }
+        header->magic      = GGML_KLEIDIAI_PACK_MAGIC;
+        header->version    = GGML_KLEIDIAI_PACK_VERSION;
+        header->slot_count = 0;
+
+        uint8_t * base_ptr = static_cast(tensor->data);
+        size_t cursor = sizeof(kleidiai_weight_header);
+        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+        std::array kernel_chain;
+        const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+        const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                       : kleidiai_collect_q4_chain(kernel_chain);
+        const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+        std::vector qdata;
+        std::vector  scales;
+
+        if (want_q8 && slot_total > 0) {
+            qdata.resize(n * k, 0);
+            scales.resize(n, 0.0f);
 
             const size_t row_stride = tensor->nb[1];
             const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;
 
-            std::vector qdata(n * k, 0);
-            std::vector scales(n, 0.0f);
-
             for (size_t row = 0; row < n; ++row) {
                 const auto * row_blocks = reinterpret_cast(
                     static_cast(data) + row * row_stride);
@@ -610,7 +1250,7 @@ public:
                         if (linear_idx >= k) {
                             break;
                         }
-                        const float value = d * blk.qs[l];
+                        const float value = d * static_cast(blk.qs[l]);
                         max_abs = std::max(max_abs, std::fabs(value));
                     }
                 }
@@ -627,31 +1267,73 @@ public:
                         if (linear_idx >= k) {
                             break;
                         }
-                        const float value = d * blk.qs[l];
+                        const float value = d * static_cast(blk.qs[l]);
                         int32_t q = scale > 0.0f ? static_cast(std::lround(value * inv_scale)) : 0;
                         q = std::clamp(q, -127, 127);
                         qdata[row * k + linear_idx] = static_cast(q);
                     }
                 }
             }
-
-            size_t nr = ctx.kernels_q8->gemm.get_nr();
-            size_t kr = ctx.kernels_q8->gemm.get_kr();
-            size_t sr = ctx.kernels_q8->gemm.get_sr();
-
-            struct kai_rhs_pack_qsi8cx_params params;
-            params.lhs_zero_point = 1;
-            params.scale_multiplier = 1.0f;
-
-            ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
-                                                  qdata.data(), nullptr, scales.data(),
-                                                  tensor->data, 0, ¶ms);
-            GGML_UNUSED(data_size);
-            return 0;
         }
 
-        GGML_UNUSED(data_size);
-        return -1;
+        for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+            if (!allow_fallback && slot > 0) {
+                break;
+            }
+            ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+            kernel_info * kernel = &kernels->gemm;
+            rhs_packing_info * rhs_info = &kernels->rhs_info;
+            if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
+                continue;
+            }
+
+            const size_t nr = kernel->get_nr();
+            const size_t kr = kernel->get_kr();
+            const size_t sr = kernel->get_sr();
+            const ggml_type rhs_type = kernels->rhs_type;
+            const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
+                                     rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
+            if (block_len == 0) {
+                continue;
+            }
+
+            const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
+            const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+            uint8_t * dst_ptr = base_ptr + aligned_cursor;
+
+            if (rhs_type == GGML_TYPE_Q4_0) {
+                struct kai_rhs_pack_qs4cxs1s0_param params;
+                params.lhs_zero_point = 1;
+                params.rhs_zero_point = 8;
+                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
+                                       static_cast(data), nullptr, nullptr,
+                                       dst_ptr, 0, ¶ms);
+            } else if (rhs_type == GGML_TYPE_Q8_0) {
+                struct kai_rhs_pack_qsi8cx_params params;
+                params.lhs_zero_point = 1;
+                params.scale_multiplier = 1.0f;
+                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
+                                       qdata.data(), nullptr, scales.data(),
+                                       dst_ptr, 0, ¶ms);
+            } else {
+                continue;
+            }
+
+            header->offsets[header->slot_count] = aligned_cursor;
+            header->sizes[header->slot_count]   = packed_size;
+            ++header->slot_count;
+
+            cursor = aligned_cursor + packed_size;
+        }
+
+        if (header->slot_count == 0) {
+            header->magic   = 0;
+            header->version = 0;
+            memcpy(tensor->data, data, data_size);
+        }
+
+        return 0;
     }
 };
 
@@ -681,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
 }
 
 static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "CPU_KLEIDIAI";
-
     GGML_UNUSED(buft);
+    return "CPU_KLEIDIAI";
 }
 
 static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -702,49 +1383,78 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return TENSOR_ALIGNMENT;
-
     GGML_UNUSED(buft);
+    return TENSOR_ALIGNMENT;
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
     GGML_UNUSED(buft);
 
+    if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
+        return ggml_nbytes(tensor);
+    }
+
     const size_t n = tensor->ne[1];
     const size_t k = tensor->ne[0];
 
-    ggml_kleidiai_kernels * kernels = nullptr;
-    size_t block_len = 0;
+    size_t cursor = sizeof(kleidiai_weight_header);
+    cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
 
-    if (tensor->type == GGML_TYPE_Q4_0) {
-        GGML_ASSERT(ctx.kernels_q4);
-        kernels = ctx.kernels_q4;
-        block_len = QK4_0;
-    } else if (tensor->type == GGML_TYPE_Q8_0) {
-        GGML_ASSERT(ctx.kernels_q8);
-        kernels = ctx.kernels_q8;
-        block_len = QK8_0;
-    } else {
-        return 0;
+    std::array kernel_chain;
+    const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+    const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                   : kleidiai_collect_q4_chain(kernel_chain);
+    const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+    size_t slot_count = 0;
+    for (int slot = 0; slot < slot_total; ++slot) {
+        if (!allow_fallback && slot > 0) {
+            break;
+        }
+        ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+        if (!kernels) {
+            continue;
+        }
+        kernel_info * kernel = &kernels->gemm;
+        rhs_packing_info * rhs_info = &kernels->rhs_info;
+        if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
+            continue;
+        }
+
+        const ggml_type rhs_type = kernels->rhs_type;
+        const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                 rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+        if (block_len == 0) {
+            continue;
+        }
+
+        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+        cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
+        ++slot_count;
     }
 
-    const size_t nr = kernels->gemm.get_nr();
-    const size_t kr = kernels->gemm.get_kr();
-    const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
-    const size_t raw     = ggml_nbytes(tensor);
+    if (slot_count == 0) {
+        return ggml_nbytes(tensor);
+    }
 
-    return packed > raw ? packed : raw;
+    return std::max(cursor, ggml_nbytes(tensor));
 }
 
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        std::array kernel_chain;
+        const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
         if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
             (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
             op->src[0]->buffer &&
             (ggml_n_dims(op->src[0]) == 2) &&
-            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
-            if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
+            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
+            slot_total > 0) {
+            if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
+                return false;
+            }
+            if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
                 return false;
             }
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
@@ -762,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
         if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
-            }
-            else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
-                if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
-                    (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
-                    return nullptr;
+            } else {
+                std::array kernel_chain;
+                const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
+                const bool has_kernel = slot_total > 0;
+                if (has_kernel && op->src[1]->ne[1] > 1) {
+                    if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
+                        (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
+                        return nullptr;
+                    }
+                    return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
                 }
-
-                return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
             }
         }
         return nullptr;
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h
deleted file mode 100644
index a7078687..00000000
--- a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h
+++ /dev/null
@@ -1,333 +0,0 @@
-#pragma once
-
-typedef vector unsigned char vec_t;
-typedef __vector_quad acc_t;
-
-template 
-class tinyBLAS_Q0_PPC {
-  public:
-    tinyBLAS_Q0_PPC(int64_t k,
-                    const TA *A, int64_t lda,
-                    const block_q8_0 *B, int64_t ldb,
-                    float *C, int64_t ldc,
-                    int ith, int nth);
-
-    void matmul(int64_t m, int64_t n);
-    void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
-        vec_t A_pack[mc*kc*2];
-        vec_t B_pack[nc*kc*2];
-        int comparray[mc*kc];
-        constexpr bool is_Ablock_q4 = std::is_same_v;
-        int64_t ytiles = m / mc;
-        int64_t xtiles = n / nc;
-        int64_t tiles  = xtiles * ytiles;
-        int64_t duty = (tiles + nth - 1) / nth;
-        int64_t start = duty * ith;
-        int64_t end = start + duty;
-        if (end > tiles) {
-            end = tiles;
-        }
-        for (int64_t job = start; job < end; ++job) {
-            int64_t ii = (job / xtiles) * mc;
-            int64_t jj = (job % xtiles) * nc;
-            for (int64_t kk = 0; kk < k; kk += kc) {
-                if constexpr(is_Ablock_q4) {
-                    packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
-                } else {
-                    packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
-                }
-                packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
-                KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
-            }
-        }
-    }
-
-  private:
-    inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
-        for (int I = 0; I < RM; I++) {
-            for (int J = 0; J < RN; J++) {
-                *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
-            }
-        }
-    }
-
-    inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
-        for (int I = 0; I < RM; I++) {
-            for (int J = 0; J < RN; J++) {
-                float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
-                *c_ptr += *((float*)&fin_res[idx+I]+J);
-            }
-        }
-    }
-
-    template
-    inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
-        vector signed int vec_C[4];
-        vector float CA[4] = {0};
-        vector float res[4] = {0};
-        __builtin_mma_disassemble_acc(vec_C, ACC);
-        for (int i = 0; i < 4; i++) {
-            CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
-            res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
-            fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
-        }
-    }
-
-    inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
-        const vector signed char lowMask = vec_splats((signed char)0xF);
-        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
-        const vector signed char v8 = vec_splats((signed char)0x8);
-        vector signed int vsum = {0};
-        vector signed int vsum2 = {0};
-        c[0] = vec_and(c[1], lowMask);
-        c[1] = vec_sr(c[1], v4);
-        c[0] = vec_sub(c[0], v8);
-        c[1] = vec_sub(c[1], v8);
-        vsum = vec_sum4s(c[0], vsum);
-        vsum2 = vec_sum4s(c[1], vsum2);
-        vsum = vec_add(vsum, vsum2);
-        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-    }
-
-    template 
-    inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
-        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
-        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
-        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
-        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
-        V2 t1, t2, t3, t4, t5, t6, t7, t8;
-        vector unsigned char xor_vector;
-        uint8_t flip_vec = 0x80;
-        xor_vector = vec_splats(flip_vec);
-        t1 = vec_perm(s1, s2, swiz1);
-        t2 = vec_perm(s1, s2, swiz2);
-        t3 = vec_perm(s3, s4, swiz1);
-        t4 = vec_perm(s3, s4, swiz2);
-        t5 = vec_perm(t1, t3, swiz3);
-        t6 = vec_perm(t1, t3, swiz4);
-        t7 = vec_perm(t2, t4, swiz3);
-        t8 = vec_perm(t2, t4, swiz4);
-        if (flip == true) {
-            t5 = vec_xor(t5, xor_vector);
-            t6 = vec_xor(t6, xor_vector);
-            t7 = vec_xor(t7, xor_vector);
-            t8 = vec_xor(t8, xor_vector);
-        }
-        vec_xst(t5, 0, vecOffset);
-        vec_xst(t6, 0, vecOffset+16);
-        vec_xst(t7, 0, vecOffset+32);
-        vec_xst(t8, 0, vecOffset+48);
-    }
-
-    template
-    inline void kernel(int64_t ii, int64_t jj) {
-        if constexpr(RM == 4 && RN == 8) {
-            KERNEL_4x8(ii,jj);
-        } else if constexpr(RM == 8 && RN == 4) {
-            KERNEL_8x4(ii,jj);
-        } else if constexpr(RM == 8 && RN == 8) {
-            KERNEL_8x8(ii,jj);
-        } else {
-            assert(false && "RN/RM values not supported");
-        }
-    }
-    template
-    void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray);
-    template
-    void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
-    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
-    void KERNEL_4x8(int64_t ii, int64_t jj);
-    void KERNEL_8x4(int64_t ii, int64_t jj);
-    void KERNEL_8x8(int64_t ii, int64_t jj);
-    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
-    template 
-    void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
-
-    void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
-        for (int I = 0; I<8; I++) {
-            float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
-            for (int J = 0; J<4; J++) {
-                *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
-                *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
-             }
-         }
-    }
-
-    inline void process_q8_elements(const int8_t *qs, int *ca) {
-        vector signed char c1 = vec_xl(0, qs);
-        vector signed char c2 = vec_xl(16, qs);
-        vector signed int vsum1 = {0};
-        vector signed int vsum2 = {0};
-        vsum1 = vec_sum4s(c1, vsum1);
-        vsum2 = vec_sum4s(c2, vsum2);
-        vector signed int vsum = vec_add(vsum1, vsum2);
-        *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
-    }
-
-    template
-    void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
-        int64_t i, j;
-        block_q8_0 *aoffset = NULL;
-        VA *vecOffset = NULL;
-        block_q8_0* aoffsets[8];
-        __vector_pair arr[8];
-        VB c[8][2] = {0};
-        VB c1[8] = {0}; VB c2[8] = {0};
-        aoffset = const_cast(a);
-        vecOffset = vec;
-        j = (rows >> 3);
-        int index = 0;
-        if (j > 0) {
-            do {
-                for (int it = 0; it < 8; it++)
-                    aoffsets[it] = aoffset + it*lda;
-                aoffset += 8 * lda;
-                for (int blk = 0; blk < kc; blk++) {
-                    for (int it = 0; it < 8; it++) {
-                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
-                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
-                        c1[it] = c[it][0];
-                        c2[it] = c[it][1];
-                        if (comparray){
-                            process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
-                        }
-                    }
-                    vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
-                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
-                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
-                    vecOffset += 256;
-                }
-                j--;
-                index += 8*kc;
-            } while(j > 0);
-        }
-
-    }
-
-    void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
-        int64_t i, j;
-        TA *aoffset = NULL;
-        int8_t *vecOffset = NULL;
-        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-        vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
-        vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
-        aoffset = const_cast(a);
-        vecOffset = vec;
-        int index = 0;
-        j = (rows >> 3);
-        if (j > 0) {
-            do {
-                aoffset1 = aoffset;
-                aoffset2 = aoffset1 + lda;
-                aoffset3 = aoffset2 + lda;
-                aoffset4 = aoffset3 + lda;
-                aoffset5 = aoffset4 + lda;
-                aoffset6 = aoffset5 + lda;
-                aoffset7 = aoffset6 + lda;
-                aoffset8 = aoffset7 + lda;
-                aoffset += 8 * lda;
-                for (int blk = 0; blk < kc; blk++) {
-                    c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs));
-                    c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs));
-                    c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs));
-                    c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs));
-                    c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs));
-                    c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs));
-                    c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs));
-                    c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs));
-
-                    process_q4_elements(c1, &comparray[index + 8*blk+0]);
-                    process_q4_elements(c2, &comparray[index + 8*blk+1]);
-                    process_q4_elements(c3, &comparray[index + 8*blk+2]);
-                    process_q4_elements(c4, &comparray[index + 8*blk+3]);
-                    process_q4_elements(c5, &comparray[index + 8*blk+4]);
-                    process_q4_elements(c6, &comparray[index + 8*blk+5]);
-                    process_q4_elements(c7, &comparray[index + 8*blk+6]);
-                    process_q4_elements(c8, &comparray[index + 8*blk+7]);
-                    vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
-                    vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
-                    vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
-                    vecOffset += 256;
-                }
-                j--;
-                index += 8*kc;
-            } while (j > 0);
-        }
-    }
-
-    void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
-        acc_t acc[8];
-        for (int i = 0; i < mc ; i += 8) {
-            for (int j = 0; j < nc; j += 8) {
-                vector float fin_res[16] = {0};
-                vector float vs[16] = {0};
-                for (int64_t kk = 0; kk < kc; kk+=2) {
-                    for (int x = 0; x < 8; x++) {
-                        __builtin_mma_xxsetaccz(&acc[x]);
-                    }
-                    int A_block_idx = (i/8)*(16*kc) + kk*16;
-                    int B_block_idx = (j/8)*(16*kc)+ kk*16;
-                    vec_t *A_block = &vec_A[A_block_idx];
-                    vec_t *B_block = &vec_B[B_block_idx];
-                    for (int x = 0; x < 8; x++) {
-                        __builtin_mma_xvi8ger4pp(&acc[0], A_block[x],     B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[2], A_block[x],     B_block[x+8]);
-                        __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8],   B_block[x+8]);
-                    }
-                    compute_scale(ii+i, jj+j, l+kk, vs);
-                    int c_index = (i/8)*(8*kc)+ kk*8;
-                    int* c_block = &comparray[c_index];
-                    compute(&acc[0], 0,  0,  c_block, vs, fin_res);
-                    compute(&acc[1], 4,  4,  c_block, vs, fin_res);
-                    compute(&acc[2], 0,  8,  c_block, vs, fin_res);
-                    compute(&acc[3], 4, 12,  c_block, vs, fin_res);
-
-                    A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
-                    B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
-                    A_block = &vec_A[A_block_idx];
-                    B_block = &vec_B[B_block_idx];
-                    for (int x = 0; x < 8; x++) {
-                        __builtin_mma_xvi8ger4pp(&acc[4], A_block[x],     B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
-                        __builtin_mma_xvi8ger4pp(&acc[6], A_block[x],     B_block[x+8]);
-                        __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8],   B_block[x+8]);
-                    }
-                    compute_scale(ii+i, jj+j, l+kk+1, vs);
-                    c_index = (i/8)*(8*kc)+ (kk+1)*8;
-                    c_block = &comparray[c_index];
-                    compute(&acc[4], 0,  0,  c_block, vs, fin_res);
-                    compute(&acc[5], 4,  4,  c_block, vs, fin_res);
-                    compute(&acc[6], 0,  8,  c_block, vs, fin_res);
-                    compute(&acc[7], 4, 12,  c_block, vs, fin_res);
-
-                }
-                if (l == 0) {
-                    save_res(ii+i,   jj+j,    0,  fin_res);
-                    save_res(ii+i+4, jj+j,    4,  fin_res);
-                    save_res(ii+i,   jj+j+4,  8,  fin_res);
-                    save_res(ii+i+4, jj+j+4, 12,  fin_res);
-                } else {
-                    add_save_res(ii+i,   jj+j,    0,  fin_res);
-                    add_save_res(ii+i+4, jj+j,    4,  fin_res);
-                    add_save_res(ii+i,   jj+j+4,  8,  fin_res);
-                    add_save_res(ii+i+4, jj+j+4, 12,  fin_res);
-                }
-            }
-        }
-    }
-
-    const TA *const A;
-    const block_q8_0 *const B;
-    float *C;
-    const int64_t k;
-    int64_t kc;
-    const int64_t lda;
-    const int64_t ldb;
-    const int64_t ldc;
-    const int ith;
-    const int nth;
-};
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index 7dc36d4f..c89e5076 100644
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
 #endif
 
 #if defined(__MMA__)
-#include "sgemm-ppc.h"
+typedef vector unsigned char vec_t;
+typedef __vector_quad acc_t;
 #endif
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // VECTORIZED FUSED MULTIPLY ADD
@@ -532,7 +533,7 @@ class tinyBLAS {
         if constexpr (RN > 1) {
             return mnpack(m, n, SIZE_N, BN);
         } else {
-            GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
+            GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
             GGML_ASSERT(false); // we have miss something.
         }
     }
@@ -710,7 +711,7 @@ class tinyBLAS_RVV {
         if constexpr (RN > 1) {
             return mnpack(m, n, SIZE_N, BN);
         } else {
-            GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
+            GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N);
             GGML_ASSERT(false); // we have miss something.
         }
     }
@@ -1797,10 +1798,27 @@ class tinyBLAS_Q0_AVX {
       } \
    } \
 
+template
+struct mma_instr;
+
+template<>
+struct mma_instr {
+    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
+        __builtin_mma_xvbf16ger2pp(acc, a, b);
+    }
+};
+
+template<>
+struct mma_instr {
+    static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
+        __builtin_mma_xvf16ger2pp(acc, a, b);
+    }
+};
+
 template 
-class tinyBLAS_BF16_PPC {
+class tinyBLAS_HP16_PPC {
   public:
-    tinyBLAS_BF16_PPC(int64_t k,
+    tinyBLAS_HP16_PPC(int64_t k,
                 const TA *A, int64_t lda,
                 const TB *B, int64_t ldb,
                 TC *C, int64_t ldc,
@@ -2118,8 +2136,8 @@ class tinyBLAS_BF16_PPC {
             packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
             packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
             for (int x = 0; x < 4; x++) {
-                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
+                mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
             }
         }
         SAVE_ACC(&acc_0, ii, jj);
@@ -2135,8 +2153,8 @@ class tinyBLAS_BF16_PPC {
             packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
             packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
             for (int x = 0; x < 4; x++) {
-                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
+                mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                mma_instr::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
             }
         }
         SAVE_ACC(&acc_0, ii, jj);
@@ -2155,10 +2173,10 @@ class tinyBLAS_BF16_PPC {
             packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
             packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
             for (int x = 0; x < 4; x++) {
-                __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
-                __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
-                __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
+                mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
+                mma_instr::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
+                mma_instr::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
             }
         }
 
@@ -2189,7 +2207,7 @@ class tinyBLAS_BF16_PPC {
                 packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
                 packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
                 for (int x = 0; x<2; x++) {
-                    __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
+                    mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
                 }
             }
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -2224,8 +2242,8 @@ class tinyBLAS_BF16_PPC {
                 packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
                 packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
                 for (int x = 0; x<4; x++) {
-                    __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
-                    __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
+                    mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]);
+                    mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
                 }
             }
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
@@ -2284,43 +2302,299 @@ class tinyBLAS_BF16_PPC {
     const int nth;
 };
 
-    template 
-    tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k,
-        const TA *A, int64_t lda,
-        const block_q8_0 *B, int64_t ldb,
-        float *C, int64_t ldc,
-        int ith, int nth)
+template 
+class tinyBLAS_Q0_PPC {
+  public:
+    tinyBLAS_Q0_PPC(int64_t k,
+             const TA * A, int64_t lda,
+             const block_q8_0 * B, int64_t ldb,
+             float * C, int64_t ldc,
+             int ith, int nth)
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-                kc = 64;
     }
 
-    template
-    void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) {
-        int mc = 64; int nc = 64;
-        if (n % 8 == 0 && n < nc) {
-                nc = n;
-                mc = 32 ;
-                kc = 32;
+    void matmul(int64_t m, int64_t n) {
+        const int64_t mc = 64;
+        const int64_t kc = 64;
+        int64_t nc = 64;
+        int64_t n_aligned = 0;
+        if (n % 64 == 0) {
+            n_aligned = n;
+        } else if (n == 4) {
+            n_aligned = 4;
+        } else if (n < 64) {
+            n_aligned = (n / 8) * 8;
+        } else {
+            n_aligned = (n / 64) * 64;
         }
-        const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
-        if (is_aligned) {
-            this->matmul_tiled_q0(m, n, mc, nc, kc);
+
+        if (n_aligned > 0) {
+            if (n_aligned % 64 == 0)      nc = 64;
+            else if (n_aligned == n)      nc = n;
+            else if (n_aligned % 32 == 0) nc = 32;
+            else if (n_aligned % 24 == 0) nc = 24;
+            else if (n_aligned % 16 == 0) nc = 16;
+            else                          nc = 8;
+        }
+        bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
+        if (can_use_tiled) {
+            matmul_tiled(m, n_aligned, mc, nc, kc);
+            if (n > n_aligned) {
+                mnpack(0, m, n_aligned, n);
+            }
         } else {
             mnpack(0, m, 0, n);
         }
     }
 
-   template
-   template
-   void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) {
+  private:
+    inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
+        for (int I = 0; I < RM; I++) {
+            for (int J = 0; J < RN; J++) {
+                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
+            }
+        }
+    }
+
+    inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+        vec_t vec_C[4];
+        __builtin_mma_disassemble_acc(vec_C, ACC);
+        for (int I = 0; I < 4; I++) {
+            for (int J = 0; J < 4; J++) {
+                *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
+            }
+        }
+    }
+
+    inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
+        vec_t vec_C[4];
+        __builtin_mma_disassemble_acc(vec_C, ACC);
+        for (int I = 0; I < 4; I++) {
+            for (int J = 0; J < 4; J++) {
+                float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
+                *c_ptr += *((float *)&vec_C[I] + J);
+            }
+        }
+    }
+
+    template
+    inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
+        vector signed int vec_C[4];
+        vector float CA[4] = {0};
+        vector float res[4] = {0};
+        __builtin_mma_disassemble_acc(vec_C, ACC);
+        for (int i = 0; i < 4; i++) {
+            CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
+            res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+            fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
+        }
+    }
+
+    inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
+        const vector signed char lowMask = vec_splats((signed char)0xF);
+        const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+        const vector signed char v8 = vec_splats((signed char)0x8);
+        vector signed int vsum = {0};
+        vector signed int vsum2 = {0};
+        c[0] = vec_and(c[1], lowMask);
+        c[1] = vec_sr(c[1], v4);
+        c[0] = vec_sub(c[0], v8);
+        c[1] = vec_sub(c[1], v8);
+        vsum = vec_sum4s(c[0], vsum);
+        vsum2 = vec_sum4s(c[1], vsum2);
+        vsum = vec_add(vsum, vsum2);
+        *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
+    }
+
+    template 
+    inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+        V2 t1, t2, t3, t4, t5, t6, t7, t8;
+        vector unsigned char xor_vector;
+        uint8_t flip_vec = 0x80;
+        xor_vector = vec_splats(flip_vec);
+        t1 = vec_perm(s1, s2, swiz1);
+        t2 = vec_perm(s1, s2, swiz2);
+        t3 = vec_perm(s3, s4, swiz1);
+        t4 = vec_perm(s3, s4, swiz2);
+        t5 = vec_perm(t1, t3, swiz3);
+        t6 = vec_perm(t1, t3, swiz4);
+        t7 = vec_perm(t2, t4, swiz3);
+        t8 = vec_perm(t2, t4, swiz4);
+        if (flip == true) {
+            t5 = vec_xor(t5, xor_vector);
+            t6 = vec_xor(t6, xor_vector);
+            t7 = vec_xor(t7, xor_vector);
+            t8 = vec_xor(t8, xor_vector);
+        }
+        vec_xst(t5, 0, vecOffset);
+        vec_xst(t6, 0, vecOffset + 16);
+        vec_xst(t7, 0, vecOffset + 32);
+        vec_xst(t8, 0, vecOffset + 48);
+    }
+
+    inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
+        const vector signed char lowMask = vec_splats((signed char)0x0F);
+        const vector signed char v8      = vec_splats((signed char)0x08);
+        const vector unsigned char v4    = vec_splats((unsigned char)4);
+        lo = vec_and(packed, lowMask);
+        hi = vec_sr(packed, v4);
+        lo = vec_sub(lo, v8);
+        hi = vec_sub(hi, v8);
+    }
+
+    inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
+        vec_t t[8], s[8];
+        vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
+        vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
+        vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        for (int i = 0; i < 4; i += 2) {
+            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
+            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
+        }
+        for (int i = 4; i < 8; i += 2) {
+            t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
+            t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
+        }
+        s[0] = vec_perm(t[0], t[2], swiz3);
+        s[1] = vec_perm(t[0], t[2], swiz4);
+        s[2] = vec_perm(t[1], t[3], swiz3);
+        s[3] = vec_perm(t[1], t[3], swiz4);
+        s[4] = vec_perm(t[4], t[6], swiz3);
+        s[5] = vec_perm(t[4], t[6], swiz4);
+        s[6] = vec_perm(t[5], t[7], swiz3);
+        s[7] = vec_perm(t[5], t[7], swiz4);
+        for (int i = 0; i < 8; ++i) {
+            vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
+        }
+    }
+
+    static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
+        vector signed short i16_hi = vec_unpackh(raw);
+        vector signed short i16_lo = vec_unpackl(raw);
+
+        vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
+        vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
+        vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
+        vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
+        out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
+        out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
+    }
+
+    void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
+        unsigned char * vecOffset = vec;
+        for (int i = 0; i < rows; i += 8) {
+            const block_q4_0 * rows_base[8];
+            for (int r = 0; r < 8; r++) {
+                rows_base[r] = a + (i + r) * lda;
+            }
+            for (int blk = 0; blk < blocks; blk++) {
+                vector unsigned short hp_res[8][4];
+                for (int r = 0; r < 8; r++) {
+                    const block_q4_0 * current_blk = rows_base[r] + blk;
+                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
+                    vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs);
+                    vector signed char c1, c2;
+                    unpack_q4_to_q8(v_qs, c1, c2);
+                    convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
+                    convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
+                }
+                for (int c = 0; c < 4; c++) {
+                    vector unsigned char c_arr[8];
+                    for (int r = 0; r < 8; r++) {
+                        c_arr[r] = (vector unsigned char)hp_res[r][c];
+                    }
+                    vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
+                    vecOffset += 128;
+                }
+            }
+        }
+    }
+
+    template 
+    static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
+        unsigned char * vecOffset = vec;
+        const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
+        const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
+        const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+
+        for (int i = 0; i < rows; i += chunk_size) {
+            const block_q8_0 * rows_base[chunk_size];
+            for (int r = 0; r < chunk_size; r++) {
+                rows_base[r] = a + (i + r) * lda;
+            }
+            for (int blk = 0; blk < blocks; blk++) {
+                vector unsigned short hp_res[chunk_size][4];
+                for (int r = 0; r < chunk_size; r++) {
+                    const block_q8_0 * b = rows_base[r] + blk;
+                    vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
+                    vector signed char c[2];
+                    __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
+                    __builtin_vsx_disassemble_pair(c, & pair);
+                    convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
+                    convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
+                }
+                for (int col = 0; col < 4; col++) {
+                    if constexpr (chunk_size == 8) {
+                        vec_t t[8];
+                        t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
+                        t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
+                        t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
+                        t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
+                        t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
+                        t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
+                        t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
+                        t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
+
+                        vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
+                        vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
+                        vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
+                        vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
+                        vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
+                        vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
+                        vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
+                        vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
+                        vecOffset += 128;
+                    } else {
+                        vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
+                        vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
+                        vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
+                        vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
+
+                        vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
+                        vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
+                        vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
+                        vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
+                        vecOffset += 64;
+                    }
+                }
+            }
+        }
+    }
+
+    void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
+        if (rows == 4) {
+            pack_q8_block<4>(a, lda, rows, blocks, vec);
+        } else {
+            pack_q8_block<8>(a, lda, rows, blocks, vec);
+        }
+    }
+
+    template
+    void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array & comparray) {
         int64_t i, j;
-        TA *aoffset = NULL;
-        int8_t *vecOffset = NULL;
-        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        TA * aoffset = NULL;
+        int8_t * vecOffset = NULL;
+        TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
+        TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
         vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
         vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
-        aoffset = const_cast(a);
+        aoffset = const_cast(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
@@ -2337,27 +2611,27 @@ class tinyBLAS_BF16_PPC {
                 i = (cols >> 2);
                 if (i > 0) {
                     do {
-                        c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
-                        c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
-                        c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
-                        c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
-                        c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs));
-                        c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs));
-                        c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs));
-                        c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs));
+                        c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
+                        c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
+                        c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
+                        c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
+                        c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs);
+                        c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs);
+                        c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs);
+                        c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs);
 
-                        process_q4_elements(c1, &comparray[0]);
-                        process_q4_elements(c2, &comparray[1]);
-                        process_q4_elements(c3, &comparray[2]);
-                        process_q4_elements(c4, &comparray[3]);
-                        process_q4_elements(c5, &comparray[4]);
-                        process_q4_elements(c6, &comparray[5]);
-                        process_q4_elements(c7, &comparray[6]);
-                        process_q4_elements(c8, &comparray[7]);
+                        process_q4_elements(c1, & comparray[0]);
+                        process_q4_elements(c2, & comparray[1]);
+                        process_q4_elements(c3, & comparray[2]);
+                        process_q4_elements(c4, & comparray[3]);
+                        process_q4_elements(c5, & comparray[4]);
+                        process_q4_elements(c6, & comparray[5]);
+                        process_q4_elements(c7, & comparray[6]);
+                        process_q4_elements(c8, & comparray[7]);
                         vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                        vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
-                        vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
-                        vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
+                        vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
+                        vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
+                        vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
                         aoffset1 += lda;
                         aoffset2 += lda;
                         aoffset3 += lda;
@@ -2383,17 +2657,17 @@ class tinyBLAS_BF16_PPC {
             i = (cols >> 2);
             if (i > 0) {
                 do {
-                    c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
-                    c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
-                    c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
-                    c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs));
+                    c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
+                    c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
+                    c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
+                    c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs);
 
-                    process_q4_elements(c1, &comparray[0]);
-                    process_q4_elements(c2, &comparray[1]);
-                    process_q4_elements(c3, &comparray[2]);
-                    process_q4_elements(c4, &comparray[3]);
+                    process_q4_elements(c1, & comparray[0]);
+                    process_q4_elements(c2, & comparray[1]);
+                    process_q4_elements(c3, & comparray[2]);
+                    process_q4_elements(c4, & comparray[3]);
                     vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
+                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
                     aoffset1 += lda;
                     aoffset2 += lda;
                     aoffset3 += lda;
@@ -2412,17 +2686,17 @@ class tinyBLAS_BF16_PPC {
             if (i > 0) {
                 do {
                     switch(rows) {
-                        case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs));
-                        case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs));
-                        case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs));
+                        case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs);
+                        case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs);
+                        case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs);
                             break;
                     }
-                    process_q4_elements(c1, &comparray[0]);
-                    process_q4_elements(c2, &comparray[1]);
-                    process_q4_elements(c3, &comparray[2]);
-                    process_q4_elements(c4, &comparray[3]);
+                    process_q4_elements(c1, & comparray[0]);
+                    process_q4_elements(c2, & comparray[1]);
+                    process_q4_elements(c3, & comparray[2]);
+                    process_q4_elements(c4, & comparray[3]);
                     vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
-                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
+                    vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
                     aoffset1 += lda;
                     aoffset2 += lda;
                     aoffset3 += lda;
@@ -2433,39 +2707,38 @@ class tinyBLAS_BF16_PPC {
         }
     }
 
-    template
     template
-    void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+    void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
         int64_t i, j;
-        block_q8_0 *aoffset = NULL;
-        VA *vecOffset = NULL;
-        block_q8_0* aoffsets[8];
+        block_q8_0 * aoffset = NULL;
+        VA * vecOffset = NULL;
+        block_q8_0 * aoffsets[8];
         __vector_pair arr[8];
         VB c[8][2] = {0};
         VB c1[8] = {0}; VB c2[8] = {0};
-        aoffset = const_cast(a);
+        aoffset = const_cast(a);
         vecOffset = vec;
         j = (rows >> 3);
         if (j > 0) {
             do {
                 aoffsets[0] = aoffset;
                 for (int it = 1; it < 8; it++)
-                    aoffsets[it] = aoffsets[it-1] + lda;
+                    aoffsets[it] = aoffsets[it - 1] + lda;
                 aoffset += 8 * lda;
 
                 i = (cols >> 3);
                 if (i > 0) {
                 do {
                     for (int it = 0; it < 8; it++) {
-                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
-                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
+                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);
                         c1[it] = c[it][0];
                         c2[it] = c[it][1];
                     }
                     vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
-                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
-                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
+                    vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
+                    vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
                     for (int it = 0; it < 8; it++)
                         aoffsets[it] += lda;
                     vecOffset += 256;
@@ -2484,13 +2757,13 @@ class tinyBLAS_BF16_PPC {
             if (i > 0) {
                do {
                     for (int it = 0; it < 4; it++) {
-                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
-                        __builtin_vsx_disassemble_pair(c[it], &arr[it]);
+                        arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
+                        __builtin_vsx_disassemble_pair(c[it], & arr[it]);
                         c1[it] = c[it][0];
                         c2[it] = c[it][1];
                     }
                     vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
                     for (int it = 0; it < 4; it++) {
                         aoffsets[it] += lda;
                     }
@@ -2503,24 +2776,24 @@ class tinyBLAS_BF16_PPC {
         if (rows & 3) {
             aoffsets[0]  = aoffset;
             for (int it = 1; it < 3; it++ )
-                aoffsets[it] = aoffsets[it-1] + lda;
+                aoffsets[it] = aoffsets[it - 1] + lda;
             i = (cols >> 3);
             if (i > 0) {
                 do {
                     switch(rows) {
-                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
-                                __builtin_vsx_disassemble_pair(c[2], &arr[2]);
+                        case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
+                                __builtin_vsx_disassemble_pair(c[2], & arr[2]);
                                 c1[2] = c[2][0]; c2[2] = c[2][1];
-                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
-                                __builtin_vsx_disassemble_pair(c[1], &arr[1]);
+                        case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
+                                __builtin_vsx_disassemble_pair(c[1], & arr[1]);
                                 c1[1] = c[1][0]; c2[1] = c[1][1];
-                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
-                                __builtin_vsx_disassemble_pair(c[0], &arr[0]);
+                        case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
+                                __builtin_vsx_disassemble_pair(c[0], & arr[0]);
                                 c1[0] = c[0][0]; c2[0] = c[0][1];
                                 break;
                     }
                     vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
-                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
+                    vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
                     for (int it = 0; it < 3; it++)
                          aoffsets[it] += lda;
                     vecOffset += 128;
@@ -2530,8 +2803,7 @@ class tinyBLAS_BF16_PPC {
         }
     }
 
-    template
-    void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
         int m_rem = MIN(m - m0, 16);
         int n_rem = MIN(n - n0, 16);
 
@@ -2568,8 +2840,7 @@ class tinyBLAS_BF16_PPC {
     }
 
 
-    template
-    void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) {
+    void KERNEL_4x8(int64_t ii, int64_t jj) {
         vec_t vec_A[8], vec_B[16] = {0};
         acc_t acc_0, acc_1;
         std::array comparray {};
@@ -2577,26 +2848,26 @@ class tinyBLAS_BF16_PPC {
         vector float vs[8] = {0};
         bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
-            __builtin_mma_xxsetaccz(&acc_0);
-            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(& acc_0);
+            __builtin_mma_xxsetaccz(& acc_1);
             if (std::is_same_v) {
-               packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
             } else {
-               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
             }
-            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
             for(int x = 0; x < 8; x++) {
-                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
             }
             for (int I = 0; I<4; I++) {
                 for (int J = 0; J<4; J++) {
-                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
-                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
+                    *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
                 }
             }
             if (!isAblock_q4) {
-                auto aoffset = A+(ii*lda)+l;
+                auto aoffset = A + (ii * lda) + l;
                 for (int i = 0; i < 4; i++) {
                     comparray[i] = 0;
                     int ca = 0;
@@ -2607,15 +2878,14 @@ class tinyBLAS_BF16_PPC {
                     aoffset += lda;
                 }
             }
-            compute(&acc_0, 0, 0, comparray, vs, fin_res);
-            compute(&acc_1, 0, 4, comparray, vs, fin_res);
+            compute(& acc_0, 0, 0, comparray, vs, fin_res);
+            compute(& acc_1, 0, 4, comparray, vs, fin_res);
         }
         save_res(ii, jj, 0, fin_res);
-        save_res(ii, jj+4, 4, fin_res);
+        save_res(ii, jj + 4, 4, fin_res);
     }
 
-    template
-    void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) {
+    void KERNEL_8x4(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[8] = {0};
         acc_t acc_0, acc_1;
         std::array comparray {};
@@ -2623,25 +2893,25 @@ class tinyBLAS_BF16_PPC {
         vector float vs[8] = {0};
         bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
-            __builtin_mma_xxsetaccz(&acc_0);
-            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(& acc_0);
+            __builtin_mma_xxsetaccz(& acc_1);
             if (std::is_same_v) {
-               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
             } else {
-               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
             }
-            packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
+            packNormal((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
             for(int x = 0; x < 8; x++) {
-                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
             }
-            for (int I = 0; I<8; I++) {
-                for (int J = 0; J<4; J++) {
-                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+            for (int I = 0; I < 8; I++) {
+                for (int J = 0; J < 4; J++) {
+                    *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
                 }
             }
             if (!isAblock_q4) {
-                auto aoffset = A+(ii*lda)+l;
+                auto aoffset = A + (ii * lda) + l;
                 for (int i = 0; i < 8; i++) {
                     comparray[i] = 0;
                     int ca = 0;
@@ -2652,15 +2922,14 @@ class tinyBLAS_BF16_PPC {
                     aoffset += lda;
                 }
             }
-            compute(&acc_0, 0, 0, comparray, vs, fin_res);
-            compute(&acc_1, 4, 4, comparray, vs, fin_res);
+            compute(& acc_0, 0, 0, comparray, vs, fin_res);
+            compute(& acc_1, 4, 4, comparray, vs, fin_res);
         }
         save_res(ii, jj, 0, fin_res);
-        save_res(ii+4, jj, 4, fin_res);
+        save_res(ii + 4, jj, 4, fin_res);
     }
 
-    template
-    void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) {
+    void KERNEL_8x8(int64_t ii, int64_t jj) {
         vec_t vec_A[16], vec_B[16] = {0};
         acc_t acc_0, acc_1, acc_2, acc_3;
         acc_t acc_4, acc_5, acc_6, acc_7;
@@ -2669,30 +2938,30 @@ class tinyBLAS_BF16_PPC {
         vector float vs[16] = {0};
         bool isAblock_q4 = std::is_same_v;
         for (int l = 0; l < k; l++) {
-            __builtin_mma_xxsetaccz(&acc_0);
-            __builtin_mma_xxsetaccz(&acc_1);
-            __builtin_mma_xxsetaccz(&acc_2);
-            __builtin_mma_xxsetaccz(&acc_3);
+            __builtin_mma_xxsetaccz(& acc_0);
+            __builtin_mma_xxsetaccz(& acc_1);
+            __builtin_mma_xxsetaccz(& acc_2);
+            __builtin_mma_xxsetaccz(& acc_3);
             if (std::is_same_v) {
-               packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
+               packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
             } else {
-               packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+               packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
             }
-            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
             for(int x = 0; x < 8; x++) {
-                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
-                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
-                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
+                __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
             }
-            for (int I = 0; I<8; I++) {
-                for (int J = 0; J<4; J++) {
-                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
-                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+            for (int I = 0; I < 8 ; I++) {
+                for (int J = 0; J < 4; J++) {
+                    *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
+                    *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
                 }
             }
             if (!isAblock_q4) {
-                auto aoffset = A+(ii*lda)+l;
+                auto aoffset = A + (ii * lda) + l;
                 for (int i = 0; i < 8; i++) {
                     comparray[i] = 0;
                     int ca = 0;
@@ -2703,19 +2972,99 @@ class tinyBLAS_BF16_PPC {
                     aoffset += lda;
                 }
             }
-            compute(&acc_0, 0, 0, comparray, vs, fin_res);
-            compute(&acc_1, 4, 4, comparray, vs, fin_res);
-            compute(&acc_2, 0, 8, comparray, vs, fin_res);
-            compute(&acc_3, 4, 12, comparray, vs, fin_res);
+            compute(& acc_0, 0, 0, comparray, vs, fin_res);
+            compute(& acc_1, 4, 4, comparray, vs, fin_res);
+            compute(& acc_2, 0, 8, comparray, vs, fin_res);
+            compute(& acc_3, 4, 12, comparray, vs, fin_res);
         }
         save_res(ii, jj, 0, fin_res);
-        save_res(ii+4, jj, 4, fin_res);
-        save_res(ii, jj+4, 8, fin_res);
-        save_res(ii+4, jj+4, 12, fin_res);
+        save_res(ii + 4, jj, 4, fin_res);
+        save_res(ii, jj + 4, 8, fin_res);
+        save_res(ii + 4, jj + 4, 12, fin_res);
     }
 
-    template
-    void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
+    void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
+        acc_t acc[8];
+        for (int i = 0; i < mc ; i += 16) {
+            for (int j = 0; j < nc; j += 8) {
+                int A0_base = (i / 16) * (2 * 32 * kc);
+                int B0_base = (j / 8) * (32 * kc);
+                for (int x = 0; x < 8; x++) {
+                     __builtin_mma_xxsetaccz(&acc[x]);
+                }
+                for (int64_t kk = 0; kk < kc; kk++) {
+                    int A0_block_idx = A0_base + kk * 32;
+                    int B0_block_idx = B0_base + kk * 32;
+                    int A1_block_idx = A0_block_idx + 32 * kc;
+                    int B1_block_idx = B0_block_idx + 32 * kc;
+                    vec_t * A0_block = & vec_A[A0_block_idx];
+                    vec_t * B0_block = & vec_B[B0_block_idx];
+                    vec_t * A1_block = & vec_A[A1_block_idx];
+                    for (int it = 0; it < 4; it++) {
+                        for (int x = 0; x < 4; x++) {
+                            __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
+                            __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
+                            __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
+                            __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
+                            __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
+                        }
+                    }
+                }
+                if (l == 0) {
+                    save_acc(& acc[0], ii + i, jj + j);
+                    save_acc(& acc[1], ii + i, jj + j + 4);
+                    save_acc(& acc[2], ii + i + 4, jj + j);
+                    save_acc(& acc[3], ii + i + 4, jj + j + 4);
+                    save_acc(& acc[4], ii + i + 8, jj + j);
+                    save_acc(& acc[5], ii + i + 8, jj + j + 4);
+                    save_acc(& acc[6], ii + i + 12, jj + j);
+                    save_acc(& acc[7], ii + i + 12, jj + j + 4);
+                } else {
+                    add_save_acc(& acc[0], ii + i, jj + j);
+                    add_save_acc(& acc[1], ii + i, jj + j + 4);
+                    add_save_acc(& acc[2], ii + i + 4, jj + j);
+                    add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
+                    add_save_acc(& acc[4], ii + i + 8, jj + j);
+                    add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
+                    add_save_acc(& acc[6], ii + i + 12, jj + j);
+                    add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
+                }
+            }
+        }
+    }
+
+    void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
+        vec_t A_pack[mc * kc * 4];
+        vec_t B_pack[nc * kc * 4];
+        constexpr bool is_Ablock_q4 = std::is_same_v;
+        int64_t ytiles = m / mc;
+        int64_t xtiles = n / nc;
+        int64_t tiles  = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles) {
+            end = tiles;
+        }
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = (job / xtiles) * mc;
+            int64_t jj = (job % xtiles) * nc;
+            for (int64_t kk = 0; kk < k; kk += kc) {
+                if constexpr(is_Ablock_q4) {
+                    packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
+                } else {
+                    packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
+                }
+                packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
+                KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
+            }
+        }
+    }
+
+    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
         int64_t ytiles = (m - m0) / RM;
         int64_t xtiles = (n - n0) / RN;
         int64_t tiles = xtiles * ytiles;
@@ -2737,32 +3086,32 @@ class tinyBLAS_BF16_PPC {
             vector float fin_res[4] = {0};
             vector float vs[4] = {0};
             vector float CA[4] = {0};
-            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
-            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
             for (int l = 0; l < k; l++) {
-                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
-                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
-                __builtin_mma_xxsetaccz(&acc_0);
+                __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_mma_xxsetaccz(& acc_0);
                 if (isAblock_q4) {
-                   packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
+                    packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
                 } else {
-                   packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                    packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
                 }
-                packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
-                for(int x = 0; x < 8; x+=4) {
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
-                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
+                packNormal((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
+                for (int x = 0; x < 8; x += 4) {
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
+                    __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
                 }
-                for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d));
+                for (int I = 0; I < RM; I++) {
+                    for (int J = 0; J < RN; J++) {
+                        *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
                     }
                 }
-                __builtin_mma_disassemble_acc(vec_C, &acc_0);
+                __builtin_mma_disassemble_acc(vec_C, & acc_0);
                 if (!isAblock_q4) {
-                    auto aoffset = A+(ii*lda)+l;
+                    auto aoffset = A + (ii * lda) + l;
                     for (int i = 0; i < RM; i++) {
                         comparray[i] = 0;
                         int ca = 0;
@@ -2783,9 +3132,21 @@ class tinyBLAS_BF16_PPC {
         }
     }
 
-    template
+    template
+    inline void kernel(int64_t ii, int64_t jj) {
+        if constexpr(RM == 4 && RN == 8) {
+            KERNEL_4x8(ii,jj);
+        } else if constexpr(RM == 8 && RN == 4) {
+            KERNEL_8x4(ii,jj);
+        } else if constexpr(RM == 8 && RN == 8) {
+            KERNEL_8x8(ii,jj);
+        } else {
+            assert(false && "RN/RM values not supported");
+        }
+    }
+
     template 
-    NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
         int64_t ytiles = (m - m0) / RM;
         int64_t xtiles = (n - n0) / RN;
         int64_t tiles = xtiles * ytiles;
@@ -2797,12 +3158,20 @@ class tinyBLAS_BF16_PPC {
         for (int64_t job = start; job < end; ++job) {
             int64_t ii = m0 + job / xtiles * RM;
             int64_t jj = n0 + job % xtiles * RN;
-            this->kernel(ii, jj);
+            kernel(ii, jj);
         }
     }
-
-template class tinyBLAS_Q0_PPC;
-template class tinyBLAS_Q0_PPC;
+    const TA * const A;
+    const block_q8_0 * const B;
+    float * C;
+    const int64_t k;
+    int64_t kc;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
 
 class tinyBLAS_PPC {
   public:
@@ -3418,16 +3787,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             return tb.matmul(m, n);
         }
 #elif defined(__MMA__)
-        if ((k % 8))
-                return false;
-        if(Btype == GGML_TYPE_BF16) {
-           tinyBLAS_BF16_PPC tb{ k,
-            (const ggml_bf16_t *)A, lda,
-            (const ggml_bf16_t *)B, ldb,
-            (float *)C, ldc,
-            params->ith, params->nth};
-        tb.matmul(m, n);
-        return true;
+        if (k % 8) {
+            return false;
+        }
+
+        if (Btype == GGML_TYPE_BF16) {
+            tinyBLAS_HP16_PPC tb{ k,
+                (const ggml_bf16_t *)A, lda,
+                (const ggml_bf16_t *)B, ldb,
+                (float *)C, ldc,
+                params->ith, params->nth };
+
+            tb.matmul(m, n);
+            return true;
         }
 #elif defined(__riscv_zvfbfwma)
         #if LMUL == 1
@@ -3516,6 +3888,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
         #endif
             return tb.matmul(m, n);
         }
+#elif defined(__MMA__)
+        if (k % 8) {
+            return false;
+        }
+
+        if (Btype == GGML_TYPE_F16) {
+            tinyBLAS_HP16_PPC tb{ k,
+                (const ggml_fp16_t *)A, lda,
+                (const ggml_fp16_t *)B, ldb,
+                (float *)C, ldc,
+                params->ith, params->nth };
+
+            tb.matmul(m, n);
+            return true;
+        }
 #endif
         return false;
     }
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 30327839..3f85e531 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -3,14 +3,14 @@
 #include "ggml-cpu.h"
 #include "ggml-impl.h"
 #include "binary-ops.h"
+#include "simd-gemm.h"
 #include "ggml.h"
 #include "unary-ops.h"
 #include "vec.h"
 
-#include 
 #include 
+#include 
 #include 
-#include 
 
 // ggml_compute_forward_dup
 
@@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes(
         const size_t rs = ne00 * type_size;
 
         if (nb00 == type_size) {
-            // src0 is contigous on first dimension, copy by rows
+            // src0 is contiguous on first dimension, copy by rows
             for (int64_t i03 = 0; i03 < ne03; i03++) {
                 for (int64_t i02 = 0; i02 < ne02; i02++) {
                     id += rs * ir0;
@@ -670,6 +670,7 @@ void ggml_compute_forward_add(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1119,6 +1120,7 @@ void ggml_compute_forward_add1(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1247,6 +1249,7 @@ void ggml_compute_forward_acc(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1795,7 +1798,7 @@ void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, dst);
             } break;
-        // TODO: templateify the implemenation and support for I64
+        // TODO: templateify the implementation and support for I64
         //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
         //case GGML_TYPE_I64:
         //    {
@@ -2097,10 +2100,14 @@ static void ggml_compute_forward_gelu_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2114,19 +2121,23 @@ static void ggml_compute_forward_gelu_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
             GGML_UNUSED(x);
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2136,10 +2147,14 @@ static void ggml_compute_forward_gelu_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2153,20 +2168,24 @@ static void ggml_compute_forward_gelu_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
             const float v = GGML_CPU_FP16_TO_FP32(x);
             GGML_UNUSED(v);
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2277,10 +2296,14 @@ static void ggml_compute_forward_gelu_erf_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2294,19 +2317,23 @@ static void ggml_compute_forward_gelu_erf_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_erf_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
             GGML_UNUSED(x);
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2316,10 +2343,14 @@ static void ggml_compute_forward_gelu_erf_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2333,20 +2364,24 @@ static void ggml_compute_forward_gelu_erf_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_erf_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
             const float v = GGML_CPU_FP16_TO_FP32(x);
             GGML_UNUSED(v);
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2380,10 +2415,14 @@ static void ggml_compute_forward_gelu_quick_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2397,19 +2436,23 @@ static void ggml_compute_forward_gelu_quick_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_quick_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
             GGML_UNUSED(x);
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2419,10 +2462,14 @@ static void ggml_compute_forward_gelu_quick_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2436,20 +2483,24 @@ static void ggml_compute_forward_gelu_quick_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_gelu_quick_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
             const float v = GGML_CPU_FP16_TO_FP32(x);
             GGML_UNUSED(v);
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2483,10 +2534,14 @@ static void ggml_compute_forward_silu_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2500,19 +2555,23 @@ static void ggml_compute_forward_silu_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_silu_f32(nc,
-                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k];
             GGML_UNUSED(x);
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2522,10 +2581,14 @@ static void ggml_compute_forward_silu_f16(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    assert(ggml_is_contiguous_1(src0));
-    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_is_contiguous_rows(src0));
     assert(ggml_are_same_shape(src0, dst));
 
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -2539,20 +2602,24 @@ static void ggml_compute_forward_silu_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int i3 = ir/(ne02*ne01);
+        const int i2 = (ir - i3*ne02*ne01)/ne01;
+        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
+
         ggml_vec_silu_f16(nc,
-                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
+                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k];
             const float v = GGML_CPU_FP16_TO_FP32(x);
             GGML_UNUSED(v);
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2702,7 +2769,7 @@ static void ggml_compute_forward_silu_back_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2738,7 +2805,7 @@ static void ggml_compute_forward_silu_back_f16(
                 (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
                 (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
 
-    #ifndef NDEBUG
+#ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
             const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
             const float v = GGML_CPU_FP16_TO_FP32(x);
@@ -2746,7 +2813,7 @@ static void ggml_compute_forward_silu_back_f16(
             assert(!isnan(v));
             assert(!isinf(v));
         }
-    #endif
+#endif // NDEBUG
     }
 }
 
@@ -2829,7 +2896,7 @@ static void ggml_compute_forward_reglu_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2889,7 +2956,7 @@ static void ggml_compute_forward_reglu_f16(
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -2972,7 +3039,7 @@ static void ggml_compute_forward_geglu_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3032,7 +3099,7 @@ static void ggml_compute_forward_geglu_f16(
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3115,7 +3182,7 @@ static void ggml_compute_forward_swiglu_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3175,7 +3242,7 @@ static void ggml_compute_forward_swiglu_f16(
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3266,7 +3333,7 @@ static void ggml_compute_forward_swiglu_oai_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3345,7 +3412,7 @@ static void ggml_compute_forward_geglu_erf_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3405,7 +3472,7 @@ static void ggml_compute_forward_geglu_erf_f16(
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3488,7 +3555,7 @@ static void ggml_compute_forward_geglu_quick_f32(
             assert(!isnan(x));
             assert(!isinf(x));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -3548,7 +3615,7 @@ static void ggml_compute_forward_geglu_quick_f16(
             assert(!isnan(v));
             assert(!isinf(v));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -4270,6 +4337,7 @@ void ggml_compute_forward_out_prod(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4545,6 +4613,7 @@ void ggml_compute_forward_set(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4767,6 +4836,7 @@ void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5239,7 +5309,7 @@ static void ggml_compute_forward_soft_max_f32(
                     //printf("p[%d] = %f\n", i, p[i]);
                     assert(!isnan(wp[i]));
                 }
-#endif
+#endif // NDEBUG
 
                 float max = -INFINITY;
                 ggml_vec_max_f32(ne00, &max, wp);
@@ -5264,7 +5334,7 @@ static void ggml_compute_forward_soft_max_f32(
                     assert(!isnan(dp[i]));
                     assert(!isinf(dp[i]));
                 }
-#endif
+#endif // NDEBUG
             }
         }
     }
@@ -5338,7 +5408,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
             assert(!isnan(dy[i]));
             assert(!isnan(y[i]));
         }
-#endif
+#endif // NDEBUG
         // Jii = yi - yi*yi
         // Jij = -yi*yj
         // J = diag(y)-y.T*y
@@ -5371,7 +5441,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32(
             assert(!isnan(dx[i]));
             assert(!isinf(dx[i]));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
@@ -5491,6 +5561,7 @@ void ggml_compute_forward_clamp(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5739,28 +5810,33 @@ static void ggml_compute_forward_rope_flt(
 
     const int32_t * pos = (const int32_t *) src1->data;
 
+    int64_t last_i2 = -1;
+
     for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
         for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
-
-            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!mrope_used) {
-                const int64_t p = pos[i2];
-                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-            else {
-                const int64_t p_t = pos[i2];
-                const int64_t p_h = pos[i2 + ne2];
-                const int64_t p_w = pos[i2 + ne2 * 2];
-                const int64_t p_e = pos[i2 + ne2 * 3];
-                ggml_mrope_cache_init(
-                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
-                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-
             for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
-                if (ir++ < ir0) continue;
+                if (ir++ < ir0) continue; // skip rows mapped to other threads
                 if (ir   > ir1) break;
 
+                float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+                if (last_i2 != i2) {
+                    if (!mrope_used) {
+                        const int64_t p = pos[i2];
+                        ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+                    }
+                    else {
+                        const int64_t p_t = pos[i2];
+                        const int64_t p_h = pos[i2 + ne2];
+                        const int64_t p_w = pos[i2 + ne2 * 2];
+                        const int64_t p_e = pos[i2 + ne2 * 3];
+                        ggml_mrope_cache_init(
+                            p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
+                            freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+                    }
+
+                    last_i2 = i2;
+                }
+
                 T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
                 T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);
 
@@ -6129,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
     const ggml_tensor * src1 = dst->src[1];
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F16);
 
     GGML_TENSOR_BINARY_OP_LOCALS;
@@ -6160,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
     int ofs1 = is_2D ? nb12 : nb11;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
 
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
@@ -6173,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
 
                         // micro kernel
                         ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+                        const float * const src_data_f32 = src1->type == GGML_TYPE_F32
+                            ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
+                            : nullptr; // [IH, IW]
+                        const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
+                            ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
+                            : nullptr; // [IH, IW]
 
                         for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
                             for (int64_t ikw = 0; ikw < KW; ikw++) {
@@ -6183,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
                                 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
                                     dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
                                 } else {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                    if (src_data_f32 != nullptr) {
+                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
+                                    } else {
+                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
+                                    }
                                 }
                             }
                         }
@@ -7110,12 +7195,13 @@ void ggml_compute_forward_conv_2d_dw(
     }
 }
 
-// ggml_compute_forward_pool_1d_sk_p0
-
-static void ggml_compute_forward_pool_1d_sk_p0(
+// ggml_compute_forward_pool_1d_ksp
+static void ggml_compute_forward_pool_1d_ksp(
         const ggml_compute_params * params,
         const ggml_op_pool op,
         const int k,
+        const int s,
+        const int p,
         ggml_tensor * dst) {
 
     const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7212,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
         return;
     }
 
-    const char * cdata = (const char *)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-    float * drow = (float *)dst->data;
+    const int64_t IW = src->ne[0];
+    const int64_t OW = dst->ne[0];
 
-    const int64_t rs = dst->ne[0];
+    const int64_t nr = ggml_nrows(src);
 
-    while (cdata < data_end) {
-        const void * srow = (const void *)cdata;
-        int j = 0;
-        for (int64_t i = 0; i < rs; ++i) {
+    for (int64_t ir = 0; ir < nr; ++ir) {
+        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];
+        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);
+
+        for (int64_t ow = 0; ow < OW; ++ow) {
+            float res = 0;
             switch (op) {
-                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
-                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
+                case GGML_OP_POOL_AVG: res = 0.0f;     break;
+                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
+
+            int count = 0;
+            const int base = (int) ow * s - p;
+
             for (int ki = 0; ki < k; ++ki) {
-                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
-                switch (op) {
-                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
-                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
-                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
+                const int j = base + ki;
+                if (j < 0 || j >= (int) IW) {
+                    continue;
                 }
-                ++j;
+
+                float v;
+                if (src->type == GGML_TYPE_F32) {
+                    v = ((const float *) srow_bytes)[j];
+                } else {
+                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
+                }
+
+                switch (op) {
+                    case GGML_OP_POOL_AVG: res += v;                break;
+                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+
+                ++count;
             }
+
             switch (op) {
-                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
-                case GGML_OP_POOL_MAX:                       break;
+                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
+                case GGML_OP_POOL_MAX:                                           break;
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
-        }
 
-        cdata += src->nb[1];
-        drow  += rs;
+            drow[ow] = res;
+        }
     }
 }
 
@@ -7173,10 +7276,8 @@ void ggml_compute_forward_pool_1d(
     const int k0 = opts[1];
     const int s0 = opts[2];
     const int p0 = opts[3];
-    GGML_ASSERT(p0 == 0); // padding not supported
-    GGML_ASSERT(k0 == s0); // only s = k supported
 
-    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
 }
 
 // ggml_compute_forward_pool_2d
@@ -7194,6 +7295,7 @@ void ggml_compute_forward_pool_2d(
     }
 
     const int32_t * opts = (const int32_t *)dst->op_params;
+
     ggml_op_pool op = static_cast(opts[0]);
     const int k0 = opts[1];
     const int k1 = opts[2];
@@ -7217,11 +7319,13 @@ void ggml_compute_forward_pool_2d(
     while (cdata < data_end) {
         for (int oy = 0; oy < py; ++oy) {
             float * const drow = dplane + oy * px;
+            float * const out  = drow;
+
             for (int ox = 0; ox < px; ++ox) {
-                float * const out =  drow + ox;
+                float res = 0;
                 switch (op) {
-                    case GGML_OP_POOL_AVG:     *out = 0;        break;
-                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
+                    case GGML_OP_POOL_AVG: res = 0;        break;
+                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
                     case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
 
@@ -7229,24 +7333,32 @@ void ggml_compute_forward_pool_2d(
                 const int iy = offset1 + oy * s1;
 
                 for (int ky = 0; ky < k1; ++ky) {
-                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {
+                        continue;
+                    }
+
                     const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
                     for (int kx = 0; kx < k0; ++kx) {
                         int j = ix + kx;
-                        if (j < 0 || j >= src->ne[0]) continue;
+                        if (j < 0 || j >= src->ne[0]) {
+                            continue;
+                        }
+
                         const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
                         switch (op) {
-                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
-                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_AVG: res += srow_j;                break;
+                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;
                             case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
                         }
                     }
                 }
                 switch (op) {
-                    case GGML_OP_POOL_AVG:           *out /= ka; break;
-                    case GGML_OP_POOL_MAX:                       break;
+                    case GGML_OP_POOL_AVG:           res /= ka; break;
+                    case GGML_OP_POOL_MAX:                      break;
                     case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
+
+                out[ox] = res;
             }
         }
 
@@ -7603,8 +7715,7 @@ static void ggml_compute_forward_pad_f32(
 
     const ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(src0->nb[0] == sizeof(float));
-    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -8016,12 +8127,14 @@ void ggml_compute_forward_top_k(
     }
 }
 
-// ggml_compute_forward_flash_attn_ext
-
 static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         const ggml_compute_params * params,
         ggml_tensor * dst,
-        int ir0, int ir1) {
+        int ir0, int ir1,
+        int64_t ic_start, int64_t ic_end,
+        float * partials, int64_t partial_stride) {
+
+    const bool write_partials = (partials != nullptr);
     const ggml_tensor * q     = dst->src[0];
     const ggml_tensor * k     = dst->src[1];
     const ggml_tensor * v     = dst->src[2];
@@ -8098,7 +8211,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
     int ith = params->ith;
 
-    // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
         const int iq3 = ir/(neq2*neq1);
@@ -8138,7 +8250,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         // online softmax / attention
         // loop over n_kv and n_head_kv
         // ref: https://arxiv.org/pdf/2112.05682.pdf
-        for (int64_t ic = 0; ic < nek1; ++ic) {
+
+        for (int64_t ic = ic_start; ic < ic_end; ++ic) {
             const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
             if (mv == -INFINITY) {
                 continue;
@@ -8211,8 +8324,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
             }
         }
 
-        // sinks
-        if (sinks) {
+        // sinks - apply only on the first kv-chunk
+        if (sinks && ic_start == 0) {
             const float s = ((float *)((char *) sinks->data))[h];
 
             float ms = 1.0f;
@@ -8220,6 +8333,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
             if (s > M) {
                 ms = expf(M - s);
+                M = s;
                 ggml_vec_scale_f32(DV, VKQ32, ms);
             } else {
                 vs = expf(s - M);
@@ -8228,30 +8342,38 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
             S = S*ms + vs;
         }
 
-        // V /= S
-        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
-        ggml_vec_scale_f32(DV, VKQ32, S_inv);
+        if (write_partials) {
+            // Write M, S, VKQ to partials for later reduction
+            // partials layout: [M, S, VKQ[DV]] per query head
+            float * partial = partials + ir * partial_stride;
+            partial[0] = M;
+            partial[1] = S;
+            memcpy(partial + 2, VKQ32, DV * sizeof(float));
+        } else {
+            // V /= S
+            const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+            ggml_vec_scale_f32(DV, VKQ32, S_inv);
 
-        // dst indices
-        const int i1 = iq1;
-        const int i2 = iq2;
-        const int i3 = iq3;
+            // dst indices
+            const int i1 = iq1;
+            const int i2 = iq2;
+            const int i3 = iq3;
 
-        // original
-        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
-
-        // permute(0, 2, 1, 3)
-        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+            // permute(0, 2, 1, 3)
+            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+        }
     }
 }
 
-static void ggml_compute_forward_flash_attn_ext_f16(
+static void ggml_compute_forward_flash_attn_ext_tiled(
         const ggml_compute_params * params,
-        ggml_tensor * dst) {
-
+        ggml_tensor * dst,
+        int ir0, int ir1) {
     const ggml_tensor * q     = dst->src[0];
     const ggml_tensor * k     = dst->src[1];
     const ggml_tensor * v     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
 
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
@@ -8286,47 +8408,451 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // parallelize by q rows using ggml_vec_dot_f32
+    GGML_ASSERT(k->type == v->type);
+    const ggml_type kv_type = k->type;
 
-    // total rows in q
-    const int64_t nr = neq1*neq2*neq3;
 
-    // rows per thread
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    int ith = params->ith;
+
+    static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;
+    static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
+
+    int ir = ir0;
+    while (ir < ir1) {
+        // q indices for the start of this tile
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        // Number of valid rows in this tile:
+        // - limited by tile size (Q_TILE_SZ)
+        // - limited by chunk boundary (ir1 - ir)
+        // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
+        const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
+        GGML_ASSERT(tile_rows > 0);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float S[Q_TILE_SZ];
+        float M[Q_TILE_SZ];
+
+        for (int i = 0 ; i < Q_TILE_SZ; ++i) {
+            S[i] = 0.;
+            M[i] = -INFINITY;
+        }
+
+        // Per-thread scratch layout:
+        // Q_q:    Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
+        // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
+        // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)
+        // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)
+        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile)
+        // K_f32:  KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
+        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
+
+        void  * Q_q    = base;
+        float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
+        float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
+        float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;
+        float * V32    = VKQ32 + Q_TILE_SZ * DV;
+        float * K_f32  = V32 + KV_TILE_SZ * DV;
+
+        memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
+        memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        {
+            float * Q_f32 = (float *)Q_q;
+            for (int tq = 0; tq < tile_rows; tq++) {
+                const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
+                memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
+            }
+            for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
+                memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
+            }
+        }
+
+        memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
+        memset(V32,   0, KV_TILE_SZ * DV * sizeof(float));
+
+        for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
+            const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
+
+            // skip the tile entirely if all the masks are -inf
+            if (mask) {
+                bool can_skip = true;
+                for (int tq = 0; tq < tile_rows; tq++) {
+                    const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
+                    for (int tk = 0; tk < kv_tile; tk++) {
+                        mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
+                        if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
+                            can_skip = false;
+                        }
+                    }
+                    // Pad remaining mask entries with -inf
+                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
+                        mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
+                    }
+                }
+
+                if (can_skip) {
+                    continue;
+                }
+            }
+
+            // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
+            // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
+            for (int tk = 0; tk < kv_tile; tk++) {
+                const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
+                if (kv_type == GGML_TYPE_F16) {
+                    const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
+                    for (int64_t dk = 0; dk < DK; dk++) {
+                        K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
+                    }
+                } else {
+                    const float * k_f32_src = (const float *)k_data;
+                    for (int64_t dk = 0; dk < DK; dk++) {
+                        K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
+                    }
+                }
+            }
+            memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+            simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
+            ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
+
+            // Set padded KQ entries to -inf so softmax gives them zero weight
+            if (kv_tile < KV_TILE_SZ) {
+                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                    for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
+                        KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
+                    }
+                }
+            }
+
+            if (logit_softcap != 0.0f) {
+                ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
+                ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
+            }
+
+            if (mask) {
+                ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
+            }
+
+            bool skip[Q_TILE_SZ] = {};
+
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                float * kq_row = KQ + tq * KV_TILE_SZ;
+
+                float tile_max;
+                ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
+
+                if (tile_max == -INFINITY) {
+                    skip[tq] = true;
+                    continue;
+                }
+
+                const float Mold = M[tq];
+                const float Mnew = fmaxf(Mold, tile_max);
+
+                if (Mnew > Mold) {
+                    const float ms = expf(Mold - Mnew);
+                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+                    S[tq] *= ms;
+                }
+                M[tq] = Mnew;
+
+
+                S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
+            }
+
+            // V accumulation: VKQ32 += softmax(KQ) * V
+            // Pack V tile to contiguous F32, zero-padded
+            for (int tk = 0; tk < kv_tile; tk++) {
+                const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
+                if (kv_type == GGML_TYPE_F16) {
+                    ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
+                } else {
+                    memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
+                }
+            }
+            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+                if (skip[tq]) {
+                    memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
+                }
+            }
+            simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
+        }
+
+        // sinks (apply only to valid rows in the tile)
+        if (sinks) {
+            const float s = ((float *)((char *) sinks->data))[h];
+
+            for (int tq = 0; tq < tile_rows; tq++) {
+                float ms = 1.0f;
+                float vs = 1.0f;
+
+                if (s > M[tq]) {
+                    ms = expf(M[tq] - s);
+                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+                } else {
+                    vs = expf(s - M[tq]);
+                }
+
+                S[tq] = S[tq] * ms + vs;
+            }
+        }
+
+        for (int tq = 0; tq < tile_rows; tq++) {
+            // V /= S
+            const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
+            ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
+
+            // dst indices
+            const int i1 = iq1 + tq;
+            const int i2 = iq2;
+            const int i3 = iq3;
+
+            // permute(0, 2, 1, 3)
+            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
+        }
+
+        ir += tile_rows;
+    }
+}
+
+// Reduction function: combines partial results across KV chunks
+// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
+static void ggml_flash_attn_ext_reduce_partials(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        const int64_t n_chunks,
+        const int64_t chunk_size) {
+
+    const ggml_tensor * q = dst->src[0];
+    const ggml_tensor * k = dst->src[1];
+    const ggml_tensor * v = dst->src[2];
+
+    const int64_t DK        = k->ne[0];
+    const int64_t DV        = v->ne[0];
+    const int64_t nek1      = k->ne[1];
+    const int64_t n_q_heads = q->ne[2];
+
     const int ith = params->ith;
     const int nth = params->nth;
 
-    // disable for NUMA
-    const bool disable_chunking = ggml_is_numa();
+    const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
+    float *       thread_wdata     = (float *) params->wdata + ith * wdata_per_thread;
 
-    // 4x chunks per thread
-    int nth_scaled = nth * 4;
-    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
-    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+    const int64_t partials_offset  = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
+    const int64_t partial_size     = 2 + DV;
+    const float * partials_base    = (const float *) params->wdata + partials_offset;
 
-    if (nth == 1 || nchunk < nth || disable_chunking) {
-        nchunk = nth;
+    // Output layout
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const size_t  nb1 = dst->nb[1];
+
+    // Each thread reduces a subset of query heads
+    for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
+        float   M_final   = -INFINITY;
+        float   S_final   = 0.0f;
+        float * VKQ_final = thread_wdata;
+        memset(VKQ_final, 0, DV * sizeof(float));
+
+        // Combine partials from all chunks
+        for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
+            const int64_t ic_start = chunk_idx * chunk_size;
+            if (ic_start >= nek1) continue;
+
+            const float * partial   = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
+            const float   M_chunk   = partial[0];
+            const float   S_chunk   = partial[1];
+            const float * VKQ_chunk = partial + 2;
+
+            if (S_chunk == 0.0f) continue;
+
+            const float M_new     = fmaxf(M_final, M_chunk);
+            const float scale_old = expf(M_final - M_new);
+            const float scale_new = expf(M_chunk - M_new);
+
+            for (int64_t d = 0; d < DV; ++d) {
+                VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
+            }
+            S_final = S_final * scale_old + S_chunk * scale_new;
+            M_final = M_new;
+        }
+
+        // Normalize and write to output
+        if (S_final != 0.0f) {
+            const float S_inv = 1.0f / S_final;
+            ggml_vec_scale_f32(DV, VKQ_final, S_inv);
+        }
+        // iq1=0, iq3=0 for decode
+        memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
     }
+}
 
-    if (ith == 0) {
-        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        ggml_threadpool_chunk_set(params->threadpool, nth);
-    }
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
 
-    ggml_barrier(params->threadpool);
+    const ggml_tensor * q     = dst->src[0];
+    const ggml_tensor * k     = dst->src[1];
+    const ggml_tensor * v     = dst->src[2];
 
-    // The number of elements in each chunk
-    const int64_t dr = (nr + nchunk - 1) / nchunk;
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
-    // The first chunk comes from our thread_id, the rest will get auto-assigned.
-    int current_chunk = ith;
+    const int64_t DK = nek0;
+    const int64_t DV = nev0;
+    const int64_t N  = neq1;
 
-    while (current_chunk < nchunk) {
-        const int64_t ir0 = dr * current_chunk;
-        const int64_t ir1 = MIN(ir0 + dr, nr);
 
-        ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+    GGML_ASSERT(ne0 == DV);
+    GGML_ASSERT(ne2 == N);
 
-        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == DK);
+    GGML_ASSERT(nek0 == DK);
+    GGML_ASSERT(nev0 == DV);
+
+    GGML_ASSERT(neq1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
+    const bool use_ref = params->use_ref;
+
+    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
+    const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
+
+    if (use_split_kv_path) {
+        const int64_t chunk_size = (nek1 + nth - 1) / nth;
+
+        // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
+        const int64_t partial_size  = 2 + DV;
+        float *       partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
+
+        const int64_t ic_start = ith * chunk_size;
+        const int64_t ic_end   = std::min(ic_start + chunk_size, nek1);
+
+        const int64_t partial_stride = nth * partial_size;
+        float *       chunk_partials = partials_base + ith * partial_size;
+
+        if (ic_start < nek1) {
+            for (int64_t q_head = 0; q_head < neq2; q_head++) {
+                ggml_compute_forward_flash_attn_ext_f16_one_chunk(
+                    params, dst, q_head, q_head + 1, ic_start, ic_end,
+                    chunk_partials, partial_stride);
+            }
+        } else {
+            for (int64_t q_head = 0; q_head < neq2; q_head++) {
+                float * q_partials = chunk_partials + q_head * partial_stride;
+                q_partials[0] = -INFINITY;  // M
+                q_partials[1] = 0.0f;       // S
+            }
+        }
+
+        ggml_barrier(params->threadpool);
+        ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
+    } else {
+
+        // total rows in q
+        const int64_t nr = neq1*neq2*neq3;
+
+        // disable for NUMA
+        const bool disable_chunking = ggml_is_numa();
+
+        // 4x chunks per thread
+        int nth_scaled = nth * 4;
+        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+
+        if (nth == 1 || nchunk < nth || disable_chunking) {
+            nchunk = nth;
+        }
+
+        if (ith == 0) {
+            ggml_threadpool_chunk_set(params->threadpool, nth);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+        static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
+        bool use_tiled = !use_ref &&
+                               (q->type == GGML_TYPE_F32 &&
+                                kv_is_f32_or_f16 &&
+                                k->type == v->type &&
+                                neq1 >= Q_TILE_SZ);
+#ifdef GGML_SIMD
+        use_tiled &= (DV % GGML_F32_EPR == 0);
+#endif
+        int current_chunk = ith;
+
+        while (current_chunk < nchunk) {
+            const int64_t ir0 = dr * current_chunk;
+            const int64_t ir1 = MIN(ir0 + dr, nr);
+
+            if (use_tiled) {
+                ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
+            } else {
+                ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
+            }
+
+            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+        }
     }
 }
 
@@ -9107,7 +9633,7 @@ void ggml_compute_forward_win_unpart(
     }
 }
 
-//gmml_compute_forward_unary
+//ggml_compute_forward_unary
 
 void ggml_compute_forward_unary(
         const ggml_compute_params * params,
@@ -9870,6 +10396,195 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
     }
 }
 
+// ggml_compute_forward_gated_delta_net
+static void ggml_compute_forward_gated_delta_net_one_chunk(
+    const ggml_compute_params * params,
+    ggml_tensor * dst,
+    int64_t ir0,
+    int64_t ir1) {
+
+    ggml_tensor * src_q     = dst->src[0];
+    ggml_tensor * src_k     = dst->src[1];
+    ggml_tensor * src_v     = dst->src[2];
+    ggml_tensor * src_g     = dst->src[3];
+    ggml_tensor * src_beta  = dst->src[4];
+    ggml_tensor * src_state = dst->src[5];
+
+    const int64_t S_v      = src_v->ne[0];
+    const int64_t H        = src_v->ne[1];
+    const int64_t n_tokens = src_v->ne[2];
+    const int64_t n_seqs   = src_v->ne[3];
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+    GGML_ASSERT(ggml_is_contiguous(src_g));
+    GGML_ASSERT(ggml_is_contiguous(src_beta));
+    GGML_ASSERT(ggml_is_contiguous(src_state));
+
+    GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
+    GGML_ASSERT(src_beta->ne[0] == 1);
+
+    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbk, src_k, nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbg, src_g, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
+
+    const bool kda = (neg0 == S_v);
+
+    // scratch layout per thread: [delta(S_v)]
+    const int64_t scratch_per_thread = S_v;
+    const int ith = params->ith;
+
+    float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
+
+    // output layout: [attn_scores | new_states]
+    // attn_scores: S_v * H * n_tokens * n_seqs floats
+    // new_states:  S_v * S_v * H * n_seqs floats
+    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+    float * attn_out_base  = (float *)dst->data;
+    float * state_out_base = (float *)dst->data + attn_score_elems;
+
+    const float * state_in_base = (const float *)src_state->data;
+
+  //const int64_t rq1 = nev1 / neq1;
+  //const int64_t rk1 = nev1 / nek1;
+    const int64_t rq3 = nev3 / neq3;
+    const int64_t rk3 = nev3 / nek3;
+
+    const float scale = 1.0f / sqrtf((float) S_v);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t iv1 = ir % H; // head_index
+        const int64_t iv3 = ir / H; // sequence
+
+        const int64_t iq1 = iv1 % neq1;
+        const int64_t ik1 = iv1 % nek1;
+
+        const int64_t iq3 = iv3 / rq3;
+        const int64_t ik3 = iv3 / rk3;
+
+        float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
+
+        // copy input state into output buffer and operate in-place
+        const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
+        memcpy(s_out, s_in, S_v * S_v * sizeof(float));
+
+        // attn output pointer for first token of this (head, seq)
+        float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
+
+        for (int64_t t = 0; t < n_tokens; t++) {
+            const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
+            const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
+            const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
+
+            const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
+            const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+
+            // state is stored transposed: s_out[j*S_v + i] = S[i][j]
+            // so row j of s_out = column j of S (contiguous access)
+
+            if (kda) {
+                // precompute exp(g) into delta scratch (reused below)
+                for (int64_t i = 0; i < S_v; ++i) {
+                    delta[i] = expf(g_d[i]);
+                }
+                // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
+                for (int64_t j = 0; j < S_v; ++j) {
+                    ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
+                }
+            } else {
+                ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
+            }
+
+            // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
+            for (int64_t j = 0; j < S_v; ++j) {
+                float sum = 0.0f;
+                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
+                delta[j] = (v_d[j] - sum) * beta_val;
+            }
+
+            // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
+            for (int64_t j = 0; j < S_v; ++j) {
+                ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
+            }
+
+            // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
+            for (int64_t j = 0; j < S_v; ++j) {
+                float sum = 0.0f;
+                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
+                attn_data[j] = sum * scale;
+            }
+
+            attn_data += S_v * H; // advance to next token
+        }
+    }
+}
+
+
+static void ggml_compute_forward_gated_delta_net_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    ggml_tensor * V = dst->src[2];
+    int64_t nr = V->ne[1] * V->ne[3];
+
+    // disable for NUMA
+    const bool disable_chunking = ggml_is_numa();
+
+    int nth = params->nth;
+    int ith = params->ith;
+
+    // 4x chunks per thread
+    int nth_scaled = nth * 4;
+    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+
+    if (nth == 1 || nchunk < nth || disable_chunking) {
+      nchunk = nth;
+    }
+
+    if (ith == 0) {
+      ggml_threadpool_chunk_set(params->threadpool, nth);
+    }
+
+    ggml_barrier(params->threadpool);
+
+    const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+    int current_chunk = ith;
+
+    while (current_chunk < nchunk) {
+        const int64_t ir0 = dr * current_chunk;
+        const int64_t ir1 = MIN(ir0 + dr, nr);
+
+        ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
+        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+    }
+}
+
+void ggml_compute_forward_gated_delta_net(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gated_delta_net_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_rwkv_wkv7
 
 static void ggml_compute_forward_rwkv_wkv7_f32(
@@ -10195,7 +10910,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
             assert(!isnan(s0[i]));
             assert(!isnan(s1[i]));
         }
-#endif
+#endif // NDEBUG
 
         float max = -INFINITY;
         ggml_vec_max_f32(nc, &max, s0);
@@ -10214,7 +10929,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
             assert(!isnan(st[i]));
             assert(!isinf(st[i]));
         }
-#endif
+#endif // NDEBUG
     }
     sums[ith] = sum_thread;
     ggml_barrier(params->threadpool);
@@ -10287,7 +11002,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
             assert(!isnan(s0[i]));
             assert(!isnan(s1[i]));
         }
-#endif
+#endif // NDEBUG
 
         // soft_max
         float max = -INFINITY;
@@ -10305,7 +11020,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
             assert(!isnan(ds0[i]));
             assert(!isinf(ds0[i]));
         }
-#endif
+#endif // NDEBUG
     }
 }
 
diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h
index 0fdfee79..3fa1443a 100644
--- a/ggml/src/ggml-cpu/ops.h
+++ b/ggml/src/ggml-cpu/ops.h
@@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c
index 365cb36d..7ebbb9c6 100644
--- a/ggml/src/ggml-cpu/quants.c
+++ b/ggml/src/ggml-cpu/quants.c
@@ -50,6 +50,10 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
     quantize_row_mxfp4_ref(x, y, k);
 }
 
+void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+    quantize_row_nvfp4_ref(x, y, k);
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -216,6 +220,42 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
     *s = sumf;
 }
 
+// NVFP4: super-block of 64 elements = 4 sub-blocks of 16 = 2 q8_0 blocks
+void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_NVFP4 == 0);
+
+    const block_nvfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_NVFP4;
+
+    float sumf = 0;
+
+    for (int ib = 0; ib < nb; ++ib) {
+        for (int s_idx = 0; s_idx < 4; ++s_idx) {
+            const float d = ggml_ue4m3_to_fp32(x[ib].d[s_idx]);
+            const int q8_block = s_idx / 2;
+            const int q8_off   = (s_idx % 2) * QK_NVFP4_SUB;
+            const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
+
+            int sumi_lo = 0, sumi_hi = 0;
+            for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
+                const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
+                sumi_lo += y[2*ib + q8_block].qs[q8_off + j +               0] * kvalues_mxfp4[qv & 0xf];
+                sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >>  4];
+            }
+
+            sumf += dy * d * (sumi_lo + sumi_hi);
+        }
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h
index d83eb1b1..3584aaa4 100644
--- a/ggml/src/ggml-cpu/quants.h
+++ b/ggml/src/ggml-cpu/quants.h
@@ -20,6 +20,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
 void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
 void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -42,6 +43,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -73,6 +75,7 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
 void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp
index fbf7ed94..6b76ab3b 100644
--- a/ggml/src/ggml-cpu/repack.cpp
+++ b/ggml/src/ggml-cpu/repack.cpp
@@ -48,6 +48,90 @@ static inline int nearest_int(float fval) {
 
 extern "C" {
 
+#if defined __riscv_zvfh
+void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
+
+    // scalar
+    const int blck_size_interleave = 1;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+}
+
+void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK_K == 256);
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
+
+    const int blck_size_interleave = 1;
+    float srcv[4][QK_K];
+    float iscale[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+            float max = 0;
+
+            for (int j = 0; j < QK_K; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
+                // Update the maximum value of the corresponding super block
+                if(amax < fabsf(srcv[row_iter][j])) {
+                    amax = fabsf(srcv[row_iter][j]);
+                    max = srcv[row_iter][j];
+                }
+            }
+
+            iscale[row_iter] = amax ? -127.f/max : 0;
+            y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
+        }
+
+        for (int j = 0; j < QK_K / 4; j++) {
+            y[i].bsums[j] = 0;
+        }
+        for (int j = 0; j < QK_K * 4; j++) {
+            int src_id = j % 4;
+            int src_offset = j / 4;
+            int index = ((j >> 6) << 2) + (j & 3);
+
+            float x0 = srcv[src_id][src_offset] * iscale[src_id];
+            y[i].qs[j] = nearest_int(x0);
+            y[i].bsums[index] += y[i].qs[j];
+        }
+    }
+}
+#endif
+
 void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK8_0 == 32);
     assert(k % QK8_0 == 0);
@@ -124,7 +208,6 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
     }
 }
 
-
 void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
     assert(QK_K == 256);
     assert(k % QK_K == 0);
@@ -256,6 +339,416 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
     ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
 }
 
+#if defined __riscv_zvfh
+template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row);
+}
+
+template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    ggml_quantize_mat_q8_K_4x1(x, vy, n_per_row);
+}
+#endif
+
+template 
+static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int blocklen          = M;
+    constexpr int ncols_interleaved = N;
+    const int     qk                = QK_K;
+    const int     nb                = n / qk;
+    const int     blocks_per_half   = 64 / blocklen;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0f;
+        }
+
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
+                const int base_h = base_l + 64;
+
+                const int scale_idx_l = base_l / 16;
+                const int scale_idx_h = base_h / 16;
+
+                const int qh_shift_l = ((base_l % 128) / 32) * 2;
+                const int qh_shift_h = ((base_h % 128) / 32) * 2;
+
+                const int qh_half_l = (base_l / 128) * 32;
+                const int qh_half_h = (base_h / 128) * 32;
+
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
+                    const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
+
+                    int sumi_l = 0;
+                    int sumi_h = 0;
+
+                    for (int i = 0; i < blocklen; i++) {
+                        const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
+                        const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
+                        const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
+
+                        const int qh_idx_l    = qh_half_l + ((base_l + i) % 32);
+                        const int qh_chunk_l  = qh_idx_l / blocklen;
+                        const int qh_pos_l    = qh_idx_l % blocklen;
+                        const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
+                        const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
+
+                        const int qh_idx_h    = qh_half_h + ((base_h + i) % 32);
+                        const int qh_chunk_h  = qh_idx_h / blocklen;
+                        const int qh_pos_h    = qh_idx_h % blocklen;
+                        const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
+                        const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
+
+                        const int q_l = ((hi_2_l << 4) | l_4) - 32;
+                        const int q_h = ((hi_2_h << 4) | hi_4) - 32;
+
+                        const int8_t a_l = a_ptr[l].qs[base_l + i];
+                        const int8_t a_h = a_ptr[l].qs[base_h + i];
+
+                        sumi_l += q_l * a_l;
+                        sumi_h += q_h * a_h;
+                    }
+
+                    sumf[j] +=
+                        (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+        }
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+template 
+static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int blocklen          = M;
+    constexpr int ncols_interleaved = N;
+    const int     qk                = QK_K;
+    const int     nb                = n / qk;
+    const int     blocks_per_half   = 64 / blocklen;
+    const int     q8_half_stride    = 512;
+    const int     q8_low_high_step  = 256;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+
+    float sumf[4][8];
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0f;
+                }
+            }
+
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
+                    const int base_h = base_l + 64;
+
+                    const int scale_idx_l = base_l / 16;
+                    const int scale_idx_h = base_h / 16;
+
+                    const int qh_shift_l = ((base_l % 128) / 32) * 2;
+                    const int qh_shift_h = ((base_h % 128) / 32) * 2;
+
+                    const int qh_half_l = (base_l / 128) * 32;
+                    const int qh_half_h = (base_h / 128) * 32;
+
+                    const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
+
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
+                            const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
+
+                            int sumi_l = 0;
+                            int sumi_h = 0;
+
+                            for (int i = 0; i < blocklen; i++) {
+                                const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
+                                const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
+                                const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
+
+                                const int qh_idx_l   = qh_half_l + ((base_l + i) % 32);
+                                const int qh_chunk_l = qh_idx_l / blocklen;
+                                const int qh_pos_l   = qh_idx_l % blocklen;
+                                const int qh_offset_l =
+                                    qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
+                                const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
+
+                                const int qh_idx_h   = qh_half_h + ((base_h + i) % 32);
+                                const int qh_chunk_h = qh_idx_h / blocklen;
+                                const int qh_pos_h   = qh_idx_h % blocklen;
+                                const int qh_offset_h =
+                                    qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
+                                const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
+
+                                const int q_l = ((hi_2_l << 4) | l_4) - 32;
+                                const int q_h = ((hi_2_h << 4) | hi_4) - 32;
+
+                                const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
+                                const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
+
+                                sumi_l += q_l * q8_l;
+                                sumi_h += q_h * q8_h;
+                            }
+
+                            sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
+                                          a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+template 
+static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int         blocklen          = M;
+    constexpr int         ncols_interleaved = N;
+    const int             qk                = QK_K;
+    const int             nb                = n / qk;
+    static const uint32_t kmask1            = 0x3f3f3f3f;
+    static const uint32_t kmask2            = 0x0f0f0f0f;
+    static const uint32_t kmask3            = 0x03030303;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float    sumf[ncols_interleaved];
+    float    sum_minf[ncols_interleaved];
+    uint32_t utmp[32];
+    int      sumi1;
+    int      sumi2;
+    int      sumi;
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j]     = 0.0;
+            sum_minf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int sb = 0; sb < 8; sb++) {
+                memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
+                utmp[sb * 4 + 3]      = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                utmp[sb * 4 + 2]      = uaux_0;
+                utmp[sb * 4 + 0] &= kmask1;
+            }
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                constexpr int scale_stride = 32;
+                uint8_t *     scales_0     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
+                uint8_t *     scales_1     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
+
+                const int qh_shift = (k / (32 / blocklen)) * 2;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi1 = 0;
+                    sumi2 = 0;
+                    sumi  = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+                        const int qh_idx      = (k * blocklen + i) % 32;
+                        const int qh_chunk    = qh_idx / blocklen;
+                        const int qh_pos      = qh_idx % blocklen;
+                        const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
+
+                        const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+                        const uint8_t h0     = (qh_val >> qh_shift) & 1;
+                        const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
+
+                        const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+                        const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+                        const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
+
+                        sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+                        sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
+                        sumi1 = sumi1 * scales_0[j];
+                        sumi2 = sumi2 * scales_1[j];
+                        sumi += sumi1 + sumi2;
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+            for (int sb = 0; sb < 8; sb++) {
+                uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
+                                   GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+        }
+    }
+}
+
+template 
+static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int         blocklen          = M;
+    constexpr int         ncols_interleaved = N;
+    const int             qk                = QK_K;
+    const int             nb                = n / qk;
+    static const uint32_t kmask1            = 0x3f3f3f3f;
+    static const uint32_t kmask2            = 0x0f0f0f0f;
+    static const uint32_t kmask3            = 0x03030303;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float    sumf[4][ncols_interleaved];
+    float    sum_minf[4][ncols_interleaved];
+    uint32_t utmp[32];
+    int      sumi1;
+    int      sumi2;
+    int      sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j]     = 0.0;
+                    sum_minf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int sb = 0; sb < 8; sb++) {
+                    memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
+                    utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
+                    const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
+                    utmp[sb * 4 + 1]      = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
+                    utmp[sb * 4 + 2]      = uaux_0;
+                    utmp[sb * 4 + 0] &= kmask1;
+                }
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    constexpr int scale_stride = 32;
+                    uint8_t *     scales_0     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
+                    uint8_t *     scales_1     = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
+
+                    const int qh_shift = (k / (32 / blocklen)) * 2;
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi1 = 0;
+                            sumi2 = 0;
+                            sumi  = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
+
+                                const int qh_idx   = (k * blocklen + i) % 32;
+                                const int qh_chunk = qh_idx / blocklen;
+                                const int qh_pos   = qh_idx % blocklen;
+                                const int b_qh_offset =
+                                    qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
+
+                                const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
+                                const uint8_t h0     = (qh_val >> qh_shift) & 1;
+                                const uint8_t h1     = (qh_val >> (qh_shift + 1)) & 1;
+
+                                const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
+                                const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
+
+                                const int q8_offset = (k / (32 / blocklen)) * 256 +
+                                                      (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
+
+                                sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
+                                sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
+                                sumi1 = sumi1 * scales_0[j];
+                                sumi2 = sumi2 * scales_1[j];
+                                sumi += sumi1 + sumi2;
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+                for (int sb = 0; sb < 8; sb++) {
+                    uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
+                    for (int m = 0; m < 4; m++) {
+                        const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
+                                              GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+                }
+            }
+        }
+    }
+}
+
 extern "C" {
 
 void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -474,15 +967,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     assert (n % qk == 0);
     assert (nc % ncols_interleaved == 0);
 
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
     UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
     float sumf[8];
     float sum_minf[8];
@@ -616,6 +1102,23 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+
+void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -692,6 +1195,82 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
 void ggml_gemv_q8_0_4x4_q8_0_generic(int                        n,
                                      float * GGML_RESTRICT      s,
                                      size_t                     bs,
@@ -786,6 +1365,294 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int                        n,
     }
 }
 
+#if defined __riscv_zvfh
+void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+    float sumf[16];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+    float sumf[16];
+    float sum_minf[16];
+    uint8_t scales[128];
+    uint8_t mins[128];
+    int sumi1;
+    int sumi2;
+    int sumi;
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0f;
+            sum_minf[j] = 0.0f;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int i = 0; i < 128; i++) {
+                scales[i] = b_ptr[l].scales[i] & 0x0F;
+                mins[i] = b_ptr[l].scales[i] >> 4;
+            }
+            for (int i = 0; i < 64; i++) {
+                scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4;
+                mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2;
+                scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30);
+                mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2;
+            }
+            for (int sb = 0; sb < 8; sb++) {
+                uint8_t *min = &mins[sb * 16];
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
+                }
+            }
+            for (int sb = 0; sb < 8; sb += 2) {
+                uint8_t *scales_0 = &scales[sb * 16];
+                uint8_t *scales_1 = &scales[(sb + 1) * 16];
+                for (int i = 0; i < QK4_0; i++) {
+                    for (int j = 0; j < ncols_interleaved; j++) {
+                        sumi1 = 0;
+                        sumi2 = 0;
+                        sumi = 0;
+                        const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF);
+                        const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4);
+                        sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]);
+                        sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]);
+                        sumi1 = sumi1 * scales_0[j];
+                        sumi2 = sumi2 * scales_1[j];
+                        sumi += sumi1 + sumi2;
+                        sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                    }
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
+        }
+    }
+}
+
+void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[16];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen          = 1;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[16];
+    int   sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0;
+        }
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / blocklen); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                        sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+                    }
+                    sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    assert(n % QK_K == 0);
+    assert(nr == 1);
+    assert(nc % 16 == 0);
+
+    UNUSED(bs);
+
+    const int nb = n / QK_K;
+    const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx;
+    const block_q8_K    * y = (const block_q8_K *)vy;
+
+    // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...)
+    const int sb_perm[16] = {
+        0, 4, 1, 5, 2, 6, 3, 7,  // 0-7
+        8, 12, 9, 13, 10, 14, 11, 15 // 8-15
+    };
+
+    for (int col_tile = 0; col_tile < nc; col_tile += 16) {
+        const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb;
+        const block_q8_K    * y_ptr = y;
+
+        float sumf[16] = {0};
+
+        // Loop over K-blocks
+        for (int k_block = 0; k_block < nb; ++k_block) {
+            int32_t isum[16]  = {0};
+            int32_t summs[16] = {0};
+
+            const uint8_t * qs_rhs = x_ptr[k_block].qs;
+            const uint8_t * sc_rhs = x_ptr[k_block].scales;
+            const int8_t  * qs_lhs = y_ptr[k_block].qs;
+            const int16_t * bs_lhs = y_ptr[k_block].bsums;
+
+            // Iterate over sub-blocks 0..15
+            for (int sb = 0; sb < 16; ++sb) {
+                // Correction Term
+                int16_t bsum = bs_lhs[sb];
+                int scale_offset = sb_perm[sb] * 16;
+
+                for (int col = 0; col < 16; ++col) {
+                    uint8_t sc_val = sc_rhs[scale_offset + col];
+                    summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits
+                }
+
+                // Main Dot Product
+                // Calculate base offsets for Q2 unpacking based on SB
+                int byte_base;
+                if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16;
+                else        byte_base = (sb % 2 == 0) ? 32 : 48;
+
+                int shift = ((sb / 2) % 4) * 2;
+
+                for (int col = 0; col < 16; ++col) {
+                    uint8_t sc_val = sc_rhs[scale_offset + col];
+                    int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits
+
+                    // Process 16 elements (l=0..15)
+                    for (int l = 0; l < 16; ++l) {
+                        // Q2: Interleaved by column. Byte `l` contains 4 k-values.
+                        int qs_idx = (byte_base + l) * 16 + col;
+                        uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3;
+
+                        // Q8: Linear access
+                        int k = sb * 16 + l;
+                        int8_t q8_val = qs_lhs[k];
+
+                        isum[col] += q8_val * q2_val * d_sb;
+                    }
+                }
+            }
+
+            // Finalize K-Block
+            for (int col = 0; col < 16; ++col) {
+                float d_lhs = y_ptr[k_block].d;
+                float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]);
+                float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]);
+
+                float d_all = d_lhs * d_rhs;
+                float d_min = d_lhs * dm_rhs;
+
+                sumf[col] += (isum[col] * d_all) - (summs[col] * d_min);
+            }
+        }
+
+        for (int col = 0; col < 16; ++col) {
+            s[col_tile + col] = sumf[col];
+        }
+    }
+}
+#endif
+
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -1046,15 +1913,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     assert (nr % 4 == 0);
     assert (nc % ncols_interleaved == 0);
 
-    UNUSED(s);
     UNUSED(bs);
-    UNUSED(vx);
-    UNUSED(vy);
-    UNUSED(nr);
-    UNUSED(nc);
-    UNUSED(nb);
-    UNUSED(ncols_interleaved);
-    UNUSED(blocklen);
 
     float sumf[4][8];
     float sum_minf[4][8];
@@ -1212,6 +2071,21 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
     }
 }
 
+void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+   ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
+}
 
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK8_0;
@@ -1313,6 +2187,94 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
 void ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,
                                      float * GGML_RESTRICT      s,
                                      size_t                     bs,
@@ -1365,6 +2327,8 @@ void ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,
     }
 }
 
+
+
 void ggml_gemm_q8_0_4x8_q8_0_generic(int                        n,
                                      float * GGML_RESTRICT      s,
                                      size_t                     bs,
@@ -1417,6 +2381,342 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int                        n,
     }
 }
 
+#if defined __riscv_zvfh
+void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+    float sumf[4][16];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK_K;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+    float sumf[4][16];
+    float sum_minf[4][16];
+    uint8_t scales[128];
+    uint8_t mins[128];
+    int sumi1;
+    int sumi2;
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                    sum_minf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int i = 0; i < 128; i++) {
+                    scales[i] = b_ptr[l].scales[i] & 0x0F;
+                    mins[i] = b_ptr[l].scales[i] >> 4;
+                }
+                for (int i = 0; i < 64; i++) {
+                    scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4;
+                    mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2;
+                    scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30);
+                    mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2;
+                }
+
+                for (int sb = 0; sb < 8; sb++) {
+                    uint8_t *min = &mins[sb * 16];
+                    for(int m = 0; m < 4; m++) {
+                        const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4];
+                        for(int j = 0; j < ncols_interleaved; j++) {
+                            sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
+                        }
+                    }
+                }
+
+                for (int sb = 0; sb < 8; sb += 2) {
+                    uint8_t *scales_0 = &scales[sb * 16];
+                    uint8_t *scales_1 = &scales[(sb + 1) * 16];
+
+                    for (int i = 0; i < QK4_0; i++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi1 = 0;
+                                sumi2 = 0;
+                                sumi = 0;
+
+                                const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF);
+                                const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4);
+                                sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]);
+                                sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]);
+                                sumi1 = sumi1 * scales_0[j];
+                                sumi2 = sumi2 * scales_1[j];
+                                sumi += sumi1 + sumi2;
+
+                                sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
+                            }
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
+                }
+            }
+        }
+    }
+}
+
+void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen = 1;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][16];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk                = QK8_0;
+    const int nb                = n / qk;
+    const int ncols_interleaved = 16;
+    const int blocklen          = 1;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][16];
+    int   sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0;
+                }
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / blocklen); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+                                sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+                            }
+                            sumf[m][j] +=
+                                sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+
+void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    assert(n % QK_K == 0);
+    assert(nr % 4 == 0);
+    assert(nc % 16 == 0);
+    const int nb = n / QK_K;
+    const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx;
+    const block_q8_Kx4  * y = (const block_q8_Kx4 *)vy;
+
+    const int sb_perm[16] = {
+        0, 4, 1, 5, 2, 6, 3, 7,
+        8, 12, 9, 13, 10, 14, 11, 15
+    };
+
+    // Iterate Rows in tiles of 4
+    for (int row_tile = 0; row_tile < nr; row_tile += 4) {
+        // Iterate Columns in tiles of 16
+        for (int col_tile = 0; col_tile < nc; col_tile += 16) {
+
+            const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb;
+            const block_q8_Kx4  * y_ptr = y + (row_tile / 4) * nb;
+
+            float sumf[4][16];
+            memset(sumf, 0, sizeof(sumf));
+
+            for (int k_block = 0; k_block < nb; ++k_block) {
+                int32_t isum[4][16];
+                int32_t summs[4][16];
+                memset(isum, 0, sizeof(isum));
+                memset(summs, 0, sizeof(summs));
+
+                const uint8_t * qs_rhs = x_ptr[k_block].qs;
+                const uint8_t * sc_rhs = x_ptr[k_block].scales;
+                const int8_t  * qs_lhs = y_ptr[k_block].qs;
+                const int16_t * bs_lhs = y_ptr[k_block].bsums;
+
+                for (int sb = 0; sb < 16; ++sb) {
+                    int scale_offset = sb_perm[sb] * 16;
+
+                    int byte_base;
+                    if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16;
+                    else        byte_base = (sb % 2 == 0) ? 32 : 48;
+                    int shift = ((sb / 2) % 4) * 2;
+
+                    for (int col = 0; col < 16; ++col) {
+                        uint8_t sc_val = sc_rhs[scale_offset + col];
+                        int32_t d_sb = sc_val & 0xF;
+                        int32_t m_sb = sc_val >> 4;
+
+                        // Correction Term
+                        for (int r = 0; r < 4; ++r) {
+                            int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4);
+                            summs[r][col] += bs_lhs[bsum_idx] * m_sb;
+                        }
+
+                        // Main Dot Product
+                        for (int l = 0; l < 16; ++l) {
+                            int qs_idx = (byte_base + l) * 16 + col;
+                            uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3;
+
+                            // Calculate Q8 index for this specific k and row
+                            int k = sb * 16 + l;
+                            int q8_idx = (k / 4) * 16 + (k % 4);
+
+                            for (int r = 0; r < 4; ++r) {
+                                // Add r*4 to jump to the correct row within the 4x4 chunk
+                                int8_t q8_val = qs_lhs[q8_idx + r * 4];
+                                isum[r][col] += q8_val * q2_val * d_sb;
+                            }
+                        }
+                    }
+                }
+
+                // Finalize K-Block
+                for (int col = 0; col < 16; ++col) {
+                    float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]);
+                    float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]);
+
+                    for (int r = 0; r < 4; ++r) {
+                        float d_lhs = y_ptr[k_block].d[r];
+                        float d_all = d_lhs * d_rhs;
+                        float d_min = d_lhs * dm_rhs;
+                        sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min);
+                    }
+                }
+            }
+
+            for (int r = 0; r < 4; ++r) {
+                for (int col = 0; col < 16; ++col) {
+                    s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col];
+                }
+            }
+        }
+    }
+}
+#endif
+
 } // extern "C"
 
 static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
@@ -1505,6 +2805,31 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
     return out;
 }
 
+static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) {
+    block_q4_0x16 out;
+
+    for (int i = 0; i < 16; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_0 * 8 / blck_size_interleave;
+
+    if (blck_size_interleave == 1) {
+        const uint8_t xor_mask = 0x88;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 16;
+            int src_offset = i / 16;
+            int dst_offset = i;
+
+            out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask;
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
 static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
     block_q4_Kx8 out;
     //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
@@ -1524,9 +2849,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
         int src_offset = (i / 8) * blck_size_interleave;
         int dst_offset = i * blck_size_interleave;
 
+        // buffer large enough for the max interleave block size (8 bytes)
         uint64_t elems;
-        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
-        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+        memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave);
+        memcpy(&out.qs[dst_offset], &elems, blck_size_interleave);
     }
 
     // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
@@ -1581,6 +2907,58 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in
     return out;
 }
 
+static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) {
+    block_q4_Kx16 out;
+    //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure
+    for (int i = 0; i < 16; i++) {
+        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+    }
+
+    for (int i = 0; i < 16; i++) {
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
+
+    const int end = QK_K * 8 / blck_size_interleave;
+
+    if (blck_size_interleave == 1) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 16;
+            int src_offset = i / 16;
+            int dst_offset = i;
+
+            out.qs[dst_offset] = in[src_id].qs[src_offset];
+        }
+
+        // RVV repacking.
+        //
+        // Extract sums and mins for all 8 sub-blocks for each block of Q4_K.
+        uint8_t s[128], m[128];
+        for (int i = 0; i < 4; i++) {
+            for (int j = 0; j < 16; j++) {
+                s[i * 16 + j] = in[j].scales[i] & 63;
+                m[i * 16 + j] = in[j].scales[i + 4] & 63;
+            }
+        }
+        for (int i = 0; i < 4; i++) {
+            for (int j = 0; j < 16; j++) {
+                s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
+                m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
+            }
+        }
+
+        for (int i = 0; i < 128; i++) {
+            out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4);
+        }
+        for (int i = 0; i < 64; i++) {
+            out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2);
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
 static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
     block_q2_Kx8 out;
 
@@ -1612,8 +2990,7 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
     // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
     // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
 
-    for(int i = 0; i < 128; i++){
-
+    for (int i = 0; i < 128; i++) {
         // Index for selecting which q2k super block
         int src1 = (i % 16) / 2;
         // Index for selecting scale
@@ -1622,7 +2999,199 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
         out.scales[i] = in[src1].scales[src2];
     }
     return out;
+}
 
+static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
+    block_q5_Kx8 out;
+    //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+    }
+
+    for (int i = 0; i < 8; i++) {
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
+
+    const int end = QK_K * 4 / blck_size_interleave;
+
+    // Interleave Q5_K quants by taking blck_size_interleave bytes at a time
+    for (int i = 0; i < end; ++i) {
+        int src_id     = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
+    }
+
+    // Repeat for high bits with the same chunk size, since
+    // the high bits are interleaved in Q5_K and the index is
+    // qh_idx = (qs_idx % 32);
+    // qh_val = qh[qh_idx] >> (qs_idx / 32);
+    for (int i = 0; i < end / 4; ++i) {
+        int src_id     = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);
+    }
+
+    // The below logic is copied over from Q4_K
+    // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
+    // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
+    // The output Q5_Kx8 structure has 96 bytes
+    // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
+    // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
+    uint8_t s[8], m[8];
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = in[j].scales[i] & 63;
+            m[j] = in[j].scales[i + 4] & 63;
+        }
+
+        out.scales[i * 12]      = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 1]  = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 2]  = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 3]  = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 4]  = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 5]  = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 6]  = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 7]  = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 8]  = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 9]  = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
+    }
+
+    for (int i = 0; i < 4; i++) {
+        for (int j = 0; j < 8; j++) {
+            s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
+            m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
+        }
+
+        out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
+        out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
+        out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
+        out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
+        out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
+        out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
+        out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
+        out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
+        out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
+        out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
+        out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
+        out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
+    }
+
+    return out;
+}
+
+static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) {
+    block_q6_Kx8  out;
+    constexpr int n_blocks = 8;  // Kx8
+    for (int i = 0; i < n_blocks; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end_ls = QK_K * 4 / blck_size_interleave;
+    // Interleave Q6_K quants by taking blck_size_interleave bytes at a time
+    for (int i = 0; i < end_ls; ++i) {
+        int src_id     = i % n_blocks;
+        int src_offset = (i / n_blocks) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elem_ls;
+        memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
+        memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
+    }
+
+    // Interleave high bits using same chunk size as low bits
+    const int end_hs = end_ls / 2;
+    for (int i = 0; i < end_hs; ++i) {
+        int src_id     = i % n_blocks;
+        int src_offset = (i / n_blocks) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elem_hs;
+        memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
+        memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
+    }
+
+    // The below logic is designed so as to unpack and rearrange scales in Q6_K
+    // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants
+    // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales
+    // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7]  (bl = block)
+    constexpr int n_scales = QK_K / 16;
+
+    for (int i = 0; i < n_blocks; i++) {
+        for (int j = 0; j < n_scales; j++) {
+            out.scales[j * n_blocks + i] = in[i].scales[j];
+        }
+    }
+
+    return out;
+}
+
+static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) {
+    block_q2_Kx16 out;
+    constexpr int N_COLS = 16;
+
+    // 1. Copy Super-Scales (d) and Super-Mins (dmin)
+    for (int i = 0; i < N_COLS; i++) {
+        out.d[i]    = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
+        out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
+    }
+
+    // 2. Interleave Q2_K Data
+    const int bytes_per_col = 64;
+    const int total_bytes = N_COLS * bytes_per_col;
+    const int end = total_bytes / blck_size_interleave;
+
+    for (int i = 0; i < end; ++i) {
+        int src_col_id = i % N_COLS;
+        int src_offset = (i / N_COLS) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+        memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave);
+    }
+
+    // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout
+    int out_idx = 0;
+
+    // Arrays define the sub-block order for each group
+    const int even_low_sbs[]  = {0, 2, 4, 6};
+    const int odd_low_sbs[]   = {1, 3, 5, 7};
+    const int even_high_sbs[] = {8, 10, 12, 14};
+    const int odd_high_sbs[]  = {9, 11, 13, 15};
+
+    // Pack Group 1: Even-Low
+    for (int sb : even_low_sbs) {
+        for (int col = 0; col < N_COLS; col++) {
+            out.scales[out_idx++] = in[col].scales[sb];
+        }
+    }
+
+    // Pack Group 2: Odd-Low
+    for (int sb : odd_low_sbs) {
+        for (int col = 0; col < N_COLS; col++) {
+            out.scales[out_idx++] = in[col].scales[sb];
+        }
+    }
+
+    // Pack Group 3: Even-High
+    for (int sb : even_high_sbs) {
+        for (int col = 0; col < N_COLS; col++) {
+            out.scales[out_idx++] = in[col].scales[sb];
+        }
+    }
+
+    // Pack Group 4: Odd-High
+    for (int sb : odd_high_sbs) {
+        for (int col = 0; col < N_COLS; col++) {
+            out.scales[out_idx++] = in[col].scales[sb];
+        }
+    }
+
+    return out;
 }
 
 static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
@@ -1687,6 +3256,36 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
+    constexpr int nrows_interleaved = 16;
+
+    block_q4_Kx16 * dst = (block_q4_Kx16*)t->data;
+    const block_q4_K * src = (const block_q4_K*) data;
+    block_q4_K dst_tmp[16];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
     GGML_ASSERT(interleave_block == 8);
@@ -1706,7 +3305,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
 
     for (int b = 0; b < nrow; b += nrows_interleaved) {
         for (int64_t x = 0; x < nblocks; x++) {
-            for (int i  = 0; i < nrows_interleaved; i++ ) {
+            for (int i = 0; i < nrows_interleaved; i++) {
                 dst_tmp[i] = src[x + i * nblocks];
             }
             *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
@@ -1718,6 +3317,132 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
     GGML_UNUSED(data_size);
 }
 
+static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
+    constexpr int nrows_interleaved = 16;
+
+    block_q2_Kx16 * dst = (block_q2_Kx16*)t->data;
+    const block_q2_K * src = (const block_q2_K*) data;
+
+    block_q2_K dst_tmp[nrows_interleaved];
+
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            // This loop gathers 16 separate blocks (one from each column)
+            // that correspond to the same K-dimension chunk.
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+
+            *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    constexpr int nrows_interleaved = 16;
+
+    block_q4_0x16 * dst = (block_q4_0x16*)t->data;
+    const block_q4_0 * src = (const block_q4_0*) data;
+    block_q4_0 dst_tmp[16];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor *       t,
+                                    int                        interleave_block,
+                                    const void * GGML_RESTRICT data,
+                                    size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q5_Kx8 *     dst = (block_q5_Kx8 *) t->data;
+    const block_q5_K * src = (const block_q5_K *) data;
+    block_q5_K         dst_tmp[8];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
+static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
+    const block_q6_K * src = (const block_q6_K *) data;
+    block_q6_K dst_tmp[8];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK_K;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
 static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
     GGML_ASSERT(interleave_block == 8);
@@ -1781,6 +3506,60 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor *       t,
     return 0;
 }
 
+static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) {
+    block_q8_0x16 out;
+
+    for (int i = 0; i < 16; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK8_0 * 16 / blck_size_interleave;
+
+    if (blck_size_interleave == 1) {
+        for (int i = 0; i < end; ++i) {
+            int src_id     = i % 16;
+            int src_offset = i / 16;
+            int dst_offset = i;
+            out.qs[dst_offset] = in[src_id].qs[src_offset];
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor *       t,
+                                    int                        interleave_block,
+                                    const void * GGML_RESTRICT data,
+                                    size_t                     data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
+    constexpr int nrows_interleaved = 16;
+
+    block_q8_0x16 *     dst = (block_q8_0x16 *) t->data;
+    const block_q8_0 * src = (const block_q8_0 *) data;
+    block_q8_0         dst_tmp[16];
+    int                nrow    = ggml_nrows(t);
+    int                nblocks = t->ne[0] / QK8_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q8_0x16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+}
+
 static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
     block_iq4_nlx4 out;
 
@@ -1906,6 +3685,177 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b
     GGML_UNUSED(data_size);
 }
 
+static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) {
+    block_iq4_nlx16 out;
+
+    for (int i = 0; i < 16; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_NL * 8 / blck_size_interleave;
+
+    if (blck_size_interleave == 1) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 16;
+            int src_offset = i / 16;
+            int dst_offset = i;
+
+            out.qs[dst_offset] = in[src_id].qs[src_offset];
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
+    GGML_ASSERT(interleave_block == 1);
+
+    const block_iq4_nl    * src = (const block_iq4_nl   *)data;
+          block_iq4_nlx16 * dst = (      block_iq4_nlx16 *)t->data;
+
+    block_iq4_nl dst_tmp[16];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 16;
+    int nblocks = t->ne[0] / QK4_NL;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
+
+    if (t->ne[1] % nrows_interleaved != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {
+    block_mxfp4x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.e[i] = in[i].e;
+    }
+
+    const int end = QK_MXFP4 * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 4) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
+    GGML_ASSERT(interleave_block == 4);
+
+    const block_mxfp4   * src = (const block_mxfp4   *)data;
+          block_mxfp4x4 * dst = (      block_mxfp4x4 *)t->data;
+
+    block_mxfp4 dst_tmp[4];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK_MXFP4;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) {
+    block_mxfp4x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.e[i] = in[i].e;
+    }
+
+    const int end = QK_MXFP4 * 4 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 8;
+            int src_offset = (i / 8) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
+    GGML_ASSERT(interleave_block == 8);
+
+    const block_mxfp4   * src = (const block_mxfp4   *)data;
+          block_mxfp4x8 * dst = (      block_mxfp4x8 *)t->data;
+
+    block_mxfp4 dst_tmp[8];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 8;
+    int nblocks = t->ne[0] / QK_MXFP4;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
+
+    if (t->ne[1] % nrows_interleaved != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 namespace ggml::cpu::repack {
 // repack
 template 
@@ -1936,6 +3886,22 @@ template <> int repack(struct ggml_tensor * t, const void * da
     return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
 }
@@ -1949,6 +3915,14 @@ template <> int repack(struct ggml_tensor * t, const void *
     return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
 }
@@ -1957,6 +3931,28 @@ template <> int repack(struct ggml_tensor * t, const void * da
     return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
 }
 
+#if defined __riscv_zvfh
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size);
+}
+#endif
+
 // gemv
 template 
 void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1973,6 +3969,17 @@ template <> void gemv(int n, float * s, size_t
     ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <>
+void gemv(int          n,
+                                            float *      s,
+                                            size_t       bs,
+                                            const void * vx,
+                                            const void * vy,
+                                            int          nr,
+                                            int          nc) {
+    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
 }
@@ -1981,8 +3988,20 @@ template <> void gemv(int n, float * s, size_t
     ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@@ -1993,6 +4012,14 @@ template <> void gemv(int n, float * s, size
     ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
@@ -2001,6 +4028,28 @@ template <> void gemv(int n, float * s, size_t
     ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+#if defined __riscv_zvfh
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+#endif
+
 // gemm
 template 
 void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -2013,20 +4062,43 @@ template <> void gemm(int n, float * s, size_t
     ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+template <>
+void gemm(int          n,
+                                            float *      s,
+                                            size_t       bs,
+                                            const void * vx,
+                                            const void * vy,
+                                            int          nr,
+                                            int          nc) {
+    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
-template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
-    ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
 template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@@ -2037,6 +4109,14 @@ template <> void gemm(int n, float * s, size
     ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
@@ -2045,6 +4125,28 @@ template <> void gemm(int n, float * s, size_t
     ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+#if defined __riscv_zvfh
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+#endif
+
 class tensor_traits_base : public ggml::cpu::tensor_traits {
   public:
     virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -2063,7 +4165,7 @@ template src[1]));
-                    size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
+                    size = GGML_PAD(size, sizeof(int64_t)); // + padding for next block.
 
                     const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
                     const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
@@ -2328,7 +4430,7 @@ template wdata;
         auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
 
-        // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
+        // total of [n_as][ne12 + 1] elements of type mmid_row_mapping (2*int32_t = int64_t)
         auto * matrix_row_counts = (int64_t *) (wdata_src1_end);                                        // [n_as]
         struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
 
@@ -2393,20 +4495,19 @@ template (ne00,
-                        (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
-                        src0_cur + src0_cur_start * nb01,
-                        src1_col, 1, src0_cur_end - src0_cur_start);
+                gemv(
+                    ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
+                    src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
             }
         }
 #undef MMID_MATRIX_ROW
@@ -2422,7 +4523,6 @@ template  q4_0_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0;
@@ -2432,6 +4532,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K;
     static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K;
 
+    // instance for Q5_K
+    static const ggml::cpu::repack::tensor_traits q5_K_8x4_q8_K;
+    static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K;
+
+    // instance for Q6_K
+    static const ggml::cpu::repack::tensor_traits q6_K_8x4_q8_K;
+    static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K;
+
     // instance for Q2
     static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K;
 
@@ -2439,13 +4547,28 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0;
 
+    // instance for MXFP4
+    static const ggml::cpu::repack::tensor_traits mxfp4_4x4_q8_0;
+    static const ggml::cpu::repack::tensor_traits mxfp4_8x8_q8_0;
+
     // instance for Q8_0
     static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0;
 
+    // instances for RISC-V
+    //
+    // These implement outer-product style matrix multiplication kernels with
+    // an interleave of 1.
+#if defined __riscv_zvfh
+    static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0;
+    static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K;
+    static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0;
+    static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0;
+    static const ggml::cpu::repack::tensor_traits q2_K_16x1_q8_K;
+#endif
+
     if (cur->type == GGML_TYPE_Q4_0) {
-        if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
-            || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
+        if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
             if (cur->ne[1] % 8 == 0) {
                 return &q4_0_8x8_q8_0;
             }
@@ -2460,6 +4583,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q4_0_4x4_q8_0;
             }
         }
+        if (ggml_cpu_has_riscv_v()) {
+            #if defined __riscv_zvfh
+            switch (__riscv_vlenb() * 8) {
+                case 128:  { break; } // TODO
+                case 256:  { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; }
+                case 512:  { break; } // TODO
+                case 1024: { break; } // TODO
+                default:   { return nullptr; }
+            }
+            #endif
+        }
     } else if (cur->type == GGML_TYPE_Q4_K) {
         if (ggml_cpu_has_avx2()) {
             if (cur->ne[1] % 8 == 0) {
@@ -2476,12 +4610,56 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q4_K_8x4_q8_K;
             }
         }
+        if (ggml_cpu_has_riscv_v()) {
+            #if defined __riscv_zvfh
+            switch (__riscv_vlenb() * 8) {
+                case 128:  { break; } // TODO
+                case 256:  { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; }
+                case 512:  { break; } // TODO
+                case 1024: { break; } // TODO
+                default:   { return nullptr; }
+            }
+            #endif
+        }
     } else if (cur->type == GGML_TYPE_Q2_K) {
         if (ggml_cpu_has_avx512()) {
             if (cur->ne[1] % 8 == 0) {
                 return &q2_K_8x8_q8_K;
             }
         }
+        if (ggml_cpu_has_riscv_v()) {
+            #if defined __riscv_zvfh
+            switch (__riscv_vlenb() * 8) {
+                case 128:  { break; } // TODO
+                case 256:  { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; }
+                case 512:  { break; } // TODO
+                case 1024: { break; } // TODO
+                default:   { return nullptr; }
+            }
+            #endif
+        }
+    } else if (cur->type == GGML_TYPE_Q5_K) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q5_K_8x8_q8_K;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q5_K_8x4_q8_K;
+            }
+        }
+    } else if (cur->type == GGML_TYPE_Q6_K) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q6_K_8x8_q8_K;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q6_K_8x4_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_IQ4_NL) {
         if (ggml_cpu_has_avx2()) {
             if (cur->ne[1] % 8 == 0) {
@@ -2493,6 +4671,28 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &iq4_nl_4x4_q8_0;
             }
         }
+        if (ggml_cpu_has_riscv_v()) {
+            #if defined __riscv_zvfh
+            switch (__riscv_vlenb() * 8) {
+                case 128:  { break; } // TODO
+                case 256:  { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; }
+                case 512:  { break; } // TODO
+                case 1024: { break; } // TODO
+                default:   { return nullptr; }
+            }
+            #endif
+        }
+    } else if (cur->type == GGML_TYPE_MXFP4) {
+        if (ggml_cpu_has_avx2()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &mxfp4_8x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &mxfp4_4x4_q8_0;
+            }
+        }
     } else if (cur->type == GGML_TYPE_Q8_0) {
         if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
             if (cur->ne[1] % 4 == 0) {
@@ -2504,6 +4704,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q8_0_4x4_q8_0;
             }
         }
+        if (ggml_cpu_has_riscv_v()) {
+            #if defined __riscv_zvfh
+            switch (__riscv_vlenb() * 8) {
+                case 128:  { break; } // TODO
+                case 256:  { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; }
+                case 512:  { break; } // TODO
+                case 1024: { break; } // TODO
+                default:   { return nullptr; }
+            }
+            #endif
+        }
     }
 
     return nullptr;
diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h
index af98e703..cb21edf6 100644
--- a/ggml/src/ggml-cpu/repack.h
+++ b/ggml/src/ggml-cpu/repack.h
@@ -28,13 +28,17 @@ template  struct block {
 // control size
 static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
 static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
+static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding");
 static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
 static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
+static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding");
 
 using block_q4_0x4 = block<4, 4>;
 using block_q4_0x8 = block<4, 8>;
+using block_q4_0x16 = block<4, 16>;
 using block_q8_0x4 = block<8, 4>;
 using block_q8_0x8 = block<8, 8>;
+using block_q8_0x16 = block<8, 16>;
 
 struct block_q4_Kx8 {
     ggml_half d[8];      // super-block scale for quantized scales
@@ -44,6 +48,14 @@ struct block_q4_Kx8 {
 };
 
 static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
+struct block_q4_Kx16 {
+    ggml_half d[16];      // super-block scale for quantized scales
+    ggml_half dmin[16];   // super-block scale for quantized mins
+    uint8_t scales[192];  // scales and mins, quantized with 6 bits
+    uint8_t qs[2048];    // 4--bit quants
+};
+
+static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding");
 struct block_q2_Kx8 {
     ggml_half d[8];      // super-block scale for quantized scales
     ggml_half dmin[8];   // super-block scale for quantized mins
@@ -52,6 +64,35 @@ struct block_q2_Kx8 {
 };
 
 static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
+struct block_q2_Kx16 {
+    ggml_half d[16];       // Super-block scale for quantized scales
+    ggml_half dmin[16];    // Super-block scale for quantized mins
+    uint8_t   scales[256]; // Sub-block scales (16 cols * 16 sub-blocks)
+    uint8_t   qs[1024];    // Data (16 cols * 64 bytes per block)
+};
+static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding");
+
+struct block_q5_Kx8 {
+    ggml_half d[8];              // super-block scale for quantized scales
+    ggml_half dmin[8];           // super-block scale for quantized mins
+    uint8_t   scales[96];        // scales and mins, quantized with 6 bits
+    uint8_t   qh[QK_K * 8 / 8];  // high bits of 5-bit quants
+    uint8_t   qs[QK_K * 8 / 2];  // low bits of 5-bit quants (in groups of 4)
+};
+
+static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
+              "wrong q5_K block size/padding");
+
+struct block_q6_Kx8 {
+    ggml_half d[8];
+    int8_t    scales[QK_K / 16 * 8];
+    uint8_t   ql[QK_K / 2 * 8];  // low bits of 6-bit quants (groups of 2)
+    uint8_t   qh[QK_K / 4 * 8];  // high bits of 6-bit quants (groups of 4)
+};
+
+static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8,
+              "wrong q6_K block size/padding");
+
 struct block_q8_Kx4 {
     float d[4];              // delta
     int8_t qs[QK_K * 4];     // quants
@@ -74,6 +115,24 @@ struct block_iq4_nlx8 {
 
 static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
 
+struct block_iq4_nlx16 {
+    ggml_half d[16];            // deltas for 16 iq4_nl blocks
+    uint8_t   qs[QK4_NL * 8];  // nibbles / quants for 16 iq4_nl blocks
+};
+
+static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding");
+struct block_mxfp4x4 {
+    uint8_t e[4];
+    uint8_t qs[QK_MXFP4 * 2];
+};
+static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
+
+struct block_mxfp4x8 {
+    uint8_t e[8];
+    uint8_t qs[QK_MXFP4 * 4];
+};
+static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
+
 #if defined(__cplusplus)
 extern "C" {
 #endif
@@ -85,23 +144,49 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
 void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+#if defined __riscv_zvfh
+void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+#endif
 
 // Native implementations
 void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@@ -111,23 +196,49 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
 void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
-void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+#if defined __riscv_zvfh
+void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
+void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+#endif
 
 #if defined(__cplusplus)
 } // extern "C"
diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h
new file mode 100644
index 00000000..78d663e5
--- /dev/null
+++ b/ggml/src/ggml-cpu/simd-gemm.h
@@ -0,0 +1,136 @@
+#pragma once
+
+// Computes C[M x N] += A[M x K] * B[K x N]
+
+#include "simd-mappings.h"
+
+// TODO: add support for sizeless vector types
+#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
+
+// TODO: untested on avx512
+// These are in units of GGML_F32_EPR
+#if defined(__AVX512F__) || defined (__ARM_NEON__)
+    static constexpr int GEMM_RM = 4;
+    static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
+#elif defined(__AVX2__) || defined(__AVX__)
+    static constexpr int GEMM_RM = 6;
+    static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
+#else
+    static constexpr int GEMM_RM = 2;
+    static constexpr int GEMM_RN = 2;
+#endif
+
+template 
+static inline void simd_gemm_ukernel(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int K, int N)
+{
+    static constexpr int KN = GGML_F32_EPR;
+
+    GGML_F32_VEC acc[RM][RN];
+    for (int64_t i = 0; i < RM; i++) {
+        for (int r = 0; r < RN; r++) {
+            acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
+        }
+    }
+
+    for (int64_t kk = 0; kk < K; kk++) {
+        GGML_F32_VEC Bv[RN];
+        for (int r = 0; r < RN; r++) {
+            Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
+        }
+        for (int64_t i = 0; i < RM; i++) {
+            GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
+            for (int r = 0; r < RN; r++) {
+                acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
+            }
+        }
+    }
+
+    for (int64_t i = 0; i < RM; i++) {
+        for (int r = 0; r < RN; r++) {
+            GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
+        }
+    }
+}
+
+// C[M x N] += A[M x K] * B[K x N]
+static void simd_gemm(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int M, int K, int N)
+{
+    static constexpr int KN = GGML_F32_EPR;
+
+    int64_t ii = 0;
+    for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
+        int64_t jj = 0;
+        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
+            simd_gemm_ukernel(C + jj, A, B + jj, K, N);
+        }
+        for (; jj + KN <= N; jj += KN) {
+            simd_gemm_ukernel(C + jj, A, B + jj, K, N);
+        }
+        for (; jj < N; jj++) {
+            for (int64_t i = 0; i < GEMM_RM; i++) {
+                float a = C[i * N + jj];
+                for (int64_t kk = 0; kk < K; kk++) {
+                    a += A[i + kk] * B[kk * N + jj];
+                }
+                C[i * N + jj] = a;
+            }
+        }
+
+        A += GEMM_RM * K;
+        C += GEMM_RM * N;
+    }
+
+    // Tail rows: one at a time
+    for (; ii < M; ii++) {
+        int64_t jj = 0;
+        for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
+            simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
+        }
+        for (; jj + KN <= N; jj += KN) {
+            simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
+        }
+        for (; jj < N; jj++) {
+            float a = C[jj];
+            for (int64_t kk = 0; kk < K; kk++) {
+                a += A[kk] * B[kk * N + jj];
+            }
+            C[jj] = a;
+        }
+
+        A += K;
+        C += N;
+    }
+}
+
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC diagnostic pop
+#endif
+
+#else // scalar path
+
+static void simd_gemm(
+    float       * GGML_RESTRICT C,
+    const float * GGML_RESTRICT A,
+    const float * GGML_RESTRICT B,
+    int M, int K, int N)
+{
+    for (int64_t i = 0; i < M; i++) {
+        for (int64_t j = 0; j < N; j++) {
+            float sum = C[i * N + j];
+            for (int64_t kk = 0; kk < K; kk++) {
+                sum += A[i * K + kk] * B[kk * N + j];
+            }
+            C[i * N + j] = sum;
+        }
+    }
+}
+
+#endif // GGML_SIMD
diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h
index a7a82722..0deda930 100644
--- a/ggml/src/ggml-cpu/simd-mappings.h
+++ b/ggml/src/ggml-cpu/simd-mappings.h
@@ -116,6 +116,17 @@ extern "C" {
 // defined in ggml-cpu.c, initialized in ggml_cpu_init()
 extern float ggml_table_f32_f16[1 << 16];
 
+// precomputed f32 table for e8m0 half (1 KB)
+// defined in ggml-cpu.c, initialized in ggml_cpu_init()
+extern float ggml_table_f32_e8m0_half[1 << 8];
+
+// Use lookup table for E8M0 on x86 (faster than bit manipulation)
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
+#else
+#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
+#endif
+
 // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
 // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
 // This is also true for POWER9.
@@ -468,13 +479,51 @@ do {                                                                  \
 
 // F16 AVX512
 
-// F16 AVX
+#if defined(__AVX512FP16__)
+
+#define GGML_F16_STEP 128
+#define GGML_F16_EPR  32
+
+#define GGML_F16x32              __m512h
+#define GGML_F16x32_ZERO         _mm512_setzero_ph()
+#define GGML_F16x32_SET1(x)      _mm512_set1_ph(__extension__(_Float16)(x))
+#define GGML_F16x32_LOAD(x)      _mm512_loadu_ph(x)
+#define GGML_F16x32_STORE(x, y)  _mm512_storeu_ph(x, y)
+#define GGML_F16x32_FMA(a, b, c) _mm512_fmadd_ph(b, c, a)
+#define GGML_F16x32_ADD          _mm512_add_ph
+#define GGML_F16x32_MUL          _mm512_mul_ph
+#define GGML_F16x32_REDUCE(res, x)                                     \
+do {                                                                   \
+    int offset = GGML_F16_ARR >> 1;                                    \
+    for (int i = 0; i < offset; ++i) {                                 \
+        x[i] = _mm512_add_ph(x[i], x[offset+i]);                       \
+    }                                                                  \
+    offset >>= 1;                                                      \
+    for (int i = 0; i < offset; ++i) {                                 \
+        x[i] = _mm512_add_ph(x[i], x[offset+i]);                       \
+    }                                                                  \
+    offset >>= 1;                                                      \
+    for (int i = 0; i < offset; ++i) {                                 \
+        x[i] = _mm512_add_ph(x[i], x[offset+i]);                       \
+    }                                                                  \
+    res = (ggml_float) _mm512_reduce_add_ph(x[0]);                     \
+} while (0)
+
+#define GGML_F16_VEC                GGML_F16x32
+#define GGML_F16_VEC_ZERO           GGML_F16x32_ZERO
+#define GGML_F16_VEC_SET1           GGML_F16x32_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x32_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x32_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F16x32_FMA
+#define GGML_F16_VEC_ADD            GGML_F16x32_ADD
+#define GGML_F16_VEC_MUL            GGML_F16x32_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F16x32_REDUCE
+
+#else // Fallback FP16 <-> FP32
 
 #define GGML_F16_STEP 64
 #define GGML_F16_EPR  16
 
-// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
-
 #define GGML_F32Cx16             __m512
 #define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
 #define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
@@ -514,6 +563,8 @@ do {                                                              \
 #define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
 
 #define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
+
+#endif // __AVX512FP16__
 #elif defined(__AVX__)
 
 #define GGML_SIMD
@@ -654,6 +705,14 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
           vec_extract(x[0], 2) +               \
           vec_extract(x[0], 3);                \
 }
+#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3)        \
+{                                                       \
+    vector float v = vec_add(vec_add(s0, s1),           \
+                             vec_add(s2, s3));          \
+    v = vec_add(v, vec_sld(v, v, 8));                   \
+    v = vec_add(v, vec_sld(v, v, 4));                   \
+    res += (ggml_float) vec_extract(v, 0);              \
+}
 
 #define GGML_F32_VEC        GGML_F32x4
 #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
@@ -690,6 +749,29 @@ static inline unsigned char ggml_endian_byte(int i) {
                                    r[i - GGML_ENDIAN_BYTE(0)]), \
             0, p - GGML_F16_EPR)
 
+//BF16 POWER9
+#define GGML_BF16_STEP 16
+#define GGML_BF16_EPR  8
+
+#define GGML_BF16x8         vector unsigned short
+#define GGML_BF16x8_ZERO    vec_splats((unsigned short)0)
+#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
+
+#define GGML_BF16_VEC          GGML_BF16x8
+#define GGML_BF16_VEC_ZERO     GGML_BF16x8_ZERO
+#define GGML_BF16_VEC_LOAD     GGML_BF16x8_LOAD
+#if defined(__LITTLE_ENDIAN__)
+#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v)))
+#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v)))
+#else
+#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO))
+#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO))
+#endif
+#define GGML_BF16_FMA_LO(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
+#define GGML_BF16_FMA_HI(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
+
 #elif defined(__wasm_simd128__)
 
 #define GGML_SIMD
@@ -1118,6 +1200,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
     float32x4_t tmp = x[0] + vec_reve(x[0]);        \
     res = tmp[0] + tmp[1];                          \
 }
+#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
+{                                                \
+    float32x4_t v = vec_add(vec_add(s0, s1),     \
+                            vec_add(s2, s3));    \
+    v = vec_add(v, vec_sld(v, v, 8));            \
+    v = vec_add(v, vec_sld(v, v, 4));            \
+    res += (ggml_float)vec_extract(v, 0);        \
+}
 
 #define GGML_F32_VEC        GGML_F32x4
 #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
@@ -1167,6 +1257,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
 #define GGML_F16_VEC_MUL            GGML_F32x4_MUL
 #define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE
 
+// BF16 s390x
+#define GGML_BF16_STEP 16
+#define GGML_BF16_EPR  8
+
+#define GGML_BF16x8         __vector unsigned short
+#define GGML_BF16x8_ZERO    vec_splats((unsigned short)0)
+#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
+
+#define GGML_BF16_VEC      GGML_BF16x8
+#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
+#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
+#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
+#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
+#define GGML_BF16_FMA_LO(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
+#define GGML_BF16_FMA_HI(acc, x, y) \
+    (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
+
 #elif defined(__riscv_v_intrinsic)
 
 // compatible with vlen >= 128
diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp
index 1d9873ad..1d834443 100644
--- a/ggml/src/ggml-cpu/unary-ops.cpp
+++ b/ggml/src/ggml-cpu/unary-ops.cpp
@@ -111,7 +111,7 @@ template 
 static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
 
-    GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index 427e6324..d0e40013 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -236,7 +236,24 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
     vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
     sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
 
+#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
+    const int np = (n & ~(GGML_BF16_STEP - 1));
+    if (np > 0) {
+        GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
+        for (; i < np; i += GGML_BF16_STEP) {
+            GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i);
+            GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8);
+            GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i);
+            GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8);
+            GGML_BF16_FMA_LO(sum[0], vx0, vy0);
+            GGML_BF16_FMA_HI(sum[1], vx0, vy0);
+            GGML_BF16_FMA_LO(sum[2], vx1, vy1);
+            GGML_BF16_FMA_HI(sum[3], vx1, vy1);
+        }
+        GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]);
+    }
 #endif
+
     for (; i < n; ++i) {
         sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
                              GGML_BF16_TO_FP32(y[i]));
diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index d313c1ac..262f8820 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND)
         FetchContent_Declare(
             CCCL
             GIT_REPOSITORY https://github.com/nvidia/cccl.git
-            GIT_TAG        v3.2.0-rc2
+            GIT_TAG        v3.2.0
             GIT_SHALLOW    TRUE
         )
 
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 57c8a99a..4896669c 100644
--- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -2,6 +2,9 @@
 
 #ifdef GGML_CUDA_USE_CUB
 #    include 
+#    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
+#        define STRIDED_ITERATOR_AVAILABLE
+#    endif
 using namespace cub;
 #endif  // GGML_CUDA_USE_CUB
 
@@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
     }
 }
 
+#ifndef STRIDED_ITERATOR_AVAILABLE
 static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
     const int idx = blockIdx.x * blockDim.x + threadIdx.x;
     if (idx <= nrows) {
         offsets[idx] = idx * ncols;
     }
 }
+#endif  // STRIDED_ITERATOR_AVAILABLE
 
 #ifdef GGML_CUDA_USE_CUB
 void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
@@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                               cudaStream_t     stream) {
     ggml_cuda_pool_alloc   temp_indices_alloc(pool, ncols * nrows);
     ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows);
-    ggml_cuda_pool_alloc   offsets_alloc(pool, nrows + 1);
 
     int *   temp_indices = temp_indices_alloc.get();
     float * temp_keys    = temp_keys_alloc.get();
-    int *   d_offsets    = offsets_alloc.get();
 
     static const int block_size = 256;
     const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
     init_indices<<>>(temp_indices, ncols, nrows);
 
-    const dim3 offset_grid((nrows + block_size - 1) / block_size);
-    init_offsets<<>>(d_offsets, ncols, nrows);
-
+#ifdef STRIDED_ITERATOR_AVAILABLE
+    auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
+#else
+    ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1);
+    int *                     offset_iterator = offsets_alloc.get();
+    const dim3                offset_grid((nrows + block_size - 1) / block_size);
+    init_offsets<<>>(offset_iterator, ncols, nrows);
+#endif
     CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 
     size_t temp_storage_bytes = 0;
@@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
             DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys,  // keys (in-place)
                                            temp_indices, dst,                                  // values (indices)
                                            ncols * nrows, nrows,  // num items, num segments
-                                           d_offsets, d_offsets + 1, stream);
+                                           offset_iterator, offset_iterator + 1, stream);
         }
     } else {
         if (nrows == 1) {
@@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                                  ncols, 0, sizeof(float) * 8, stream);
         } else {
             DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
-                                                     dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+                                                     dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
+                                                     stream);
         }
     }
 
@@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                        ncols, 0, sizeof(float) * 8, stream);
         } else {
             DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
-                                           ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
+                                           ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
         }
     } else {
         if (nrows == 1) {
@@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
                                                  ncols, 0, sizeof(float) * 8, stream);
         } else {
             DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
-                                                     temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
-                                                     stream);
+                                                     temp_indices, dst, ncols * nrows, nrows, offset_iterator,
+                                                     offset_iterator + 1, stream);
         }
     }
 }
diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu
index 0e6d777b..7339fe0c 100644
--- a/ggml/src/ggml-cuda/binbcast.cu
+++ b/ggml/src/ggml-cuda/binbcast.cu
@@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t *         src0,
                                    const uint3            ne11,
                                    const uint3            ne12,
                                    const uint3            ne13,
-                                   /*int s0, */ const int s1,
+                                 /*const int              s0,*/
+                                   const int              s1,
                                    const int              s2,
                                    const int              s3,
-                                   /*int s00,*/ const int s01,
+                                   const int              s00,
+                                   const int              s01,
                                    const int              s02,
                                    const int              s03,
-                                   /*int s10,*/ const int s11,
+                                   const int              s10,
+                                   const int              s11,
                                    const int              s12,
                                    const int              s13,
                                    src1_ptrs... src1s) {
@@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t *         src0,
     for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
         const uint32_t i10 = fastmodulo(i0, ne10);
 
-        float result = src0_row ? (float) src0_row[i0] : 0.0f;
+        float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
         if constexpr (sizeof...(src1_ptrs) > 0) {
-            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
         } else {
-            result = bin_op(result, (float)src1[i_src1 + i10]);
+            result = bin_op(result, (float)src1[i_src1 + i10*s10]);
         }
 
         dst_row[i0] = (dst_t) result;
@@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
                                            const uint3            ne11,
                                            const uint3            ne12,
                                            const uint3            ne13,
-                                           /*int s0, */ const int s1,
+                                         /*const int              s0,*/
+                                           const int              s1,
                                            const int              s2,
                                            const int              s3,
-                                           /*int s00,*/ const int s01,
+                                           const int              s00,
+                                           const int              s01,
                                            const int              s02,
                                            const int              s03,
-                                           /*int s10,*/ const int s11,
+                                           const int              s10,
+                                           const int              s11,
                                            const int              s12,
                                            const int              s13,
                                            src1_ptrs... src1s) {
@@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
 
     const int i10 = fastmodulo(i0, ne10);
 
-    float result = src0_row ? (float) src0_row[i0] : 0.0f;
+    float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
     if constexpr (sizeof...(src1_ptrs) > 0) {
-        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
     } else {
-        result = bin_op(result, (float)src1[i_src1 + i10]);
+        result = bin_op(result, (float)src1[i_src1 + i10*s10]);
     }
 
     dst_row[i0] = (dst_t) result;
@@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         cnb[3] *= cne[3];
     };
 
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
         for (int i = 0; i < 4; i++) {
             if (nr[i] != 1) {
                 break;
@@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         size_t nb12 = cnb1[2];
         size_t nb13 = cnb1[3];
 
-        size_t s0 = nb0 / sizeof(dst_t);
+      //size_t s0 = nb0 / sizeof(dst_t);
         size_t s1 = nb1 / sizeof(dst_t);
         size_t s2 = nb2 / sizeof(dst_t);
         size_t s3 = nb3 / sizeof(dst_t);
@@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
         GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
 
-        GGML_ASSERT(s0 == 1);
-        GGML_ASSERT(s00 == 1);
-        GGML_ASSERT(s10 == 1);
-
         const int block_size = 128;
 
         int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
                 k_bin_bcast_unravel<<>>(
                     src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
                     ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+                  /*s0,*/ s1,  s2,  s3,
+                    s00, s01, s02, s03,
+                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast_unravel
                     <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
                                                            ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
-                                                           /* s0, */ s1, s2, s3,
-                                                           /* s00,*/ s01, s02, s03,
-                                                           /* s10,*/ s11, s12, s13);
+                                                         /*s0,*/ s1,  s2,  s3,
+                                                           s00, s01, s02, s03,
+                                                           s10, s11, s12, s13);
             }
         } else {
             const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
             if constexpr (sizeof...(I) > 0) {
                 k_bin_bcast<<>>(
                     src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+                  /*s0,*/ s1, s2,  s3,
+                    s00 ,s01, s02, s03,
+                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast<<>>(
                     src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12, s13);
+                  /*s0,*/ s1,  s2,  s3,
+                    s00, s01, s02, s03,
+                    s10, s11, s12, s13);
             }
         }
     }
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 9516d8ec..36d8a3aa 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -53,6 +53,7 @@
 // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
 // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
 #define GGML_CUDA_CC_BLACKWELL       1200
+#define GGML_CUDA_CC_DGX_SPARK       1210
 #define GGML_CUDA_CC_RUBIN           1300
 #define GGML_CUDA_CC_OFFSET_AMD      0x1000000
 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
@@ -262,6 +263,10 @@ static const char * cu_get_error_str(CUresult err) {
 #define FLASH_ATTN_AVAILABLE
 #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
 
+#if defined(TURING_MMA_AVAILABLE)
+#define LDMATRIX_TRANS_AVAILABLE
+#endif // defined(TURING_MMA_AVAILABLE)
+
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
         (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
@@ -526,6 +531,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+enum class block_reduce_method {
+    MAX,
+    SUM,
+};
+
+template
+struct block_reduce_policy;
+
+template 
+inline constexpr bool is_any = (std::is_same_v || ...);
+
+template
+inline constexpr bool ggml_cuda_dependent_false_v = false;
+
+template  struct block_reduce_policy {
+    static __device__ T reduce(T val) {
+        if constexpr(is_any) {
+            return warp_reduce_sum(val);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum");
+        }
+    }
+
+    static __device__ T sentinel() {
+        if constexpr (std::is_same_v) {
+            return 0.0f;
+        } else if constexpr (std::is_same_v) {
+            return make_float2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v) {
+            return make_half2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v) {
+            return 0;
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum");
+        }
+    }
+};
+
+template  struct block_reduce_policy {
+    static __device__ T reduce(T val) {
+        if constexpr (is_any) {
+            return warp_reduce_max(val);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max");
+        }
+    }
+
+    static __device__ T sentinel() {
+        if constexpr (std::is_same_v) {
+            return -INFINITY;
+        } else if constexpr (std::is_same_v) {
+            return make_half2(-INFINITY, -INFINITY);
+        } else {
+            static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max");
+        }
+    }
+};
+
+template 
+static __device__ T block_reduce(T val, T * shared_vals) {
+    val                           = block_reduce_policy::reduce(val);
+    const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+    if (block_size > WARP_SIZE) {
+        assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            shared_vals[warp_id] = val;
+        }
+        __syncthreads();
+        val = block_reduce_policy::sentinel();
+        if (lane_id < (static_cast(block_size) / WARP_SIZE)) {
+            val = shared_vals[lane_id];
+        }
+        return block_reduce_policy::reduce(val);
+    }
+
+    return val;
+}
+
 static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
 #ifdef FP16_AVAILABLE
 
@@ -1037,14 +1122,18 @@ struct ggml_tensor_extra_gpu {
 #endif
 
 struct ggml_cuda_graph_node_properties {
-    void * node_address;
+    void * node_data;
     ggml_op node_op;
+    enum ggml_type node_type;
+    int32_t flags;
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
-    void * src_address[GGML_MAX_SRC];
+    void * src_data[GGML_MAX_SRC];
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 
+static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial");
+
 struct ggml_cuda_graph {
 #ifdef USE_CUDA_GRAPH
     ~ggml_cuda_graph() {
@@ -1060,25 +1149,18 @@ struct ggml_cuda_graph {
     size_t num_nodes = 0;
     std::vector nodes;
     bool disable_due_to_gpu_arch = false;
-    bool disable_due_to_too_many_updates = false;
-    int number_consecutive_updates = 0;
+    bool warmup_complete = false;
     std::vector props;
 
-    void record_update(bool use_graph, bool update_required) {
-        if (use_graph && update_required) {
-            number_consecutive_updates++;
-        } else {
-            number_consecutive_updates = 0;
-        }
-        if (number_consecutive_updates >= 4) {
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-            disable_due_to_too_many_updates = true;
-        }
-    }
+    // these are extra tensors (inputs) that participate in the ggml graph but are not nodes
+    // they properties also have to match in order to be able to safely reuse a CUDA graph
+    // ref: https://github.com/ggml-org/llama.cpp/pull/18583
+    // ref: https://github.com/ggml-org/llama.cpp/pull/19165
+    std::vector extra;
 
     bool is_enabled() const {
         static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-        return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
+        return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env);
     }
 #endif
 };
@@ -1242,10 +1324,44 @@ struct ggml_backend_cuda_context {
     cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
     cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-    std::unique_ptr cuda_graph;
-
     int curr_stream_no = 0;
 
+#ifdef USE_CUDA_GRAPH
+    // Map from first_node_ptr to cuda_graph - allows multiple graphs per context
+    // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
+    std::unordered_map> cuda_graphs;
+
+    ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
+        auto it = cuda_graphs.find(first_node_ptr);
+        if (it == cuda_graphs.end()) {
+            cuda_graphs[first_node_ptr] = std::make_unique();
+            return cuda_graphs[first_node_ptr].get();
+        }
+        return it->second.get();
+    }
+
+    // Check if any CUDA graph is enabled for this context (used by kernels that need to know
+    // if graphs are in use without having access to the specific graph key)
+    bool any_cuda_graph_enabled() const {
+        for (const auto & [key, graph] : cuda_graphs) {
+            if (graph && graph->is_enabled()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    // Check if any CUDA graph has an instance for this context
+    bool any_cuda_graph_has_instance() const {
+        for (const auto & [key, graph] : cuda_graphs) {
+            if (graph && graph->instance != nullptr) {
+                return true;
+            }
+        }
+        return false;
+    }
+#endif // USE_CUDA_GRAPH
+
     explicit ggml_backend_cuda_context(int device) :
         device(device),
         name(GGML_CUDA_NAME + std::to_string(device)) {
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index ba3d4eeb..b70492c7 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -7,7 +7,8 @@
 
 template 
 static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
-        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t ne00, const int64_t ne01,
+        const int64_t ne0203, const uint3 ne02,
         const int64_t s01, const int64_t s02, const int64_t s03) {
     const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
 
@@ -15,24 +16,28 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
         return;
     }
 
-    const int64_t i01 = blockIdx.y;
-    const int64_t i02 = blockIdx.z % ne02;
-    const int64_t i03 = blockIdx.z / ne02;
+    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
+        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+            const int64_t i02 = dm.y;
+            const int64_t i03 = dm.x;
 
-    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+            const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
 
-    const int64_t ib = ibx0 + i00/qk; // block index
-    const int64_t iqs = (i00%qk)/qr; // quant index
-    const int64_t iybs = i00 - i00%qk; // y block start index
-    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+            const int64_t ib = ibx0 + i00/qk; // block index
+            const int64_t iqs = (i00%qk)/qr; // quant index
+            const int64_t iybs = i00 - i00%qk; // y block start index
+            const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
-    // dequantize
-    float2 v;
-    dequantize_kernel(vx, ib, iqs, v);
+            // dequantize
+            float2 v;
+            dequantize_kernel(vx, ib, iqs, v);
 
-    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
-    y[iy0 + 0]        = ggml_cuda_cast(v.x);
-    y[iy0 + y_offset] = ggml_cuda_cast(v.y);
+            const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
+            y[iy0 + 0]        = ggml_cuda_cast(v.x);
+            y[iy0 + y_offset] = ggml_cuda_cast(v.y);
+        }
+    }
 }
 
 template 
@@ -485,9 +490,11 @@ template 
 static void dequantize_block_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
-    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
+    const int64_t ne0203 = ne02*ne03;
+    const uint3 ne02_fdv = init_fastdiv_values(ne02);
+    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
     dequantize_block<<>>
-        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
 
 template 
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
 
 template 
 static __global__ void convert_unary(
-        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
+        const int64_t ne0203, const uint3 ne02,
         const int64_t s01, const int64_t s02, const int64_t s03) {
     const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -620,24 +628,30 @@ static __global__ void convert_unary(
         return;
     }
 
-    const int64_t i01 = blockIdx.y;
-    const int64_t i02 = blockIdx.z % ne02;
-    const int64_t i03 = blockIdx.z / ne02;
-
     const src_t * x = (const src_t *) vx;
 
-    const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
-    const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
-    y[iy] = ggml_cuda_cast(x[ix]);
+    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
+        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+            const int64_t i02 = dm.y;
+            const int64_t i03 = dm.x;
+
+            const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
+            const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
+            y[iy] = ggml_cuda_cast(x[ix]);
+        }
+    }
 }
 
 template 
 static void convert_unary_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
-    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
+    const int64_t ne0203 = ne02*ne03;
+    const uint3 ne02_fdv = init_fastdiv_values(ne02);
+    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
     convert_unary<<>>
-        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
 
 template 
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index ee84303e..d208acf2 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -56,7 +56,8 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
     const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x;  // transpose block offset
     const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
 
-    __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
+    __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
+    int cur_tile_buf = 0;
 
 #pragma unroll
     for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
@@ -70,7 +71,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
             if(x < ne01 && y + j < ne00){
                 const int row = threadIdx.y+j;
                 const int col = threadIdx.x * sizeof(float)/sizeof(T);
-                T *tile2 = reinterpret_cast(tile[row]);
+                T *tile2 = reinterpret_cast(tile[cur_tile_buf][row]);
                 tile2[col] = src[imat*n + (y+j)*ne01 + x];
             }
         }
@@ -81,10 +82,12 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
         for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
             if (ty + j < ne01 && tx < ne00) {
                 const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
-                const T *tile2 = reinterpret_cast(tile[threadIdx.x]);
+                const T *tile2 = reinterpret_cast(tile[cur_tile_buf][threadIdx.x]);
                 dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
             }
         }
+
+        cur_tile_buf = (cur_tile_buf + 1) % 2;
     }
 
     GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index 31446787..e9abdf28 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -59,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
 
 #pragma unroll
     for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
-        half2 tmp[cpy_ne];
+        __align__(16) half2 tmp[cpy_ne];
         ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
 #pragma unroll
         for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
@@ -309,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_
         ggml_cuda_memcpy_1(dst, (const half *) vx + i0);
     } else if constexpr (std::is_same_v) {
         static_assert(ne % 2 == 0, "bad ne");
-        half2 tmp[ne/2];
+        __align__(16) half2 tmp[ne/2];
         ggml_cuda_memcpy_1(tmp, (const half *) vx + i0);
         float2 * dst_f2 = (float2 *) dst;
 #pragma unroll
@@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
 template // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
-        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
-        const int nbatch_fa) {
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
+        const int ne11, const int ne12, const int nbatch_fa) {
     constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
@@ -641,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
 
-    const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
-    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
+
+    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -654,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
-    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
-    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-    if (jt*ncols1 + j >= ne01) {
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
         return;
     }
 
-    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -681,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -778,13 +785,11 @@ void launch_fattn(
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
-    const bool is_mla = DV == 512; // TODO better parameterization
-
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
-    GGML_ASSERT(V || is_mla);
+    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
 
     const ggml_tensor * mask  = dst->src[3];
     const ggml_tensor * sinks = dst->src[4];
@@ -794,9 +799,9 @@ void launch_fattn(
     GGML_ASSERT(Q->type == GGML_TYPE_F32);
     GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(      Q->nb[0] == ggml_element_size(Q));
-    GGML_ASSERT(      K->nb[0] == ggml_element_size(K));
-    GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
+    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+    GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+    GGML_ASSERT(V->nb[0] == ggml_element_size(V));
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
 
@@ -817,10 +822,10 @@ void launch_fattn(
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    const char * V_data = V ? (const char *) V->data : nullptr;
-    size_t nb21 = V ? V->nb[1] : nb11;
-    size_t nb22 = V ? V->nb[2] : nb12;
-    size_t nb23 = V ? V->nb[3] : nb13;
+    const char * V_data = (const char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
 
     if (need_f16_K && K->type != GGML_TYPE_F16) {
         const size_t bs = ggml_blck_size(K->type);
@@ -849,36 +854,45 @@ void launch_fattn(
         K_data = (char *) K_f16.ptr;
     }
 
-    if (V && need_f16_V && V->type != GGML_TYPE_F16) {
-        const size_t bs = ggml_blck_size(V->type);
-        const size_t ts = ggml_type_size(V->type);
-
-        V_f16.alloc(ggml_nelements(V));
-        if (ggml_is_contiguously_allocated(V)) {
-            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
-            to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
-            V_data = (char *) V_f16.ptr;
-
-            nb21 = nb21*bs*sizeof(half)/ts;
-            nb22 = nb22*bs*sizeof(half)/ts;
-            nb23 = nb23*bs*sizeof(half)/ts;
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        if (V_is_K_view) {
+            V_data = K_data;
+            nb21   = nb11;
+            nb22   = nb12;
+            nb23   = nb13;
         } else {
-            GGML_ASSERT(V->nb[0] == ts);
-            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
-            const int64_t s01 = nb21 / ts;
-            const int64_t s02 = nb22 / ts;
-            const int64_t s03 = nb23 / ts;
-            to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+            const size_t bs = ggml_blck_size(V->type);
+            const size_t ts = ggml_type_size(V->type);
 
-            nb21 = V->ne[0] * sizeof(half);
-            nb22 = V->ne[1] * nb21;
-            nb23 = V->ne[2] * nb22;
+            V_f16.alloc(ggml_nelements(V));
+            if (ggml_is_contiguously_allocated(V)) {
+                to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+                V_data = (char *) V_f16.ptr;
+
+                nb21 = nb21*bs*sizeof(half)/ts;
+                nb22 = nb22*bs*sizeof(half)/ts;
+                nb23 = nb23*bs*sizeof(half)/ts;
+            } else {
+                GGML_ASSERT(V->nb[0] == ts);
+                to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
+                const int64_t s01 = nb21 / ts;
+                const int64_t s02 = nb22 / ts;
+                const int64_t s03 = nb23 / ts;
+                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+                nb21 = V->ne[0] * sizeof(half);
+                nb22 = V->ne[1] * nb21;
+                nb23 = V->ne[2] * nb22;
+            }
+            V_data = (char *) V_f16.ptr;
         }
-        V_data = (char *) V_f16.ptr;
     }
 
-    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
-    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
+    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int gqa_ratio    = Q->ne[2] / K->ne[2];
+    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+    const int ntiles_dst   = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
 
     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -905,37 +919,37 @@ void launch_fattn(
     GGML_ASSERT(max_blocks_per_sm > 0);
     int parallel_blocks = max_blocks_per_sm;
 
+    const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.
+
     dim3 blocks_num;
     if (stream_k) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
         const int max_blocks = max_blocks_per_sm*nsm;
-        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
-        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
+        const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
+        const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
 
-        const int nblocks_stream_k = max_blocks;
+        const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
 
-        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
 
-        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
         }
     } else {
-        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
-
         // parallel_blocks must not be larger than what the tensor size allows:
-        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+        parallel_blocks = std::min(parallel_blocks, ntiles_KV);
 
         // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
         // Test whether parallel_blocks can be set to a higher value for better efficiency.
         const int blocks_per_wave = nsm * max_blocks_per_sm;
         int nwaves_best = 0;
         int efficiency_percent_best = 0;
-        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
-            const int nblocks_total = ntiles_total * parallel_blocks_test;
+        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) {
+            const int nblocks_total = ntiles_dst * parallel_blocks_test;
             const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
             const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
 
@@ -953,7 +967,7 @@ void launch_fattn(
 
         blocks_num.x = ntiles_x;
         blocks_num.y = parallel_blocks;
-        blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
+        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -1001,13 +1015,13 @@ void launch_fattn(
     CUDA_CHECK(cudaGetLastError());
 
     if (stream_k) {
-        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+        if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(DV, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
             flash_attn_stream_k_fixup
                 <<>>
-                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
         }
     } else if (parallel_blocks > 1) {
         const dim3 block_dim_combine(DV, 1, 1);
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 856291dc..fff70c8e 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -98,6 +98,57 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 }
 
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2,  64, 128, 128, 128, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
+
+    // TODO tune specifically for RDNA
+    return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
+    // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64,  8, 128, 2, 128,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 16, 128, 2,  64,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 32, 128, 2,  64,  32,  32,  32, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  32,  32,  32, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80,  8, 128, 2, 128,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 16, 128, 2,  64,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 32, 128, 2,  64,  40,  40,  40, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80,  80, 64, 256, 2,  64,  40,  40,  40, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96,  8, 128, 2, 128,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 16, 128, 2,  64,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 32, 128, 2,  64,  48,  48,  48, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96,  96, 64, 256, 2,  64,  48,  48,  48, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112,  8, 128, 2, 128,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2,  64,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2,  64,  56,  56,  56, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2,  64,  56,  56,  56, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128,  8, 128, 2, 128,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2,  64,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2,  64,  64,  64,  64, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2,  64,  64,  64,  64, 1, true);
+
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256,  8,  64, 4,  64, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16,  64, 4,  32, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 1, true);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2,  32, 128, 128, 128, 1, true);
+
+    // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
+    // compile-time static_asserts even though the kernel guard prevents runtime execution.
+    // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
+    return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
+}
+
 static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
     if (ampere_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -105,6 +156,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
     if (turing_mma_available(cc)) {
         return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
     }
+    if (amd_mfma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
+    }
+    if (amd_wmma_available(cc)) {
+        return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
+    }
     GGML_ASSERT(volta_mma_available(cc));
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
 }
@@ -114,8 +171,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
     return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
 #elif defined(TURING_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(AMD_MFMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
 #elif defined(VOLTA_MMA_AVAILABLE)
     return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#elif defined(AMD_WMMA_AVAILABLE)
+    return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
 #else
     GGML_UNUSED_VARS(DKQ, DV, ncols);
     return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -186,6 +247,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
     return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
 }
 
+static constexpr __device__ int get_cols_per_thread() {
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    return 1; // AMD has a single column per thread.
+#else
+    return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+}
+
+static __host__ int get_cols_per_warp(const int cc) {
+    if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
+        return 16;
+    } else {
+        // Volta
+        return 32;
+    }
+}
+
 // ------------------------------------------------------------------------------------------------------------------
 
 static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -206,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
 template
 static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
     // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
     if constexpr (use_cp_async) {
@@ -217,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
         const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
 
         auto load = [&] __device__ (auto n) {
-            const int stride_k = WARP_SIZE >> n;
-            const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
+            const int stride_k = warp_size >> n;
+            const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
             const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
+            const int stride_i = warp_size / stride_k;
 
             if (k0_start == k0_stop) {
                 return;
@@ -228,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
-                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
                     break;
@@ -236,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
                 }
@@ -252,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
     } else {
         // TODO use ggml_cuda_memcpy_1
         auto load = [&] __device__ (const int n) {
-            const int stride_k = WARP_SIZE >> n;
-            const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+            const int stride_k = warp_size >> n;
+            const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
             const int k0_stop  =                             D2 - D2 % (1*stride_k);
-            const int stride_i = WARP_SIZE / stride_k;
+            const int stride_i = warp_size / stride_k;
 
             if (k0_start == k0_stop) {
                 return;
@@ -263,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
             for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
-                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                 if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
                     break;
@@ -271,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
                 }
@@ -289,18 +368,19 @@ template= 32 ? nbatch_fa * sizeof(half) : 64;
-        constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+        constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
 
         const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
 
 #pragma unroll
         for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
-            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
             const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
             if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
@@ -322,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+            for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
                 const int i = i0 + threadIdx.x;
 
                 tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
             }
         }
-    } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
-        constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+    } else if constexpr (nbatch_fa < 2*warp_size) {
+        constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
         constexpr int stride_j = nwarps * cols_per_warp;
 #pragma unroll
         for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
-            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+            const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
             const int j_vram = fastmodulo(j0 + j_sram, ne01);
 
             if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
                 break;
             }
 
-            const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
+            const int i = threadIdx.x % (warp_size/cols_per_warp);
 
             ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
         }
@@ -355,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
             }
 
 #pragma unroll
-            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+            for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
                 const int i = i0 + 2*threadIdx.x;
 
                 ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
@@ -365,7 +445,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
 }
 
 template
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
@@ -393,11 +473,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const int jt,
         const int kb0,
         const int k_VKQ_sup) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
+    constexpr int  warp_size       = ggml_cuda_get_physical_warp_size();
     constexpr int  ncols           = ncols1 * ncols2;
     constexpr int  cols_per_warp   = T_B_KQ::I;
-    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
-    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  cols_per_thread = get_cols_per_thread();
+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -407,19 +488,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
 
-    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
-    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
+    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
 
     const int k_VKQ_0 = kb0 * nbatch_fa;
 #if defined(TURING_MMA_AVAILABLE)
     T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #else // Volta
     T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
 #endif // defined(TURING_MMA_AVAILABLE)
 
     if constexpr (nstages > 1) {
         static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
-        static_assert(!mla, "multi-stage loading not implemented for MLA");
+        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
         static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
@@ -434,8 +516,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
     }
 
+    // For MLA K and V have the same data.
+    // Therefore, iterate over K in reverse and later re-use the data if possible.
 #pragma unroll
-    for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
+    for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
         const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
         const int k0_diff = k0_stop - k0_start;
 
@@ -461,13 +545,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     if constexpr (cols_per_warp == 8) {
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
                     } else {
-                        // Wide version of KQ_C is column-major => swap A and B.
+                        // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                        // AMD matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
+#else
+                        // swap A and B for CUDA.
                         mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                     }
                 }
             }
         } else {
-            static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
 #pragma unroll
             for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
                 load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +568,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 
-                    // Wide version of KQ_C is column-major => swap A and B.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+                    if constexpr (cols_per_warp == 8) {
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                    } else {
+                        // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                        // AMD matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+                        // swap A and B for CUDA.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    }
                 }
             }
         }
@@ -532,7 +631,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -542,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
             for (int offset = 16; offset >= 4; offset >>= 1) {
-                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
             }
         }
 
@@ -552,8 +657,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
-                    KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
+                    KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
                 } else {
                     KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
                 }
@@ -584,8 +695,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
                     // Turing + Volta:
-                    KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
+                    const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
                 }
             }
         }
@@ -596,14 +712,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             // Values per KQ column are spread across 4 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 1;
-#else
+#elif defined(AMD_MFMA_AVAILABLE)
+            // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
+            constexpr int offset_first = 32;
+            constexpr int offset_last  = 16;
+#elif defined(AMD_WMMA_AVAILABLE)
+            // Values per KQ column are spread across 2 threads:
+            constexpr int offset_first = 16;
+            constexpr int offset_last  = 16;
+#else // Volta
             // Values per KQ column are spread across 2 threads:
             constexpr int offset_first = 2;
             constexpr int offset_last  = 2;
 #endif // defined(TURING_MMA_AVAILABLE)
 #pragma unroll
             for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
-                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
             }
         }
 
@@ -612,10 +736,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
 #pragma unroll
             for (int l = 0; l < T_C_KQ::ne; ++l) {
-                // Turing + Volta:
                 if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
-                    KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    constexpr int KQ_idx = 0;
+#else
+                    // Turing + Volta:
+                    const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
+                    KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
                 } else {
                     KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
                 }
@@ -639,7 +768,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 
 #if defined(TURING_MMA_AVAILABLE)
         if constexpr (cols_per_warp == 8) {
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 #pragma unroll
             for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
@@ -660,6 +789,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+        const half2 KQ_max_scale_h2 = make_half2(
+            KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
 #else // Volta
         const half2 KQ_max_scale_h2 = make_half2(
             KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -688,6 +827,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
     if constexpr (nstages > 1) {
+        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
         // Preload K tile for next iteration:
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
@@ -703,19 +843,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
 
-    // For MLA K and V have the same data.
-    // Therefore, iterate over V in reverse and re-use the data if possible.
-    static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
-    constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
+    T_A_VKQ A_identity;
+    make_identity_mat(A_identity);
+#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
 
     // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
 #pragma unroll
-    for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
-        const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
-        const int i0_diff  = i0_stop - i0_start;
+    for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
+        static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
+        const int i0_stop = i0_start + 2*nbatch_V2;
+        const int i0_diff = i0_stop - i0_start;
 
         if constexpr (nstages <= 1) {
-            if (i0_start < reusable_cutoff) {
+            if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
                 constexpr bool use_cp_async = nstages == 1;
                 flash_attn_ext_f16_load_tile
                     (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
@@ -725,9 +866,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 __syncthreads();
             }
         }
-        const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
+        const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
 
-#if defined(TURING_MMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
         constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
 #pragma unroll
         for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -737,12 +878,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
 
                 T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
+#if defined(LDMATRIX_TRANS_AVAILABLE)
                 load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#elif defined(AMD_MFMA_AVAILABLE)
+                // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
+                // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
+                // Load with transposed addressing: 4 strided half loads.
+                {
+                    const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
+                    const half * xs0_h = (const half *) xs0;
+                    const int stride_h = stride_tile_V * 2; // stride in half units
+                    half * A_h = (half *) A.x;
+#pragma unroll
+                    for (int l = 0; l < 4; ++l) {
+                        A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
+                    }
+                }
+#else
+                // TODO: Try to transpose tile_V when loading gmem to smem.
+                // Use mma to transpose T_A_VKQ for RDNA.
+                T_A_VKQ A_trans;
+                load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+                mma(A, A_trans, A_identity);
+#endif // defined(LDMATRIX_TRANS_AVAILABLE)
                 if constexpr (T_B_KQ::I == 8) {
                     mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
                 } else {
-                    // Wide version of VKQ_C is column-major => swap A and B.
+                    // Wide version of VKQ_C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+                    // AMD matrix C is column-major.
+                    mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
+#else
+                    // swap A and B for CUDA.
                     mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
                 }
             }
         }
@@ -761,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
             }
         }
-#endif // defined(TURING_MMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 
         if constexpr (nstages <= 1) {
             __syncthreads(); // Only needed if tile_K == tile_V.
@@ -774,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         tile_Q, tile_K, tile_V, tile_mask,
         Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
 }
 
 #if defined(TURING_MMA_AVAILABLE)
@@ -794,6 +963,15 @@ template<> struct mma_tile_sizes<8> {
     using T_B_VKQ = tile< 8,  8, half2>; // column-major
     using T_C_VKQ = tile<16,  4, half2>; // row-major
 };
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+template struct mma_tile_sizes {
+    using T_A_KQ  = tile<16,  8, half2>; // row-major
+    using T_B_KQ  = tile<16,  8, half2>; // column-major
+    using T_C_KQ  = tile<16, 16, float>; // column-major
+    using T_A_VKQ = tile<16,  8, half2>; // row-major
+    using T_B_VKQ = tile<16,  8, half2>; // column-major
+    using T_C_VKQ = tile<16,  8, half2>; // column-major
+};
 #else // Volta
 template struct mma_tile_sizes {
     using T_A_KQ  = tile< 8,  4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
@@ -805,7 +983,7 @@ template struct mma_tile_sizes {
 };
 #endif // defined(TURING_MMA_AVAILABLE)
 
-template
+template
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -819,6 +997,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float logit_softcap,
         const uint3 ne01,
         const int ne02,
+        const int gqa_ratio,
         const int ne11,
         const int stride_Q1,
         const int stride_Q2,
@@ -826,11 +1005,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int stride_V,
         const int stride_mask,
         const int jt,
+        const int zt_gqa,
         const int kb0_start,
         const int kb0_stop) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int ncols = ncols1 * ncols2;
     using     T_A_KQ    = typename mma_tile_sizes::T_A_KQ;
     using     T_B_KQ    = typename mma_tile_sizes::T_B_KQ;
@@ -840,8 +1021,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     using     T_C_VKQ   = typename mma_tile_sizes::T_C_VKQ;
 
     constexpr int  cols_per_warp   = T_B_KQ::I;
-    constexpr int  cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
-    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  cols_per_thread = get_cols_per_thread();
+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);
@@ -859,8 +1040,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
 
-    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
-    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
+    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
     constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 
     extern __shared__ half2 tile_Q[];
@@ -871,6 +1051,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     T_B_KQ    Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
 #if defined(TURING_MMA_AVAILABLE)
     T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #else // Volta
     T_C_VKQ VKQ_C[                                     DV/(2*T_C_VKQ::J)];
 #endif // defined(TURING_MMA_AVAILABLE)
@@ -887,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     // The loading is done with decreasing granularity for D for better memory bandwidth.
     const half2 scale_h2 = make_half2(scale, scale);
 #pragma unroll
-    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start  = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+    for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+        const int k0_start  = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
         const int k0_stop   =                             DKQ/2 - (DKQ/2) % (1*stride_k);
-        const int stride_jc = WARP_SIZE / stride_k;
+        const int stride_jc = warp_size / stride_k;
 
         if (k0_start == k0_stop) {
             continue;
@@ -898,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
         for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
-            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
             if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
                 break;
@@ -907,10 +1089,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             const int j = jc / ncols2;
             const int c = jc % ncols2;
 
-            if (jt*ncols1 + j < int(ne01.z)) {
+            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
                     tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
@@ -918,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             } else {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                    const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                     tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
                 }
@@ -962,7 +1144,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
-                
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -971,7 +1153,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
         flash_attn_ext_f16_iter
-            
             (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -982,7 +1164,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
-                
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -991,7 +1173,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         constexpr int  k_VKQ_sup = nbatch_fa;
         flash_attn_ext_f16_iter
-            
             (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1010,6 +1192,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // The partial sums are spread across 8/4 threads.
         constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
         constexpr int offset_last  = cols_per_warp == 8 ?  4 : 1;
+#elif defined(AMD_MFMA_AVAILABLE)
+        // The partial sums are spread across 4 threads (wavefront64, 16 cols).
+        constexpr int offset_first = 32;
+        constexpr int offset_last  = 16;
+#elif defined(AMD_WMMA_AVAILABLE)
+        // The partial sums are spread across 2 threads.
+        constexpr int offset_first = 16;
+        constexpr int offset_last  = 16;
 #else // Volta
         // The partial sums are spread across 2 threads.
         constexpr int offset_first = 2;
@@ -1019,13 +1209,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         for (int col = 0; col < cols_per_thread; ++col) {
 #pragma unroll
             for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
-                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
             }
         }
     }
 
     // If attention sinks are used, potentially re-scale if KQ_max is small.
-    // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
+    // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum
     //     so it's being done unconditionally for every thread.
     if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
         float KQ_max_scale[cols_per_thread];
@@ -1047,7 +1237,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #if defined(TURING_MMA_AVAILABLE)
         if constexpr (cols_per_warp == 8) {
-            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
 #pragma unroll
             for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
 #pragma unroll
@@ -1068,6 +1258,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                 }
             }
         }
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+        const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+        for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+            for (int l = 0; l < T_C_VKQ::ne; ++l) {
+                VKQ_C[i].x[l] *= KQ_max_scale_h2;
+            }
+        }
 #else // Volta
         const int col = (threadIdx.x / 2) % 2;
         const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1119,6 +1318,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
         const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
         const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+        const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
+        const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
+        const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
 #else // Volta
         const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
         const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1149,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Warps with threadIdx.y % np != 0 must NOT return early.
         // All threads must return simultaneously to avoid race conditions with work on the next tile.
 
-        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+        constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
 
-        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
         float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
         float2 meta[nmeta];
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
+            meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
         }
 
         float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -1166,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset < WARP_SIZE) {
-                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+            if (offset < warp_size) {
+                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
             }
         }
 
@@ -1184,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset < WARP_SIZE) {
-                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+            if (offset < warp_size) {
+                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
             }
         }
 
@@ -1194,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         // Write back combined meta data:
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+            if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
                 // Combined KQ max scale + rowsum.
-                meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
+                meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
             }
         }
 
         // Combined KQ max + rowsum.
-        static_assert(cols_per_warp <= WARP_SIZE);
-        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+        static_assert(cols_per_warp <= warp_size);
+        if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
-        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+        if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
             dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
         }
@@ -1254,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
 
 #pragma unroll
-            for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-                const int k0_start  = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+            for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+                const int k0_start  = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
                 const int k0_stop   =                             nbatch_combine - nbatch_combine % (1*stride_k);
-                const int stride_jc = WARP_SIZE / stride_k;
+                const int stride_jc = warp_size / stride_k;
 
                 if (k0_start == k0_stop) {
                     continue;
@@ -1265,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
 #pragma unroll
                 for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
-                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
 
                     if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
                         break;
@@ -1276,14 +1479,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const int j_dst = jc_dst / ncols2;
                     const int c_dst = jc_dst % ncols2;
 
-                    if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
+                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
                         continue;
                     }
 
                     const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
 #pragma unroll
                     for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                        const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+                        const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
 
                         float2 dstk_val = make_float2(0.0f, 0.0f);
 #pragma unroll
@@ -1315,14 +1518,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     }
 #else
     GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
-        scale, slope, logit_softcap, ne01, ne02,
+        scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
         stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
         jt, kb0_start, kb0_stop);
     NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
 }
 
-template
+template
 __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
@@ -1346,13 +1549,20 @@ static __global__ void flash_attn_ext_f16(
                             const int32_t nb21, const int32_t nb22, const int64_t nb23,
                             const int32_t ne31, const int32_t ne32, const int32_t ne33,
                             const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
         NO_DEVICE_CODE;
         return;
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    if (ncols1*ncols2 < 32) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // VOLTA_MMA_AVAILABLE
+
 #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
     if (ncols1*ncols2 > 32) {
         NO_DEVICE_CODE;
@@ -1360,12 +1570,25 @@ static __global__ void flash_attn_ext_f16(
     }
 #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
 
-    static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
+#if defined(AMD_WMMA_AVAILABLE)
+    if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // defined(AMD_WMMA_AVAILABLE)
 
+#if defined(AMD_MFMA_AVAILABLE)
+    if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // defined(AMD_MFMA_AVAILABLE)
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     constexpr int ncols     = ncols1 * ncols2;
     constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
-    constexpr int nwarps    = nthreads / WARP_SIZE;
+    constexpr int nwarps    = nthreads / warp_size;
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
@@ -1374,14 +1597,15 @@ static __global__ void flash_attn_ext_f16(
     const int stride_K    = nb11 / sizeof(half2);
     const int stride_mask = nb31 / sizeof(half);
 
-    const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
+    const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
 
-    const int iter_k = (ne11   + (nbatch_fa - 1)) / nbatch_fa;
-    const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01.z    + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1392,22 +1616,24 @@ static __global__ void flash_attn_ext_f16(
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
 
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-        const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
-        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+        // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+        const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+        const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+        const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+        const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-        const int head0 = zt * ncols2;
+        const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
-        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+        const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+        const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
         const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
             (const half *) (mask + nb33*(sequence % ne33));
-        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+        float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
-        const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+        const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
 
         if (KV_max) {
             kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1415,14 +1641,14 @@ static __global__ void flash_attn_ext_f16(
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile
+            flash_attn_ext_f16_process_tile
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
-            flash_attn_ext_f16_process_tile
+            flash_attn_ext_f16_process_tile
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+                 ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
         }
 
         kbc += iter_k;
@@ -1436,22 +1662,24 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-    const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
-    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
+    const int sequence =  kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
 
-    const int head0 = zt * ncols2;
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
 
-    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02* head0);
-    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+    const float2 * Q_f2   = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+    const half2  * K_h2   = (const half2  *) (K + nb13*sequence + nb12*z_KV);
     const half   * mask_h = ncols2 == 1 && !mask ? nullptr :
         (const half *) (mask + nb33*(sequence % ne33));
-    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+    float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
-    const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+    const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
 
     if (KV_max) {
         kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1459,9 +1687,9 @@ static __global__ void flash_attn_ext_f16(
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile
+    flash_attn_ext_f16_process_tile
         (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+         ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
         max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1473,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
               ne31, ne32, ne33,
               nb31, nb32, nb33);
     NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 }
 
 template 
@@ -1492,10 +1720,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     const bool Q_in_reg       = ggml_cuda_fattn_mma_get_Q_in_reg      (DKQ, DV, ncols, cc);
     const int  nstages        = ggml_cuda_fattn_mma_get_nstages       (DKQ, DV, ncols1, ncols2, cc);
 
-    const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
-    const int nwarps        = nthreads / WARP_SIZE;
+    const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
+    const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
+    const int nwarps         = nthreads / warp_size_host;
 
-    constexpr bool mla = DKQ == 576;
+    constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
 
     const size_t nbytes_shared_KV_1stage = nbatch_fa            * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
     const size_t nbytes_shared_KV_2stage = nbatch_fa            *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
@@ -1512,33 +1741,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
+#if defined(GGML_USE_HIP)
+    using fattn_kernel_ptr_t = const void*;
+#else
+    using fattn_kernel_ptr_t = fattn_kernel_t;
+#endif // defined(GGML_USE_HIP)
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16;
+        fattn_kernel = flash_attn_ext_f16;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16;
+        fattn_kernel = flash_attn_ext_f16;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
         if (!shared_memory_limit_raised[id]) {
-            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
             shared_memory_limit_raised[id] = true;
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#endif // !defined(GGML_USE_MUSA)
     }
 
     launch_fattn
-        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
+        (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
 }
 
 
@@ -1585,3 +1819,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  1, 32);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  2, 32);
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 7c4d6fe6..f3fa80ab 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
 
     return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
 
     return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
 
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
 
@@ -343,7 +351,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
                 for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
                     const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
 
-                    const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
+                    const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
                     ggml_cuda_memcpy_1(
                         tile_KV + i*(J/2 + J_padding) + j,
                         !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
@@ -394,11 +402,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
                     const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
 
                     const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
-                    half2 tmp_h2[cpy_ne/2];
+                    __align__(16) half2 tmp_h2[cpy_ne/2];
                     ggml_cuda_memcpy_1(
                         tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
 
-                    float2 tmp_f2[cpy_ne/2];
+                    __align__(16) float2 tmp_f2[cpy_ne/2];
 #pragma unroll
                     for (int l = 0; l < cpy_ne/2; ++l) {
                         tmp_f2[l] = __half22float2(tmp_h2[l]);
@@ -445,14 +453,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
     static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
 #pragma unroll
     for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
-        half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
-        half2 Q_k[cpw][cpy_ne];
+        __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __align__(16) half2 Q_k[cpw][cpy_ne];
 #else
     static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
 #pragma unroll
     for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
-        float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
-        float Q_k[cpw][cpy_ne];
+        __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __align__(16) float Q_k[cpw][cpy_ne];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -602,9 +610,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #pragma unroll
     for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
 #ifdef FAST_FP16_AVAILABLE
-        half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+        __align__(16) half  tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 #else
-        float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+        __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
 #endif // FAST_FP16_AVAILABLE
 
 #pragma unroll
@@ -664,8 +672,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #ifdef FAST_FP16_AVAILABLE
 #pragma unroll
         for (int k1 = 0; k1 < nbatch_V; k1 += np) {
-            half2 V_k[(DVp/2)/warp_size];
-            half2 KQ_k[cpw];
+            __align__(16) half2 V_k[(DVp/2)/warp_size];
+            __align__(16) half2 KQ_k[cpw];
 
             constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 #pragma unroll
@@ -676,7 +684,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
             for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
                 const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
 
-                half tmp[KQ_cs];
+                __align__(16) half tmp[KQ_cs];
                 ggml_cuda_memcpy_1(
                     &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
 #pragma unroll
@@ -696,8 +704,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
 #else
 #pragma unroll
         for (int k1 = 0; k1 < nbatch_V; k1 += np) {
-            float2 V_k[(DVp/2)/warp_size];
-            float  KQ_k[cpw];
+            __align__(16) float2 V_k[(DVp/2)/warp_size];
+            __align__(16) float  KQ_k[cpw];
 
             constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 #pragma unroll
@@ -821,12 +829,12 @@ static __global__ void flash_attn_tile(
     __shared__ half2 Q_tmp[ncols * DKQ/2];
     __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
     __shared__ half  KQ[ncols * nbatch_fa];
-    half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+    __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 #else
     __shared__ float Q_tmp[ncols * DKQ];
     __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
     __shared__ float KQ[ncols * nbatch_fa];
-    float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+    __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
 #endif // FAST_FP16_AVAILABLE
 
     float KQ_max[cpw];
@@ -849,7 +857,7 @@ static __global__ void flash_attn_tile(
 #pragma unroll
         for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
             if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
-                float tmp_f[cpy_ne_D] = {0.0f};
+                __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
                 ggml_cuda_memcpy_1
                     (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
                                  + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
@@ -860,7 +868,7 @@ static __global__ void flash_attn_tile(
                 }
 
 #ifdef FAST_FP16_AVAILABLE
-                half2 tmp_h2[cpy_ne_D/2];
+                __align__(16) half2 tmp_h2[cpy_ne_D/2];
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
                     tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
@@ -959,7 +967,7 @@ static __global__ void flash_attn_tile(
             constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
 #pragma unroll
             for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
-                half2 tmp[cpy_ne_D];
+                __align__(16) half2 tmp[cpy_ne_D];
                 ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -970,7 +978,7 @@ static __global__ void flash_attn_tile(
             constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
 #pragma unroll
             for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
-                float tmp[cpy_ne_D];
+                __align__(16) float tmp[cpy_ne_D];
                 ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
 #pragma unroll
                 for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
@@ -1033,7 +1041,7 @@ static __global__ void flash_attn_tile(
         constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
 #pragma unroll
         for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
-            float2 tmp[cpy_ne_D];
+            __align__(16) float2 tmp[cpy_ne_D];
 #pragma unroll
             for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
                 tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
@@ -1178,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
+    // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
+    // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
     const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
-    const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
+    const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
     const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
     if constexpr (DV == 512) {
@@ -1187,6 +1197,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
             launch_fattn_tile_switch_ncols1(ctx, dst);
             return;
         }
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1(ctx, dst);
+            return;
+        }
     }
 
     if constexpr (DV <= 256) {
diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh
index 4d167b95..7cbe3263 100644
--- a/ggml/src/ggml-cuda/fattn-vec.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec.cuh
@@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
     return 128;
 }
 
-// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// Currently llvm with the amdgcn target does not support unrolling loops
 // that contain a break that can not be resolved at compile time.
 #ifdef __clang__
 #pragma clang diagnostic push
@@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec(
 #ifdef V_DOT2_F32_F16_AVAILABLE
     half2  Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
 #else
-    float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+    __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
 #endif // V_DOT2_F32_F16_AVAILABLE
     int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
     float2  Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
@@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec(
             for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
                 const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
 
-                float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
+                __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
                 if (ncols == 1 || ic0 + j < int(ne01.z)) {
                     ggml_cuda_memcpy_1(tmp,            &Q_j[i]);
                     ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index 8694fd06..f19defbf 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
     constexpr int frag_m = ncols == 8 ? 32 : 16;
     constexpr int frag_n = ncols == 8 ?  8 : 16;
     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
+    typedef wmma::fragment frag_a_K;
+    typedef wmma::fragment frag_a_V;
+    typedef wmma::fragment frag_b;
+    typedef wmma::fragment                      frag_c_KQ;
+    typedef wmma::fragment                          frag_c_VKQ;
+#else
     typedef wmma::fragment frag_a_K;
     typedef wmma::fragment frag_a_V;
     typedef wmma::fragment frag_b;
     typedef wmma::fragment                      frag_c_KQ;
     typedef wmma::fragment                          frag_c_VKQ;
+#endif
 
     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
 
     __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
     half2 * VKQ2 = (half2 *) VKQ;
+
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
+    const _Float16 * K_h_f16  = reinterpret_cast(K_h);
+    const _Float16 * V_h_f16  = reinterpret_cast(V_h);
+    _Float16       * KQ_f16   = reinterpret_cast<_Float16 *>(KQ);
+    _Float16       * VKQ_f16  = reinterpret_cast<_Float16 *>(VKQ);
+#else
+    const half * K_h_f16  = K_h;
+    const half * V_h_f16  = V_h;
+    half       * KQ_f16   = KQ;
+    half       * VKQ_f16  = VKQ;
+#endif
+
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
     for (int i0 = 0; i0 < D; i0 += 16) {
 #pragma unroll
         for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
         }
     }
 
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
                 wmma::load_matrix_sync(
                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
+                    KQ_f16 + j0*(kqar*kqs_padded) + k,
                     kqar*kqs_padded);
             }
         }
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
                 wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
                     D_padded, wmma::mem_col_major);
             }
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
index cd3bfd40..aaf711a6 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -18,7 +18,7 @@
 #if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
 #define GGML_USE_WMMA_FATTN
 #elif defined(RDNA4)
-#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
+#warning "rocwmma fattn is not supported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
 #endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
 #endif // defined(GGML_HIP_ROCWMMA_FATTN)
 
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 01554066..85c177f4 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -18,12 +18,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
         }
     }
 
-    if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
-        return;
+    if constexpr (ncols2 <= 16) {
+        if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
+            ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
+            return;
+        }
     }
 
-    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
+    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
@@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
 
 template 
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -46,7 +49,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     //     are put into the template specialization without GQA optimizations.
     bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
     for (const ggml_tensor * t : {Q, K, V, mask}) {
-        if (t == nullptr) {
+        if (t == nullptr || ggml_is_quantized(t->type)) {
             continue;
         }
         for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
@@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
-    if (use_gqa_opt && gqa_ratio % 8 == 0) {
+    // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
+    if (cc == GGML_CUDA_CC_VOLTA) {
+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 2 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio > 4) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio % 4 == 0) {
+    if (use_gqa_opt && gqa_ratio > 2) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio % 2 == 0) {
+    if (use_gqa_opt && gqa_ratio > 1) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst);
         return;
     }
@@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
 }
 
 static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -121,8 +146,50 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
 
             GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
             const int gqa_ratio = Q->ne[2] / K->ne[2];
-            GGML_ASSERT(gqa_ratio % 16 == 0);
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+            if (gqa_ratio == 20) { // GLM 4.7 Flash
+                if (cc >= GGML_CUDA_CC_DGX_SPARK) {
+                    if (Q->ne[1] <= 8) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_BLACKWELL) {
+                    if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    if (Q->ne[1] <= 4) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_TURING) {
+                    if (Q->ne[1] <= 4) {
+                        if (K->ne[1] <= 16384) {
+                            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                            break;
+                        }
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                // Volta:
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+            } else if (gqa_ratio % 16 == 0) {
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
+            }
         } break;
         default:
             GGML_ABORT("fatal error");
@@ -230,7 +297,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 
     // The effective batch size for the kernel can be increased by gqa_ratio.
     // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
-    const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    for (const ggml_tensor * t : {Q, K, V, mask}) {
+        if (t == nullptr || ggml_is_quantized(t->type)) {
+            continue;
+        }
+        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+            if (t->nb[i] % 16 != 0) {
+                gqa_opt_applies = false;
+                break;
+            }
+        }
+    }
 
     const int cc = ggml_cuda_info().devices[device].cc;
 
@@ -251,7 +329,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (V->ne[0] != 512) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+            if (!gqa_opt_applies) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;
@@ -337,6 +415,43 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         return BEST_FATTN_KERNEL_WMMA_F16;
     }
 
+    if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+        if (can_use_vector_kernel) {
+            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+                if (Q->ne[1] == 1) {
+                    if (!gqa_opt_applies) {
+                        return BEST_FATTN_KERNEL_VEC;
+                    }
+                }
+            } else {
+                if (Q->ne[1] <= 2) {
+                    return BEST_FATTN_KERNEL_VEC;
+                }
+            }
+        }
+        int gqa_ratio_eff = 1;
+        const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+        while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+            gqa_ratio_eff *= 2;
+        }
+        if (Q->ne[1] * gqa_ratio_eff <= 8) {
+            return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
+        }
+        return BEST_FATTN_KERNEL_MMA_F16;
+    }
+
+    // Use MFMA flash attention for CDNA (MI100+):
+    if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
+        const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
+        // MMA vs tile crossover benchmarked on MI300X @ d32768:
+        //   hsk=64  (gqa=4): MMA wins at eff >= 128 (+11%)
+        //   hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%)
+        if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) {
+            return BEST_FATTN_KERNEL_MMA_F16;
+        }
+        // Fall through to tile kernel for small effective batch sizes.
+    }
+
     // If there are no tensor cores available, use the generic tile kernel:
     if (can_use_vector_kernel) {
         if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
new file mode 100644
index 00000000..1ce6d5f3
--- /dev/null
+++ b/ggml/src/ggml-cuda/gated_delta_net.cu
@@ -0,0 +1,263 @@
+#include "gated_delta_net.cuh"
+
+template 
+__global__ void gated_delta_net_cuda(const float * q,
+                                     const float * k,
+                                     const float * v,
+                                     const float * g,
+                                     const float * beta,
+                                     const float * curr_state,
+                                     float *       dst,
+                                     int64_t       H,
+                                     int64_t       n_tokens,
+                                     int64_t       n_seqs,
+                                     int64_t       sq1,
+                                     int64_t       sq2,
+                                     int64_t       sq3,
+                                     int64_t       sv1,
+                                     int64_t       sv2,
+                                     int64_t       sv3,
+                                     int64_t       sb1,
+                                     int64_t       sb2,
+                                     int64_t       sb3,
+                                     const uint3   neqk1_magic,
+                                     const uint3   rq3_magic,
+                                     float         scale) {
+    const uint32_t h_idx    = blockIdx.x;
+    const uint32_t sequence = blockIdx.y;
+    // each warp owns one column, using warp-level primitives to reduce across rows
+    const int      lane     = threadIdx.x;
+    const int      col      = blockIdx.z * blockDim.y + threadIdx.y;
+
+    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+    const uint32_t iq3 = fastdiv(sequence, rq3_magic);
+
+    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+    float *       attn_data        = dst;
+    float *       state            = dst + attn_score_elems;
+
+    const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+    state += state_offset;
+    curr_state += state_offset;
+    attn_data += (sequence * n_tokens * H + h_idx) * S_v;
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
+    static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+    float         s_shard[rows_per_lane];
+    // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
+#pragma unroll
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i = r * warp_size + lane;
+        s_shard[r]  = curr_state[col * S_v + i];
+    }
+
+    for (int t = 0; t < n_tokens; t++) {
+        const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
+
+        const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
+        const float * beta_t = beta + gb_offset;
+        const float * g_t    = g    + gb_offset * (KDA ? S_v : 1);
+
+        const float beta_val = *beta_t;
+
+        if constexpr (!KDA) {
+            const float g_val = expf(*g_t);
+
+            // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
+            float kv_shard = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += s_shard[r] * k_t[i];
+            }
+            float kv_col = warp_reduce_sum(kv_shard);
+
+            // delta[col] = (v[col] - g * kv[col]) * beta
+            float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
+
+            // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_partial = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
+            }
+
+            float attn_col = warp_reduce_sum(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
+        } else {
+            // kv[col] = sum_i g[i] * S[i][col] * k[i]
+            float kv_shard = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
+            }
+
+            float kv_col = warp_reduce_sum(kv_shard);
+
+            // delta[col] = (v[col] - kv[col]) * beta
+            float delta_col = (v_t[col] - kv_col) * beta_val;
+
+            // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_partial = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
+            }
+
+            float attn_col = warp_reduce_sum(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
+        }
+
+        attn_data += S_v * H;
+    }
+
+    // Write state back to global memory (transposed layout)
+#pragma unroll
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i          = r * warp_size + lane;
+        state[col * S_v + i] = s_shard[r];
+    }
+}
+
+template 
+static void launch_gated_delta_net(
+        const float * q_d, const float * k_d, const float * v_d,
+        const float * g_d, const float * b_d, const float * s_d,
+        float * dst_d,
+        int64_t S_v,   int64_t H, int64_t n_tokens, int64_t n_seqs,
+        int64_t sq1,   int64_t sq2, int64_t sq3,
+        int64_t sv1,   int64_t sv2, int64_t sv3,
+        int64_t sb1,   int64_t sb2, int64_t sb3,
+        int64_t neqk1, int64_t rq3,
+        float scale, cudaStream_t stream) {
+    //TODO: Add chunked kernel for even faster pre-fill
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
+    const int num_warps = 4;
+    dim3      grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+    dim3      block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
+
+    const uint3 neqk1_magic = init_fastdiv_values(neqk1);
+    const uint3 rq3_magic   = init_fastdiv_values(rq3);
+
+    int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+    switch (S_v) {
+        case 16:
+            gated_delta_net_cuda<16, KDA><<>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+            break;
+        case 32:
+            gated_delta_net_cuda<32, KDA><<>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+            break;
+        case 64: {
+            gated_delta_net_cuda<64, KDA><<>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+            break;
+        }
+        case 128: {
+            gated_delta_net_cuda<128, KDA><<>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+            break;
+        }
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * src_q     = dst->src[0];
+    ggml_tensor * src_k     = dst->src[1];
+    ggml_tensor * src_v     = dst->src[2];
+    ggml_tensor * src_g     = dst->src[3];
+    ggml_tensor * src_beta  = dst->src[4];
+    ggml_tensor * src_state = dst->src[5];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
+
+    const int64_t S_v      = nev0;
+    const int64_t H        = nev1;
+    const int64_t n_tokens = nev2;
+    const int64_t n_seqs   = nev3;
+
+    const bool kda = (src_g->ne[0] == S_v);
+
+    GGML_ASSERT(neq1 == nek1);
+    const int64_t neqk1 = neq1;
+
+    const int64_t rq3 = nev3 / neq3;
+
+    const float * q_d = (const float *) src_q->data;
+    const float * k_d = (const float *) src_k->data;
+    const float * v_d = (const float *) src_v->data;
+    const float * g_d = (const float *) src_g->data;
+    const float * b_d = (const float *) src_beta->data;
+
+    const float * s_d   = (const float *) src_state->data;
+    float *       dst_d = (float *) dst->data;
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+    GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
+    GGML_ASSERT(src_g->ne[0] == 1 || kda);
+    GGML_ASSERT(ggml_is_contiguous(src_g));
+    GGML_ASSERT(ggml_is_contiguous(src_beta));
+    GGML_ASSERT(ggml_is_contiguous(src_state));
+
+    // strides in floats (beta strides used for both g and beta offset computation)
+    const int64_t sq1 = nbq1 / sizeof(float);
+    const int64_t sq2 = nbq2 / sizeof(float);
+    const int64_t sq3 = nbq3 / sizeof(float);
+    const int64_t sv1 = nbv1 / sizeof(float);
+    const int64_t sv2 = nbv2 / sizeof(float);
+    const int64_t sv3 = nbv3 / sizeof(float);
+    const int64_t sb1 = nbb1 / sizeof(float);
+    const int64_t sb2 = nbb2 / sizeof(float);
+    const int64_t sb3 = nbb3 / sizeof(float);
+
+    const float scale = 1.0f / sqrtf((float) S_v);
+
+    cudaStream_t stream = ctx.stream();
+
+    if (kda) {
+        launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
+    } else {
+        launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
+    }
+}
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh
new file mode 100644
index 00000000..7375e81c
--- /dev/null
+++ b/ggml/src/ggml-cuda/gated_delta_net.cuh
@@ -0,0 +1,4 @@
+#include "common.cuh"
+#include "ggml.h"
+
+void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c3ee2ea0..5a0be4a4 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -53,6 +53,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/gated_delta_net.cuh"
 #include "ggml-cuda/set.cuh"
 #include "ggml-cuda/set-rows.cuh"
 #include "ggml-cuda/pad_reflect_1d.cuh"
@@ -70,17 +71,18 @@
 #include 
 #include 
 #include 
-#include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
-#include 
-#include 
+#include 
+#include 
+#include 
 #include 
 #include 
+#include 
 
 static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 
@@ -122,7 +124,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
         err = cudaMallocManaged(ptr, size);
 #if defined(GGML_USE_HIP)
         if (err == hipSuccess) {
-            CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+            // hipMemAdviseSetCoarseGrain is an optional performance hint;
+            // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs).
+            cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device);
+            (void)hipGetLastError(); // clear any error
         }
 
         // fall back to cudaMalloc if not supported (e.g. on Windows)
@@ -203,7 +208,14 @@ static ggml_cuda_device_info ggml_cuda_init() {
     GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
 
     int64_t total_vram = 0;
-    GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+    for (int id = 0; id < info.device_count; ++id) {
+        cudaDeviceProp prop;
+        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+        total_vram += prop.totalGlobalMem;
+    }
+    GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n",
+                  __func__, info.device_count, (size_t)(total_vram / (1024 * 1024)));
+    total_vram = 0;
 
     std::vector> turing_devices_without_mma;
     for (int id = 0; id < info.device_count; ++id) {
@@ -241,6 +253,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
 #else
         info.devices[id].supports_cooperative_launch = false;
 #endif // !(GGML_USE_MUSA)
+
 #if defined(GGML_USE_HIP)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
 
@@ -255,22 +268,25 @@ static ggml_cuda_device_info ggml_cuda_init() {
                 info.devices[id].cc += prop.minor * 0x10;
             }
         }
-        GGML_LOG_INFO("  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
+        GGML_LOG_INFO("  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n",
                       id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
-                      device_vmm ? "yes" : "no", prop.warpSize);
+                      device_vmm ? "yes" : "no", prop.warpSize,
+                      (size_t)(prop.totalGlobalMem / (1024 * 1024)));
 #elif defined(GGML_USE_MUSA)
         // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
         info.devices[id].warp_size = 32;
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
         info.devices[id].cc += prop.minor * 0x10;
-        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
-                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
+                      id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
+                      (size_t)(prop.totalGlobalMem / (1024 * 1024)));
 #else
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
-        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
-                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
+                      id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
+                      (size_t)(prop.totalGlobalMem / (1024 * 1024)));
         std::string device_name(prop.name);
         if (device_name == "NVIDIA GeForce MX450") {
             turing_devices_without_mma.push_back({ id, device_name });
@@ -285,6 +301,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         // TODO: Check for future drivers the default scheduling strategy and
         // remove this call again when cudaDeviceScheduleSpin is default.
         if (prop.major == 12 && prop.minor == 1) {
+            CUDA_CHECK(cudaSetDevice(id));
             CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
         }
 
@@ -1224,6 +1241,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     }
 }
 
+struct cublas_force_compute_type {
+    bool fp32 = false;
+    bool fp16 = false;
+};
+
+static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
+    static const cublas_force_compute_type compute_type = [] {
+        cublas_force_compute_type result;
+
+        const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
+        const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
+
+        GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
+
+        if (ggml_cuda_force_cublas_compute_32f_env) {
+            GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
+            result.fp32 = true;
+        } else if (ggml_cuda_force_cublas_compute_16f_env) {
+            GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
+            result.fp16 = true;
+        }
+
+        return result;
+    }();
+
+    return compute_type;
+}
+
 static void ggml_cuda_op_mul_mat_cublas(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -1306,7 +1351,13 @@ static void ggml_cuda_op_mul_mat_cublas(
 
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
 
-        if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+        const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
+
+        if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
+                                        || GGML_CUDA_CC_IS_RDNA4(cc)
+                                        || cc == GGML_CUDA_CC_VOLTA
+                                        || force_compute_type.fp32))
+        {
             const float alpha = 1.0f;
             const float beta = 0.0f;
             CUBLAS_CHECK(
@@ -1905,10 +1956,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
     cudaDataType_t cu_data_type_b = traits::data_type;
     const void * alpha = traits::get_alpha();
     const void * beta = traits::get_beta();
-    const float alpha_f32 = 1.0f;
-    const float beta_f32 = 0.0f;
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+    const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
+
+    int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
+
+    // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
+    // so checking necessity of forced fp32 only for fp16 src0_type
+    static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
+
+    const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
+                                                                                  || GGML_CUDA_CC_IS_RDNA4(cc)
+                                                                                  || cc == GGML_CUDA_CC_VOLTA
+                                                                                  || force_compute_type.fp32);
+
+    if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
         if constexpr (src0_type == GGML_TYPE_F32) {
             dst_t = (char *) dst_ddf;  // Direct F32 output
         } else {
@@ -1918,18 +1982,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
         }
     } else {
         dst_t = (char *) dst_ddf;
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-        cu_data_type = CUDA_R_32F;
-        alpha = &alpha_f32;
-        beta = &beta_f32;
-    }
-
-    int id = ggml_cuda_get_device();
-    const int cc = ggml_cuda_info().devices[id].cc;
-    if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-        alpha = &alpha_f32;
-        beta = &beta_f32;
+        cu_compute_type = batched_mul_mat_traits::compute_type;
+        cu_data_type = batched_mul_mat_traits::data_type;
+        alpha = batched_mul_mat_traits::get_alpha();
+        beta = batched_mul_mat_traits::get_beta();
     }
 
     GGML_ASSERT(ne12 % ne02 == 0);
@@ -2277,14 +2333,21 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
+    // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
     if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        if (ne2 == 1) {
+        static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
+        if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
             if (ggml_is_quantized(src0->type)) {
-                ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+                if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
+                    ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+                    return;
+                }
             } else {
-                ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+                if (GGML_CUDA_CC_IS_AMD(cc)) {
+                    ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+                    return;
+                }
             }
-            return;
         }
 
         if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
@@ -2298,6 +2361,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         }
     }
 
+    // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization
+    // TODO: add asserts to verify this. should work with CUDA, HIP, etc.
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(nb12 % nb11 == 0);
@@ -2723,6 +2788,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_cuda_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_GATED_DELTA_NET:
+            ggml_cuda_op_gated_delta_net(ctx, dst);
+            break;
         case GGML_OP_RWKV_WKV7:
             ggml_cuda_op_rwkv_wkv7(ctx, dst);
             break;
@@ -2858,14 +2926,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
     bool use_cuda_graph = true;
     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
 
-    const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
-    const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
-    const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
-    const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
-    const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
-    const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
-    const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
-
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -2880,30 +2940,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 #endif
         }
 
-        if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
-            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
-#endif
-        }
-
-        if (node->op == GGML_OP_ADD &&
-            node->src[1] && node->src[1]->ne[1] > 1 &&
-            (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
-            (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
-            strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
-            strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
-            strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
-            strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
-            strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
-            // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
-            // by means of matching node names. See
-            // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
-            // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
-            // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
+        // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
+        if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
+            // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
+            // TODO: figure out a way to enable for larger batch sizes, without hurting performance
+            // ref: https://github.com/ggml-org/llama.cpp/pull/18958
             use_cuda_graph = false;
 #ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
 #endif
         }
 
@@ -2916,21 +2960,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
 }
 
 static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
-    props->node_address = node->data;
+    memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
+    props->node_data = node->data;
     props->node_op = node->op;
+    props->node_type = node->type;
+    props->flags = node->flags;
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         props->ne[i] = node->ne[i];
         props->nb[i] = node->nb[i];
     }
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+        if (!node->src[i]) {
+            continue;
+        }
+
+        props->src_data[i] = node->src[i]->data;
     }
     memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
 }
 
 static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
-    if (node->data != props->node_address &&
-          node->op != GGML_OP_VIEW) {
+    if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
         return false;
     }
 
@@ -2938,6 +2988,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
         return false;
     }
 
+    if (node->type != props->node_type) {
+        return false;
+    }
+
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         if (node->ne[i] != props->ne[i]) {
             return false;
@@ -2947,73 +3001,104 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
         }
     }
 
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != props->src_address[i] &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
+    if (node->op != GGML_OP_VIEW) {
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (!node->src[i]) {
+                if (props->src_data[i] != nullptr) {
+                    return false;
+                }
+                continue;
+            }
+
+            if (node->src[i]->data != props->src_data[i]) {
+                return false;
+            }
         }
     }
 
-    if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
-        memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+    if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+        return false;
+    }
+
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
         return false;
     }
 
     return true;
 }
 
-static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
+static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
+    return cgraph->nodes[0];
+}
 
+static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
     bool res = false;
 
-    if (cuda_ctx->cuda_graph->instance == nullptr) {
-        res = true;
-    }
+    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
     // Check if the graph size has changed
-    if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
+    if (graph->props.size() != (size_t)cgraph->n_nodes) {
         res = true;
-        cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
+        graph->props.resize(cgraph->n_nodes);
     }
 
     // Loop over nodes in GGML graph to determine if CUDA graph update is required
     // and store properties to allow this comparison for the next token
+    std::unordered_set seen_node;
+    std::vector srcs_extra;
     for (int i = 0; i < cgraph->n_nodes; i++) {
         bool props_match = true;
+
+        seen_node.insert(cgraph->nodes[i]);
+
         if (!res) {
-            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
+            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
         }
         if (!props_match) {
             res = true;
         }
-        ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
+        ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
+
+        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+            ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
+            if (src && seen_node.find(src) == seen_node.end()) {
+                srcs_extra.push_back(src);
+            }
+        }
     }
 
-    for (int i = 0; i < cgraph->n_leafs; i++) {
-        bool props_match= true;
+    if (graph->extra.size() != (size_t) srcs_extra.size()) {
+        res = true;
+        graph->extra.resize(srcs_extra.size());
+    }
+
+    for (size_t i = 0; i < srcs_extra.size(); ++i) {
+        bool props_match = true;
+
         if (!res) {
-            props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
+            props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
         }
+
         if (!props_match) {
             res = true;
         }
-        ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
+        ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
     }
 
     return res;
 }
 
-static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
+static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
 #if CUDART_VERSION >= 12000
     cudaGraphExecUpdateResultInfo result_info;
-    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
 #else
     cudaGraphNode_t errorNode;
     cudaGraphExecUpdateResult result_info;
-    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
 #endif // CUDART_VERSION >= 12000
 
     if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -3024,14 +3109,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
         // The pre-existing graph exec cannot be updated due to violated constraints
         // so instead clear error and re-instantiate
         (void)cudaGetLastError();
-        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-        cuda_ctx->cuda_graph->instance = nullptr;
-        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
+        graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
     } else {
         GGML_ASSERT(stat == cudaSuccess);
     }
 }
-#endif
+#endif // USE_CUDA_GRAPH
 
 static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
                                                 const ggml_tensor * view,
@@ -3067,63 +3152,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
     return true;
 }
 
-static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) {
+static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
+    args.sigmoid         = false;
+    args.softmax         = false;
+    args.delayed_softmax = false;
+    args.prob_bias       = false;
+    args.norm            = false;
+
+    const int      n_nodes = cgraph->n_nodes;
+    ggml_tensor ** nodes   = cgraph->nodes;
+
+    if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
+        args.softmax = true;
+    }
+
+    if (nodes[node_idx]->op == GGML_OP_UNARY) {
+        if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
+            return false;
+        }
+        args.sigmoid = true;
+    }
+
+    if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
+        args.delayed_softmax = true;
+    }
+
+    node_idx++;
+
+    if (args.sigmoid || args.softmax) {
+        // SOFTMAX -> RESHAPE
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
+                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        ggml_tensor * probs_reshaped = nodes[node_idx];
+        node_idx++;
+
+        if (node_idx >= n_nodes) {
+            return false;
+        }
+
+        // src of bias add is the unreshaped probs (-2 instead of -1)
+        if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
+            args.prob_bias = true;
+            node_idx++;
+        }
+        // RESHAPE/ADD -> ARGSORT
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
+            return false;
+        }
+
+        if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
+            return false;
+        }
+
+        node_idx++;
+
+        // ARGSORT-> VIEW
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
+            return false;
+        }
+
+        // GET_ROWS
+        if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+    } else if (args.delayed_softmax) {
+        if (node_idx - 2 < 0) {
+            return false;
+        }
+        ggml_tensor * probs_reshaped = nodes[node_idx - 2];
+
+        // VIEW->ARGSORT
+        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+            nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            return false;
+        }
+        node_idx++;
+
+        // GET_ROWS
+        if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+                nodes[node_idx]->src[0] != probs_reshaped) {
+            return false;
+        }
+        node_idx++;
+
+        static const std::vector remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+        for (const ggml_op op : remaining_ops) {
+            if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+                return false;
+            }
+            node_idx++;
+        }
+    }
+
+    // At this point we can check for norm + scale. Everything is now at least valid till the norm
+    if (node_idx >= n_nodes) {
+        return true;
+    }
+
+    if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
+        //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
+        static const std::vector norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
+
+        args.norm = true;
+        for (const ggml_op op : norm_ops) {
+            if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+                node_idx++;
+            } else {
+                args.norm = false;
+                return true;
+            }
+        }
+
+        // DIV <- CLAMP, RESHAPE
+        if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+            nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
+            args.norm = false;
+            return true;
+        }
+        node_idx++;
+
+        if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+            args.norm = false;
+            return true;
+        }
+
+        node_idx++;
+    }
+
+    if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+        args.scale = true;
+    }
+
+    return true;
+}
+
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph *                cgraph,
+                               int                                       node_idx,
+                               std::initializer_list       ops,
+                               std::initializer_list unary_ops) {
 #ifndef NDEBUG
     const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
     GGML_ASSERT(unary_ops.size() == num_unary);
 #endif
 
-    //TODO: remove special case once ggml_can_fuse can handle empty nodes
-    std::initializer_list topk_moe_ops =
-        ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
-    std::initializer_list topk_moe_ops_with_norm =
-        ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
-    std::initializer_list topk_moe_ops_delayed_softmax =
-        ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
-
     const auto is_equal = [](const std::initializer_list & list1,
                              const std::initializer_list & list2) {
         return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
     };
 
-    if (is_equal(topk_moe_ops_with_norm, ops) &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 9];
-        ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
-        ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
-        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
-
-        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
-            return true;
-        }
-    }
-
-    if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 4];
-        ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
-        ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
-        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
-
-        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
-            return true;
-        }
-    }
-
-    if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
-        ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
-        ggml_tensor * weights = cgraph->nodes[node_idx + 5];
-        ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
-        ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
-        int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
-
-        if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
-            return true;
-        }
-    }
-
     std::initializer_list mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
     std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
 
@@ -3200,7 +3388,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
             return false;
         }
 
-        //rms_norm kernel assumes contigous rows
+        //rms_norm kernel assumes contiguous rows
         if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
             return false;
         }
@@ -3212,6 +3400,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         return true;
     }
 
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY
+     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
+        const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
+        const ggml_tensor * silu     = cgraph->nodes[node_idx+1];
+
+        if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
+            return false;
+        }
+
+        return true;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
+     && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
+        const ggml_tensor * unary = cgraph->nodes[node_idx];
+        const ggml_tensor * mul   = cgraph->nodes[node_idx+1];
+
+        if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) {
+            return false;
+        }
+
+        if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) {
+            return false;
+        }
+
+        if (unary->type != mul->type) {
+            return false;
+        }
+
+        const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0];
+        if (other->type != unary->type) {
+            return false;
+        }
+        if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) {
+            return false;
+        }
+
+        return true;
+    }
+
     if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
      && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
         const ggml_tensor *scale  = cgraph->nodes[node_idx];
@@ -3236,7 +3464,70 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     return false;
 }
 
-static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
+// returns whether the write (out) nodes overwrite the read nodes in operation
+static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph,
+                                                 int           node_idx,
+                                                 int           node_count,
+                                                 int *         out_nodes,
+                                                 int           out_count) {
+    auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) {
+        const int64_t a_start = (int64_t) a->data;
+        const int64_t a_end   = a_start + ggml_nbytes(a);
+
+        const int64_t b_start = (int64_t) b->data;
+        const int64_t b_end   = b_start + ggml_nbytes(b);
+
+        if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) {
+            return true;
+        }
+
+        return false;
+    };
+
+    bool is_ok = true;
+    // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok
+    if (ggml_nrows(cgraph->nodes[node_idx]) == 1) {
+        return true;
+    }
+
+    for (int i = 0; i < out_count; ++i) {
+        const ggml_tensor * dst = cgraph->nodes[out_nodes[i]];
+
+        for (int j = node_idx; j < node_idx + node_count; ++j) {
+            // Loop over all srcs of all nodes in the fusion. If the src overlaps
+            // the destination and the src is not an intermediate node that's being
+            // elided, then disable fusion.
+
+            for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+                const ggml_tensor * src = cgraph->nodes[j]->src[src_idx];
+
+                if (!src || src->op == GGML_OP_NONE) {
+                    continue;
+                }
+
+                if (nodes_overlap(dst, src)) {
+                    bool found = false;
+
+                    for (int k = node_idx; k < j; ++k) {
+                        if (cgraph->nodes[k] == src) {
+                            found = true;
+                            break;
+                        }
+                    }
+
+                    if (!found) {
+                        is_ok = false;
+                        break;
+                    }
+                }
+            }
+        }
+    }
+
+    return is_ok;
+}
+
+static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
     bool graph_evaluated_or_captured = false;
 
     // flag used to determine whether it is an integrated_gpu
@@ -3378,39 +3669,84 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
                     continue;
                 }
 
+                if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+                    continue;
+                }
 
                 // start of fusion operations
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
                 if (!disable_fusion) {
+                    ggml_cuda_topk_moe_args args;
 
-                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
-                        ggml_tensor * weights          = cgraph->nodes[i + 9];
-                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
-                        ggml_tensor * clamp            = cgraph->nodes[i + 7];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
-                                              /*delayed softmax*/ false, clamp);
-                        i += 9;
-                        continue;
-                    }
+                    if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
+                        cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
+                        const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
 
-                    if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
-                        ggml_tensor * weights          = cgraph->nodes[i + 4];
-                        ggml_tensor * selected_experts = cgraph->nodes[i + 3];
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
-                                              /*delayed softmax*/ false);
-                        i += 4;
-                        continue;
-                    }
+                        std::vector ops;
 
-                    if (ggml_cuda_can_fuse(cgraph, i,
-                                           ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
-                        ggml_tensor * weights = cgraph->nodes[i + 5];
-                        ggml_tensor * ids     = cgraph->nodes[i + 1];
+                        if (can_fuse) {
+                            const ggml_tensor * logits  = node->src[0];
+                            ggml_tensor *       weights = nullptr;
+                            ggml_tensor *       ids     = nullptr;
+                            const ggml_tensor * bias    = nullptr;
+                            const ggml_tensor * clamp   = nullptr;
+                            const ggml_tensor * scale   = nullptr;
 
-                        ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
-                                              /*delayed_softmax*/ true);
-                        i += 5;
-                        continue;
+                            if (!args.delayed_softmax) {
+                                ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
+                                int     out_nodes[2];  // nodes which can't be elided
+
+                                if (args.prob_bias) {
+                                    bias = cgraph->nodes[i + 2]->src[1];
+                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
+                                                            GGML_OP_VIEW, GGML_OP_GET_ROWS });
+                                    out_nodes[0] = i + 4;
+                                    ids          = cgraph->nodes[i + 4];
+                                } else {
+                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
+                                                            GGML_OP_GET_ROWS });
+                                    out_nodes[0] = i + 3;
+                                    ids          = cgraph->nodes[i + 3];
+                                }
+
+                                if (args.norm) {
+                                    ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+                                                            GGML_OP_DIV, GGML_OP_RESHAPE });
+                                    clamp = cgraph->nodes[i + ops.size() - 3];
+                                }
+                                if (args.scale) {
+                                    ops.insert(ops.end(), { GGML_OP_SCALE });
+                                    scale = cgraph->nodes[i + ops.size() - 1];
+                                }
+
+                                weights      = cgraph->nodes[i + ops.size() - 1];
+                                out_nodes[1] = i + ops.size() - 1;
+
+                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+                                    ggml_cuda_should_use_topk_moe(node, logits, weights, ids) &&
+                                    ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
+                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+                                    i += ops.size() - 1;
+                                    continue;
+                                }
+                            } else if (!args.norm && !args.prob_bias) {
+                                //special case gpt-oss, no norm, no bias.
+                                ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
+                                                        GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
+                                weights                     = cgraph->nodes[i + 5];
+                                ids                         = cgraph->nodes[i + 1];
+                                const ggml_tensor * softmax = cgraph->nodes[i + 4];
+
+                                int out_nodes[2] = { i + 1, i + 5 };
+                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+                                    ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) &&
+                                    ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) {
+                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+                                    i += ops.size() - 1;
+                                    continue;
+                                }
+                            }
+                        }
                     }
 
                     if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
@@ -3442,11 +3778,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
                         n_fuse++;
 
                         if (n_fuse > 1) {
+                            ggml_tensor fused_add_node;
+                            memcpy(&fused_add_node, node, sizeof(ggml_tensor));
                             for (int j = 0; j < n_fuse - 1; ++j) {
-                                node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+                                fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
                             }
-                            cgraph->nodes[i + n_fuse - 1]->data = node->data;
-                            ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+                            fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
+                            ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
                             i += n_fuse - 1;
 
                             continue;
@@ -3655,6 +3993,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
                         continue;
                     }
 
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
+                        ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]);
+                        i++;
+                        continue;
+                    }
+
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) ||
+                        ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) ||
+                        ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) {
+                        ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]);
+                        i++;
+                        continue;
+                    }
+
                     if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
                         i += 2;
                         ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
@@ -3687,13 +4039,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
         }
 
 #ifdef USE_CUDA_GRAPH
+        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
         if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
-            if (cuda_ctx->cuda_graph->graph != nullptr) {
-                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
-                cuda_ctx->cuda_graph->graph = nullptr;
+            if (graph->graph != nullptr) {
+                CUDA_CHECK(cudaGraphDestroy(graph->graph));
+                graph->graph = nullptr;
             }
 
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
 
             std::lock_guard lock(ggml_cuda_lock);
@@ -3706,41 +4059,38 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
     }
 
     if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+        if (graph->instance == nullptr) { // Create executable graph from captured graph.
+            CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
         }
         if (cuda_graph_update_required) { // Update graph executable
-            ggml_cuda_graph_update_executable(cuda_ctx);
+            ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
         }
         // Launch graph
-        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+        CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
 #else
+        GGML_UNUSED(graph_key);
         graph_evaluated_or_captured = true;
 #endif  // USE_CUDA_GRAPH
     }
 }
 
-static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
-
 #ifdef USE_CUDA_GRAPH
+static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
-
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
+    if (graph->graph == nullptr) {
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+            if (!graph->disable_due_to_gpu_arch) {
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+            }
+            graph->disable_due_to_gpu_arch = true;
         }
     }
 
-    return cuda_ctx->cuda_graph->is_enabled();
-#else
-    GGML_UNUSED(cuda_ctx);
-    return false;
-#endif // USE_CUDA_GRAPH
+    return graph->is_enabled();
 }
+#endif // USE_CUDA_GRAPH
 
 static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
@@ -3749,15 +4099,40 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
     bool use_cuda_graph             = false;
     bool cuda_graph_update_required = false;
+    const void * graph_key = nullptr;
 
 #ifdef USE_CUDA_GRAPH
-    use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
+    graph_key = ggml_cuda_graph_get_key(cgraph);
 
-    if (cuda_ctx->cuda_graph->is_enabled()) {
-        cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
-        use_cuda_graph             = ggml_cuda_graph_check_compability(cgraph);
+    ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
 
-        cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+    if (graph->is_enabled()) {
+        const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph);
+        if (graph_compatible) {
+            const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
+
+            if (!graph->warmup_complete) {
+                // Warmup: need at least 2 calls with no property change on the 2nd call
+                if (!properties_changed) {
+                    graph->warmup_complete = true;
+                    GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__);
+                    use_cuda_graph = true;
+                    cuda_graph_update_required = true;
+                }
+                // else: properties changed or first call - execute directly (use_cuda_graph stays false)
+            } else {
+                // Post-warmup: normal CUDA graph operation
+                if (properties_changed) {
+                    // Properties changed - reset warmup, execute directly until stable again
+                    graph->warmup_complete = false;
+                    GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__);
+                } else {
+                    use_cuda_graph = true;
+                    cuda_graph_update_required = graph->instance == nullptr;
+                }
+            }
+        }
     }
 #endif // USE_CUDA_GRAPH
 
@@ -3771,7 +4146,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
-    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
+    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
 
     return GGML_STATUS_SUCCESS;
 }
@@ -3804,7 +4179,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
 static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 
-    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
+#ifdef USE_CUDA_GRAPH
+    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
+#else
+    const bool use_cuda_graph = false;
+    GGML_UNUSED(cuda_ctx);
+    GGML_UNUSED(cgraph);
+#endif
 
     static bool enable_graph_optimization = [] {
         const char * env     = getenv("GGML_CUDA_GRAPH_OPT");
@@ -4335,6 +4717,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_CEIL:
                 case GGML_UNARY_OP_ROUND:
                 case GGML_UNARY_OP_TRUNC:
+                    // TODO: should become:
+                    //return ggml_is_contiguous_rows(op->src[0]);
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -4551,7 +4935,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_L2_NORM:
             return true;
         case GGML_OP_RMS_NORM_BACK:
-            return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
+            return ggml_is_contiguous(op->src[0]);
             break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -4613,8 +4997,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONV_2D_DW:
         case GGML_OP_CONV_TRANSPOSE_2D:
         case GGML_OP_POOL_2D:
-        case GGML_OP_ACC:
             return true;
+        case GGML_OP_ACC:
+            // TODO: extend support like so:
+            //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]);
+            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
         case GGML_OP_SUM:
             return ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_TOP_K:
@@ -4627,8 +5014,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_PAD:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_PAD:
+            return true;
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_ARANGE:
@@ -4638,6 +5026,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_RWKV_WKV7:
             return true;
+        case GGML_OP_GATED_DELTA_NET:
+            //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327
+#ifdef GGML_USE_MUSA
+            return false;
+#else
+            return true;
+#endif // GGML_USE_MUSA
         case GGML_OP_FLASH_ATTN_EXT:
             return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
         case GGML_OP_CROSS_ENTROPY_LOSS:
diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu
index 60542fc1..49af5389 100644
--- a/ggml/src/ggml-cuda/mean.cu
+++ b/ggml/src/ggml-cuda/mean.cu
@@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 #endif // USE_CUDA_GRAPH
     if ((nrows == 1) &&
 #ifdef USE_CUDA_GRAPH
-            // CUDA_GRAPHS_DISABLED
-            ((ncols > 65536) &&
-             ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
-              ctx.cuda_graph->is_enabled())) ||
-        // CUDA_GRAPHS ENABLED
-        ((ncols > 32768) &&
-         !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
-            ctx.cuda_graph->is_enabled()))) {
+            // Determine if CUDA graphs are effectively disabled for this context
+            // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
+            (((ncols > 65536) &&
+              (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+               ctx.any_cuda_graph_enabled())) ||
+            // CUDA graphs are enabled - use lower threshold
+             ((ncols > 32768) &&
+              !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+                ctx.any_cuda_graph_enabled())))) {
 #else
         (ncols > 65536)) {
 #endif // USE_CUDA_GRAPH
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index df9eed71..5d1dadd3 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -206,10 +206,16 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 16) {
-                // matrix C
 #if defined(RDNA3)
-                return 2 * l + (threadIdx.x / 16);
+                if constexpr (std::is_same_v || std::is_same_v) {
+                    // matrix C
+                    return 2 * l + (threadIdx.x / 16);
+                } else {
+                    // matrix A&B
+                    return l;
+                }
 #else
+                // matrix C is the transposed matrix A&B on RDNA4
                 return ne * (threadIdx.x / 16) + l;
 #endif // defined(RDNA3)
             } else if constexpr (I == 16 && J == 8) {
@@ -327,7 +333,33 @@ namespace ggml_cuda_mma {
 
         static __device__ __forceinline__ int get_j(const int l) {
             if constexpr (I == 16 && J == 8) {
-                return 4 * (threadIdx.x / 16) + l;
+                return ne * (threadIdx.x / 16) + l;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+#elif defined(AMD_MFMA_AVAILABLE)
+        static constexpr int ne = I * J / 64;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            if (I == 16 && J == 8) return true;
+            return false;
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return threadIdx.x % 16;
+            } else {
+                NO_DEVICE_CODE;
+                return -1;
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 16 && J == 8) {
+                return ne * (threadIdx.x / 16) + l;
             } else {
                 NO_DEVICE_CODE;
                 return -1;
@@ -385,7 +417,22 @@ namespace ggml_cuda_mma {
         static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
 
 #if defined(AMD_WMMA_AVAILABLE)
-        static constexpr int ne = I * J / 32;
+        static constexpr int ne = tile::ne;
+        nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+        static constexpr __device__ bool supported() {
+            return tile::supported();
+        }
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            return tile::get_i(l);
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            return tile::get_j(l);
+        }
+#elif defined(AMD_MFMA_AVAILABLE)
+        static constexpr int ne = tile::ne;
         nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
 
         static constexpr __device__ bool supported() {
@@ -621,6 +668,21 @@ namespace ggml_cuda_mma {
 
         return ret;
     }
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+    template 
+    static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
+        tile ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+        }
+        return ret;
+    }
+
+    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+        NO_DEVICE_CODE;
+        return tile<8, 8, half2>{};
+    }
 #else // Volta
     template 
     static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
@@ -639,6 +701,19 @@ namespace ggml_cuda_mma {
     }
 #endif // defined(TURING_MMA_AVAILABLE)
 
+    static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
+#if defined(RDNA4)
+        const int row = t.get_i(0);
+        const int left_right = t.get_j(0) / 4;
+        const int up_down = row / 8;
+        const int idx = row % 8;
+        reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
+#else
+        GGML_UNUSED_VARS(t);
+        NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+    }
+
     template 
     static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) {
 #if defined(AMD_MFMA_AVAILABLE)
@@ -878,6 +953,45 @@ namespace ggml_cuda_mma {
             : "+r"(Dxi[2]), "+r"(Dxi[3])
             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+        using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+        halfx8_t& acc_frag = reinterpret_cast(D.x[0]);
+        const halfx8_t& a_frag = reinterpret_cast(A.x[0]);
+        const halfx8_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+#elif defined(AMD_MFMA_AVAILABLE)
+        // MFMA: FP16 input, FP32 accumulate, convert back to half2.
+        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+
+        // Convert existing half2 accumulator to float for MFMA:
+        floatx4_t acc_f32;
+        {
+            const halfx4_t acc_h = reinterpret_cast(D.x[0]);
+#pragma unroll
+            for (int i = 0; i < 4; ++i) {
+                acc_f32[i] = (float)acc_h[i];
+            }
+        }
+
+        const halfx4_t& a_frag = reinterpret_cast(A.x[0]);
+        const halfx4_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0);
+
+        // Convert back to half2:
+        {
+            halfx4_t result_h;
+#pragma unroll
+            for (int i = 0; i < 4; ++i) {
+                result_h[i] = (_Float16)acc_f32[i];
+            }
+            reinterpret_cast(D.x[0]) = result_h;
+        }
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
@@ -900,6 +1014,32 @@ namespace ggml_cuda_mma {
 #endif // AMPERE_MMA_AVAILABLE
     }
 
+    template 
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
+#ifdef AMD_MFMA_AVAILABLE
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+        floatx4_t& acc_frag = reinterpret_cast(D.x[0]);
+#if defined(CDNA3)
+        using floatx2_t = __attribute__((ext_vector_type(2))) float;
+        const floatx2_t& a_frag = reinterpret_cast(A.x[0]);
+        const floatx2_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA1)
+#pragma unroll
+        for (int i = 0; i < 2; ++i) {
+            acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
+        }
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(CDNA3)
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
+    }
+
     static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> &     D,
                                                             const tile<16, 8, int> & A,
                                                             const tile<8, 8, int> &  B,
@@ -1009,6 +1149,13 @@ namespace ggml_cuda_mma {
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
 #endif // RDNA4
+#elif defined(AMD_MFMA_AVAILABLE)
+        using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+        floatx4_t& acc_frag = reinterpret_cast(D.x[0]);
+        const halfx4_t& a_frag = reinterpret_cast(A.x[0]);
+        const halfx4_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
@@ -1036,11 +1183,31 @@ namespace ggml_cuda_mma {
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
-#endif // RDNA4
+#endif // defined(RDNA4)
+#elif defined(AMD_MFMA_AVAILABLE)
+        using floatx4_t = __attribute__((ext_vector_type(4))) float;
+        floatx4_t& acc_frag = reinterpret_cast(D.x[0]);
+#if defined(CDNA3) || defined(CDNA2)
+        using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
+        const bf16x4_t& a_frag = reinterpret_cast(A.x[0]);
+        const bf16x4_t& b_frag = reinterpret_cast(B.x[0]);
+        acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
+#elif defined(CDNA1)
+#pragma unroll
+        for (int i = 0; i < 2; ++i) {
+            using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
+            const bf16x2_t& a_frag = reinterpret_cast(A.x[i]);
+            const bf16x2_t& b_frag = reinterpret_cast(B.x[i]);
+            acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
+        }
 #else
         GGML_UNUSED_VARS(D, A, B);
         NO_DEVICE_CODE;
-#endif // AMPERE_MMA_AVAILABLE
+#endif // defined(CDNA3) || defined(CDNA2)
+#else
+        GGML_UNUSED_VARS(D, A, B);
+        NO_DEVICE_CODE;
+#endif // defined(AMD_WMMA_AVAILABLE)
     }
 
     template 
diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu
index 6643f243..aad4c34a 100644
--- a/ggml/src/ggml-cuda/mmf.cu
+++ b/ggml/src/ggml-cuda/mmf.cu
@@ -2,6 +2,13 @@
 #include "mmf.cuh"
 #include "mmid.cuh"
 
+static __forceinline__ int mmf_get_rows_per_block(const int cc) {
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return MMF_ROWS_PER_BLOCK_CDNA;
+    } else {
+        return MMF_ROWS_PER_BLOCK;
+    }
+}
 
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
@@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
         ids_info_ptr = &ids_info;
     }
 
+    const int device    = ggml_cuda_get_device();
+    const int cc        = ggml_cuda_info().devices[device].cc;
+    const int rows_per_block = mmf_get_rows_per_block(cc);
+
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
             constexpr int vals_per_T = 1;
-            mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+            mul_mat_f_switch_rows_per_block(
+                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
                 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         case GGML_TYPE_F16: {
             const half2 * src0_d = (const half2 *) src0->data;
             constexpr int vals_per_T = 2;
-            mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+            mul_mat_f_switch_rows_per_block(
+                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
                 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
             constexpr int vals_per_T = 2;
-            mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+            mul_mat_f_switch_rows_per_block(
+                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
                 ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
@@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
             return false;
         }
     }
-    if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
+    if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
+        return false;
+    }
+
+    if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
         return false;
     }
 
@@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
     } else {
         if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
             return false;
+        } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
+            //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
+            return false;
+        } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
+            return false;
         } else if (src1_ncols > 16) {
             return false;
         }
@@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
 
     switch (type) {
         case GGML_TYPE_F32:
-            return ampere_mma_available(cc);
+            return ampere_mma_available(cc) || amd_mfma_available(cc);
         case GGML_TYPE_F16:
-            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
+            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
         case GGML_TYPE_BF16:
-            return ampere_mma_available(cc) || amd_wmma_available(cc);
+            return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
         default:
             return false;
     }
diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh
index e3673094..c2a8d54c 100644
--- a/ggml/src/ggml-cuda/mmf.cuh
+++ b/ggml/src/ggml-cuda/mmf.cuh
@@ -7,6 +7,31 @@
 using namespace ggml_cuda_mma;
 
 #define MMF_ROWS_PER_BLOCK 32
+#define MMF_ROWS_PER_BLOCK_CDNA 64
+
+static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return 512;
+    } else {
+        return 256;
+    }
+}
+
+static __forceinline__ int mmf_get_padding(int cc) {
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return 2;
+    } else {
+        return 4;
+    }
+}
+
+static constexpr __device__ int mmf_get_padding() {
+#if defined(AMD_MFMA_AVAILABLE)
+    return 2;
+#else
+    return 4;
+#endif // defined(AMD_MFMA_AVAILABLE)
+}
 
 struct mmf_ids_data {
     const int32_t * ids_src_compact = nullptr;
@@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
-#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE)
-    // Special case for tf32, just dummy mma layout as wmma doesn't support it.
-    constexpr bool is_tf32 = std::is_same_v;
-    constexpr int tile_B_I = is_tf32 ? 8 : 16;
-    constexpr int tile_C_J = is_tf32 ? 8 : 16;
-    constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
-    typedef tile<16,       8,        T,     ab_layout>           tile_A;
-    typedef tile           tile_B;
-    typedef tile<16,       tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
+    if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
 #else
 #ifdef VOLTA_MMA_AVAILABLE
-    if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {
+    if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
     typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
     typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
 #else
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<16, 8, T>     tile_A;
     typedef tile<8,  8, T>     tile_B;
     typedef tile<16, 8, float> tile_C;
@@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
     }
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int tile_k_padded = warp_size + mmf_get_padding();
     constexpr int ntA = rows_per_block / tile_A::I;
     constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
 
@@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
     }
 
     float * buf_iw = (float *) compute_base;
-    constexpr int kiw = nwarps*rows_per_block + 4;
+    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
 
     if (nwarps > 1) {
         __syncthreads();
@@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
             return;
         }
 
-        float sum = 0.0f;
-        static_assert(rows_per_block == warp_size, "need loop/check");
+        float sum[rows_per_block/warp_size] = {0.0f};
+        static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
 #pragma unroll
         for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
-            const int i = i0 + threadIdx.x;
+#pragma unroll
+            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+                const int i = i0 + i1*warp_size + threadIdx.x;
 
-            sum += buf_iw[j*kiw + i];
+                sum[i1] += buf_iw[j*kiw + i];
+            }
         }
 
         if constexpr (!has_ids) {
-            dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+            for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+                dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+            }
         } else {
             const int slot = (j < cols_per_block) ? slot_map[j] : -1;
             if (slot >= 0 && (col_base + j) < ncols_dst_total) {
-                dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+                    dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+                }
             }
         }
     }
-#ifdef VOLTA_MMA_AVAILABLE
     }
-#endif //VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     NO_DEVICE_CODE;
-#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 }
 
 //This kernel is for larger batch sizes of mul_mat_id
@@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
         const uint3 sis1_fd, const uint3 nch_fd) {
 // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
-#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE)
-    // Special case for tf32, just dummy mma layout as wmma doesn't support it.
-    constexpr bool is_tf32 = std::is_same_v;
-    constexpr int tile_B_I = is_tf32 ? 8 : 16;
-    constexpr int tile_C_J = is_tf32 ? 8 : 16;
-    constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
-    typedef tile<16,       8,        T,     ab_layout>           tile_A;
-    typedef tile           tile_B;
-    typedef tile<16,       tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
+    if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_A;
+    typedef tile<16, 8,  T,     get_input_data_layout()> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR>     tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_A;
+    typedef tile<16, 8,  T,     DATA_LAYOUT_I_MAJOR> tile_B;
+    typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
 #else
 #ifdef VOLTA_MMA_AVAILABLE
-    if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {
+    if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<32, 4, T,     DATA_LAYOUT_I_MAJOR>          tile_A;
     typedef tile< 8, 4, T,     DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
     typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR>          tile_C;
 #else
+    if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
     typedef tile<16, 8, T>     tile_A;
     typedef tile<8,  8, T>     tile_B;
     typedef tile<16, 8, float> tile_C;
@@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
 
 
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int tile_k_padded = warp_size + mmf_get_padding();
     constexpr int ntA = rows_per_block / tile_A::I;
     constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
 
@@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
     }
 
     float * buf_iw = (float *) compute_base;
-    constexpr int kiw = nwarps*rows_per_block + 4;
+    constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
 
     if (nwarps > 1) {
         __syncthreads();
@@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
             return;
         }
 
-        float sum = 0.0f;
-        static_assert(rows_per_block == warp_size, "need loop/check");
+        float sum[rows_per_block/warp_size] = {0.0f};
+        static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
 #pragma unroll
         for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
-            const int i = i0 + threadIdx.x;
+#pragma unroll
+            for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+                const int i = i0 + i1*warp_size + threadIdx.x;
 
-            sum += buf_iw[j*kiw + i];
+                sum[i1] += buf_iw[j * kiw + i];
+            }
         }
 
         const int global_j = col_base + j;
@@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
             const int token = (int) qrm.x;
             if (token < ncols_dst_total) {
                 const int slot = (int) qrm.y;
-                dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+                for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+                    dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+                }
             }
         }
     }
-#ifdef VOLTA_MMA_AVAILABLE
     }
-#endif // VOLTA_MMA_AVAILABLE
 #else
     GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
         ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
     NO_DEVICE_CODE;
-#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
 }
 
-template
+template
 static inline void mul_mat_f_switch_ids(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
@@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
         const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
         const uint3 nch_fd  = init_fastdiv_values((uint32_t) nchannels_dst);
 
-        mul_mat_f_ids<<>>
+        mul_mat_f_ids<<>>
             (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
             ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
             channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
         dim3 block_nums_ids = block_nums;
         block_nums_ids.y *= col_tiles;
 
-        mul_mat_f<<>>
+        mul_mat_f<<>>
             (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     } else {
-        mul_mat_f<<>>
+        mul_mat_f<<>>
             (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     }
 }
 
-template 
+template 
 void mul_mat_f_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@@ -605,7 +645,7 @@ void mul_mat_f_cuda(
 
     int64_t nwarps_best     = 1;
     int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
-    int64_t max_block_size  = 256;
+    int64_t max_block_size  = mmf_get_max_block_size(cc);
     for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
         const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
         if (niter < niter_best) {
@@ -614,10 +654,9 @@ void mul_mat_f_cuda(
         }
     }
 
-    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
-    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
-    const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
-    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
+    const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
     const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
     const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
     const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
@@ -628,56 +667,56 @@ void mul_mat_f_cuda(
 
     switch (nwarps_best) {
         case 1: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 2: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 3: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 4: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 5: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 6: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 7: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
                 ids_data);
         } break;
         case 8: {
-            mul_mat_f_switch_ids(
+            mul_mat_f_switch_ids(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
@@ -691,7 +730,7 @@ void mul_mat_f_cuda(
     GGML_UNUSED_VARS(nchannels_y);
 }
 
-template 
+template 
 static void mul_mat_f_switch_cols_per_block(
         const T * x, const float * y, const int32_t * ids, float * dst,
         const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
 
     switch (ncols_case) {
         case  1: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  2: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  3: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  4: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  5: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream, ids_data);
         } break;
         case  6: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  7: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  8: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  9: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 10: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 11: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 12: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 13: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 14: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 15: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 16: {
-            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
@@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
     }
 }
 
-#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
-    template void mul_mat_f_cuda( \
+template 
+static void mul_mat_f_switch_rows_per_block(
+        const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t stride_col_id, const int stride_row_id,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream, const mmf_ids_data * ids_data) {
+    switch (rows_per_block) {
+        case MMF_ROWS_PER_BLOCK: {
+            mul_mat_f_switch_cols_per_block(
+                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+        } break;
+        case MMF_ROWS_PER_BLOCK_CDNA: {
+            mul_mat_f_switch_cols_per_block(
+                x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+        } break;
+        default:
+            GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
+    }
+}
+
+#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
+    template void mul_mat_f_cuda( \
         const T * x, const float * y, const int32_t * ids, float * dst, \
         const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
         const int64_t stride_col_id, const int64_t stride_row_id, \
@@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
         cudaStream_t stream, const mmf_ids_data * ids_data);
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
 #define DECL_MMF_CASE_EXTERN(ncols_dst) \
-    extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
-    extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
-    extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
 
 #define DECL_MMF_CASE(ncols_dst) \
-    DECL_MMF_CASE_HELPER(float, ncols_dst) \
-    DECL_MMF_CASE_HELPER(half2, ncols_dst) \
-    DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+    DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+    DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
 
 DECL_MMF_CASE_EXTERN(1);
 DECL_MMF_CASE_EXTERN(2);
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index a382e6a6..255e59f6 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -2715,14 +2715,14 @@ template  static __device__ __forceinline__ void loa
 
 #pragma unroll
         for (int l = 0; l < QR2_XXS; ++l) {
-            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
-            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+            const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
+            const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
 
-            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
-            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
-            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2733,12 +2733,12 @@ template  static __device__ __forceinline__ void loa
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
-        const int ls = aux32 >> 28;
+        const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
         const float d = bxi->d;
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
 #else
-        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
     }
 }
@@ -2776,11 +2776,14 @@ template  static __device__ __forceinline__ void loa
 
     #pragma unroll
         for (int l = 0; l < QR2_XS; ++l) {
-            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
-            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
+            const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
+            const uint32_t signs = unpack_ksigns(q2[l] >> 9);
 
-            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
-            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2904,11 +2907,13 @@ template  static __device__ __forceinline__ void loa
 #pragma unroll
         for (int l = 0; l < QR3_XXS; ++l) {
             const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+            const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
 
-            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
-            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
@@ -3697,13 +3702,20 @@ static __global__ void mul_mat_q(
          tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 }
 
-
 template 
-static __global__ void mul_mat_q_stream_k_fixup(
-        const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
-        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
-        const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
-        const int ncols_max) {
+static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
+                                                const int32_t * expert_bounds,
+                                                float * __restrict__ dst,
+                                                const float * __restrict__ tmp_last_tile,
+                                                const int    ncols_x,
+                                                const int    nrows_x,
+                                                const int    ncols_dst,
+                                                const size_t stride_col_dst,
+                                                const int    nchannels_y,
+                                                const size_t stride_channel_dst,
+                                                const int    nsamples_y,
+                                                const size_t stride_sample_dst,
+                                                const int    ncols_max) {
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits::qk;
     constexpr int     ITER_K          = get_iter_k(type);
diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu
index 32948e4d..d9147202 100644
--- a/ggml/src/ggml-cuda/mmvf.cu
+++ b/ggml/src/ggml-cuda/mmvf.cu
@@ -4,26 +4,48 @@
 #include "mmvf.cuh"
 #include "convert.cuh"
 
-template 
+template 
 static __global__ void mul_mat_vec_f(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
-        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+        const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        const int ids_stride) {
     const int row         = blockIdx.x;
+    // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
     const int channel_dst = blockIdx.y;
-    const int channel_x   = ids ? ids[channel_dst]          : fastdiv((uint32_t) channel_dst, channel_ratio);
-    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
-    const int sample_dst  = blockIdx.z;
+    const int tid         = threadIdx.x;
+
+    int token_idx;
+    int channel_x;
+    int channel_y;
+    int sample_dst;
+
+    if constexpr (is_multi_token_id) {
+        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+        token_idx  = blockIdx.z;
+        channel_x  = ids[channel_dst + token_idx * ids_stride];
+        channel_y  = fastmodulo(channel_dst, nchannels_y);
+        sample_dst = 0;
+    } else {
+        token_idx  = ids ? blockIdx.z                                          : 0;
+        channel_x  = ids ? ids[blockIdx.y + token_idx * ids_stride]            : fastdiv((uint32_t) channel_dst, channel_ratio);
+        channel_y  = ids ? fastmodulo(blockIdx.y, nchannels_y)                 : channel_dst;
+        sample_dst = ids ? 0                                                   : blockIdx.z;
+    }
+
     const int sample_x    = fastdiv((uint32_t) sample_dst, sample_ratio);
     const int sample_y    = sample_dst;
-    const int tid         = threadIdx.x;
 
     constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
 
     x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
     y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
     dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+    if constexpr (is_multi_token_id) {
+        y   += token_idx*stride_col_y2*2;
+        dst += token_idx*stride_col_dst;
+    }
 
     bool use_gate = false;
     bool use_bias = false;
@@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f(
     if (use_gate) {
         gate_x += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
     }
+
+    const int channel_bias = ids ? channel_x : channel_dst;
+
     if constexpr (has_fusion) {
-        const int channel_bias = ids ? channel_x : channel_dst;
         if (use_bias) {
             x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
         }
@@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f(
     }
 }
 
-template
+template
 static void mul_mat_vec_f_switch_fusion(
         const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
-        const int64_t ncols, const int64_t nrows,
+        const int64_t ncols, const uint3 nchannels_y,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
+        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_f<<>>
-                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f<<>>
+                (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
             return;
        }
     }
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_f<<>>
-        (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+    mul_mat_vec_f<<>>
+        (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 
 }
 
-template 
+template 
 void launch_mul_mat_vec_f_cuda(
         const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int64_t ncols, const int64_t nrows,
@@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
     GGML_ASSERT(ncols        % 2 == 0);
     GGML_ASSERT(stride_row   % 2 == 0);
     GGML_ASSERT(stride_col_y % 2 == 0);
     GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
     const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
     const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 
@@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda(
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
 
     const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
-    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case   64: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case   96: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  128: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  160: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  192: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  224: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  256: {
-            mul_mat_vec_f_switch_fusion
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        const int64_t ids_stride, cudaStream_t stream) {
+
+    const bool has_ids = ids != nullptr;
+
+    if (has_ids && ncols_dst > 1) {
+        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+        constexpr int c_ncols_dst = 1;
+        launch_mul_mat_vec_f_cuda
+            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+             ncols_dst, ids_stride, stream);
+        return;
+    }
+
+    if (has_ids) {
+        // Single-token MUL_MAT_ID path
+        constexpr int c_ncols_dst = 1;
+        launch_mul_mat_vec_f_cuda
+            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+             ncols_dst, ids_stride, stream);
+        return;
+    }
+
     switch (ncols_dst) {
         case 1:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 2:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 3:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 4:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 5:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 6:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 7:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 8:
             launch_mul_mat_vec_f_cuda
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        enum ggml_prec prec, cudaStream_t stream) {
+        const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
 
     if constexpr(std::is_same_v) {
         if (prec == GGML_PREC_DEFAULT) {
             mul_mat_vec_f_cuda_switch_ncols_dst
                 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             return;
         }
     }
     mul_mat_vec_f_cuda_switch_ncols_dst
         (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
 }
 
 void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
@@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const size_t ts_src1 = ggml_type_size(src1->type);
     const size_t ts_dst  = ggml_type_size(dst->type);
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+    GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
     GGML_ASSERT(ne13 == ne3);
 
     GGML_ASSERT(        nb00       == ts_src0);
@@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const int64_t ncols_dst          = ids ? ne2  : ne1;
     const int64_t nchannels_y        = ids ? ne11 : ne12;
     const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_col_dst     = ids ? s2   : s1;
+    const int64_t stride_col_y       = ids ? s12  : s11;
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
+    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f(
             const float * src0_d = (const float *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh
index a09fbdc7..a50f7c02 100644
--- a/ggml/src/ggml-cuda/mmvf.cuh
+++ b/ggml/src/ggml-cuda/mmvf.cuh
@@ -1,5 +1,7 @@
 #include "common.cuh"
 
+#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.
+
 void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
     const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
 
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index d671551c..632246e4 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
 enum mmvq_parameter_table_id {
     MMVQ_PARAMETERS_GENERIC = 0,
     MMVQ_PARAMETERS_GCN,
-    MMVQ_PARAMETERS_RDNA2
+    MMVQ_PARAMETERS_RDNA2,
+    MMVQ_PARAMETERS_RDNA3_0,
+    MMVQ_PARAMETERS_RDNA4
 };
 
 static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
-#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
+#if defined(RDNA4)
+    return MMVQ_PARAMETERS_RDNA4;
+#elif defined(RDNA3_0)
+    return MMVQ_PARAMETERS_RDNA3_0;
+#elif defined(RDNA2) || defined(RDNA3_5)
     return MMVQ_PARAMETERS_RDNA2;
 #elif defined(GCN) || defined(CDNA)
     return MMVQ_PARAMETERS_GCN;
@@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
 }
 
 static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
-    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+    if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+        return MMVQ_PARAMETERS_RDNA4;
+    }
+    if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
+        return MMVQ_PARAMETERS_RDNA3_0;
+    }
+    if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
         return MMVQ_PARAMETERS_RDNA2;
     }
     if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
     return MMVQ_PARAMETERS_GENERIC;
 }
 
-static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
+static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC) {
         switch (ncols_dst) {
             case 1:
@@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
                 return 1;
         }
     }
+    if (table_id == MMVQ_PARAMETERS_RDNA4) {
+        // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
+        // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
+        // pressure and lookup table contention at higher thread counts.
+        if (ncols_dst == 1) {
+            switch (type) {
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                case GGML_TYPE_Q6_K:
+                case GGML_TYPE_IQ4_NL:
+                case GGML_TYPE_IQ4_XS:
+                    return 8;
+                default:
+                    return 1;
+            }
+        }
+        return 1;
+    }
+    if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
+        // RDNA3 (W7900): stricter whitelist than RDNA4.
+        // Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
+        if (ncols_dst == 1) {
+            switch (type) {
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q6_K:
+                case GGML_TYPE_IQ4_NL:
+                    return 8;
+                default:
+                    return 1;
+            }
+        }
+        return 1;
+    }
     return 1;
 }
 
@@ -137,21 +193,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
     return 1;
 }
 
-// tell the compiler to use as many registers as it wants, see nwarps definition below
-template 
-__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
+template 
+__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
-        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
+        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+        const uint32_t ids_stride) {
 
     constexpr int qk  = ggml_cuda_type_traits::qk;
     constexpr int qi  = ggml_cuda_type_traits::qi;
     constexpr int vdr = get_vdr_mmvq(type);
     constexpr mmvq_parameter_table_id table_id = get_device_table_id();
-    constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
+    constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
     constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
@@ -162,11 +218,25 @@ static __global__ void mul_mat_vec_q(
     const     int blocks_per_row_x = ncols_x / qk;
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-    // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
     const uint32_t channel_dst = blockIdx.y;
-    const uint32_t channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
-    const uint32_t channel_y   = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
-    const uint32_t sample_dst  = blockIdx.z;
+
+    uint32_t token_idx = 0;
+    uint32_t channel_x;
+    uint32_t channel_y;
+    uint32_t sample_dst;
+
+    if constexpr (is_multi_token_id) {
+        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+        token_idx  = blockIdx.z;
+        channel_x  = ids[channel_dst + token_idx * ids_stride];
+        channel_y  = fastmodulo(channel_dst, nchannels_y);
+        sample_dst = 0;
+    } else {
+        channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
+        channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+        sample_dst = blockIdx.z;
+    }
+
     const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
     const uint32_t sample_y    = sample_dst;
 
@@ -188,11 +258,11 @@ static __global__ void mul_mat_vec_q(
         active_glu    = fusion.glu_op;
     }
 
-    const uint32_t channel_bias = ids ? channel_x : channel_dst;
 
     float x_biases[ncols_dst]    = { 0.0f };
     float gate_biases[ncols_dst] = { 0.0f };
     if constexpr (has_fusion) {
+        const uint32_t channel_bias = ids ? channel_x : channel_dst;
         if (use_bias) {
             x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
             // 1. Hide latency by prefetching bias and gate here
@@ -222,6 +292,9 @@ static __global__ void mul_mat_vec_q(
     float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+    if constexpr (is_multi_token_id) {
+        y += token_idx*stride_col_y;
+    }
     const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -275,6 +348,10 @@ static __global__ void mul_mat_vec_q(
 
     dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
 
+    if constexpr (is_multi_token_id) {
+        dst += token_idx*stride_col_dst;
+    }
+
     // sum up partial sums and write back result
 #pragma unroll
     for (int j = 0; j < ncols_dst; ++j) {
@@ -334,41 +411,43 @@ static __global__ void mul_mat_vec_q(
     }
 }
 
+template
 static std::pair calc_launch_params(
-        const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+        const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
         const int warp_size, const mmvq_parameter_table_id table_id) {
     const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
-    const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
-    const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
+    const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
+    const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
     return {block_nums, block_dims};
 }
 
-template
+template
 static void mul_mat_vec_q_switch_fusion(
         const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
         const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
-        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
+        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
+        const uint32_t ids_stride, cudaStream_t stream) {
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (c_ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_q<<>>
+            mul_mat_vec_q<<>>
                 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
             return;
         }
     }
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_q<<>>
+    mul_mat_vec_q<<>>
         (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 }
 
 template 
@@ -379,7 +458,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
         const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        cudaStream_t stream) {
+        const int ids_stride, cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
@@ -393,72 +472,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
     const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+    const bool has_ids = ids != nullptr;
+
+    if (has_ids && ncols_dst > 1) {
+        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+        constexpr int c_ncols_dst = 1;
+        std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
+        mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+             channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+             sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+             dims.first, dims.second, 0, ids_stride, stream);
+        return;
+    }
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
     switch (ncols_dst) {
         case 1: {
             constexpr int c_ncols_dst = 1;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 2: {
             constexpr int c_ncols_dst = 2;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 3: {
             constexpr int c_ncols_dst = 3;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 4: {
             constexpr int c_ncols_dst = 4;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 5: {
             constexpr int c_ncols_dst = 5;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 6: {
             constexpr int c_ncols_dst = 6;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 7: {
             constexpr int c_ncols_dst = 7;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 8: {
             constexpr int c_ncols_dst = 8;
-            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
             mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         default:
             GGML_ABORT("fatal error");
@@ -474,127 +564,127 @@ static void mul_mat_vec_q_switch_type(
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
         const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        cudaStream_t stream) {
+        const int ids_stride, cudaStream_t stream) {
     switch (type_x) {
         case GGML_TYPE_Q4_0:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q4_1:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_0:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_1:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q8_0:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_MXFP4:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q3_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q4_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q6_K:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_XXS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_XS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_S:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ3_XXS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ1_S:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ1_M:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ4_NL:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ4_XS:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ3_S:
             mul_mat_vec_q_switch_ncols_dst
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -622,7 +712,7 @@ void ggml_cuda_mul_mat_vec_q(
     GGML_ASSERT(        nb0        == ts_dst);
     GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+    GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
 
     const float   * src1_d =       (const float   *) src1->data;
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
@@ -693,11 +783,13 @@ void ggml_cuda_mul_mat_vec_q(
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
+    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
     mul_mat_vec_q_switch_type(
         src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
         ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
         ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-        ne03,              ne3,           s03, s13,              s3,               stream);
+        ne03,              ne3,           s03, s13,              s3,               ids_stride, stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(
@@ -726,7 +818,7 @@ void ggml_cuda_op_mul_mat_vec_q(
     ggml_cuda_mm_fusion_args_device fusion_local{};
     mul_mat_vec_q_switch_type(
         src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
-        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
+        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
 
     GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
 }
diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh
index 4bb10cfa..8a154631 100644
--- a/ggml/src/ggml-cuda/mmvq.cuh
+++ b/ggml/src/ggml-cuda/mmvq.cuh
@@ -1,6 +1,7 @@
 #include "common.cuh"
 
 #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
 
 void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index 4f153c57..ef98f675 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -25,19 +25,8 @@ static __global__ void norm_f32(
     }
 
     // sum up partial sums
-    mean_var = warp_reduce_sum(mean_var);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float2 s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = mean_var;
-        }
-        __syncthreads();
-        mean_var = s_sum[lane_id];
-        mean_var = warp_reduce_sum(mean_var);
-    }
+    extern __shared__ float2 s_sum2[];
+    mean_var = block_reduce(mean_var, s_sum2);
 
     const float mean = mean_var.x / ncols;
     const float var = mean_var.y / ncols - mean * mean;
@@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp += x[j];
     }
 
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce(tmp, s_sum);
 
     const float mean = tmp / group_size;
     tmp = 0.0f;
@@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp += xi * xi;
     }
 
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    tmp = block_reduce(tmp, s_sum);
 
     const float variance = tmp / group_size;
     const float scale = rsqrtf(variance + eps);
@@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x,
     }
 
     // sum up partial sums
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int        warp_id = tid / WARP_SIZE;
-        const int        lane_id = tid % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = 0.0f;
-        if (lane_id < (block_size / WARP_SIZE)) {
-            tmp = s_sum[lane_id];
-        }
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce(tmp, s_sum);
 
     const float mean = tmp / ncols;
     const float scale = rsqrtf(mean + eps);
@@ -306,19 +259,8 @@ static __global__ void l2_norm_f32(
     }
 
     // sum up partial sums
-    tmp = warp_reduce_sum(tmp);
-    if constexpr (block_size > WARP_SIZE) {
-        static_assert(block_size == 1024, "unexpected block_size");
-        __shared__ float s_sum[32];
-        const int warp_id = threadIdx.x / WARP_SIZE;
-        const int lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        __syncthreads();
-        tmp = s_sum[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    extern __shared__ float s_sum[];
+    tmp = block_reduce(tmp, s_sum);
 
     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
     const float scale = rsqrtf(fmaxf(tmp, eps * eps));
@@ -337,7 +279,7 @@ static void norm_f32_cuda(
         norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -348,7 +290,7 @@ static void group_norm_f32_cuda(
         group_norm_f32<<>>(x, dst, group_size, ne_elements, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps);
+        group_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
     }
 }
 
@@ -358,10 +300,10 @@ static void rms_norm_f32_cuda(
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(256, 1, 1);
-        rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
@@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float *  x,
         const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
         if (ncols < 1024) {
             const dim3 block_dims(256, 1, 1);
-            rms_norm_f32<256, true><<>>(
+            rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true><<>>(
+            rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
         }
@@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float *  x,
         const uint3 add_nsamples_packed  = init_fastdiv_values(add_nsamples);
         if (ncols < 1024) {
             const dim3 block_dims(256, 1, 1);
-            rms_norm_f32<256, true, true><<>>(
+            rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
                 add_nchannels_packed, add_nsamples_packed);
         } else {
             const dim3 block_dims(1024, 1, 1);
-            rms_norm_f32<1024, true, true><<>>(
+            rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
                 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
                 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
                 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
@@ -460,7 +402,7 @@ static void l2_norm_f32_cuda(
         l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
index 660c192e..31cd00f7 100644
--- a/ggml/src/ggml-cuda/pad.cu
+++ b/ggml/src/ggml-cuda/pad.cu
@@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
     return (coord + size) % size;
 }
 
-static __global__ void pad_f32(const float * src, float * dst,
+static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
                                const int lp0, const int rp0, const int lp1, const int rp1,
                                const int lp2, const int rp2, const int lp3, const int rp3,
                                const int ne0, const int ne1, const int ne2, const int ne3,
@@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst,
             const int64_t i01  = i1 - lp1;
             const int64_t i02  = i2 - lp2;
             const int64_t i03  = i3 - lp3;
-            const int64_t ne02 = ne2 - lp2 - rp2;
-            const int64_t ne01 = ne1 - lp1 - rp1;
-            const int64_t ne00 = ne0 - lp0 - rp0;
 
-            const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
+            const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
 
             dst[dst_idx] = src[src_idx];
         } else {
@@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst,
         const int64_t i02 = wrap_around(i2 - lp2, ne02);
         const int64_t i03 = wrap_around(i3 - lp3, ne03);
 
-        const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
+        const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
 
         dst[dst_idx] = src[src_idx];
     }
 }
 
 
-static void pad_f32_cuda(const float * src, float * dst,
+static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
     const int lp0, const int rp0, const int lp1, const int rp1,
     const int lp2, const int rp2, const int lp3, const int rp3,
     const int ne0, const int ne1, const int ne2, const int ne3,
     const bool circular, cudaStream_t stream) {
     int  num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
     dim3 gridDim(num_blocks, ne1, ne2 * ne3);
-    pad_f32<<>>(src, dst,
+    pad_f32<<>>(src, s00, s01, s02, s03, dst,
                                                          lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
                                                          ne0, ne1, ne2, ne3, circular);
 }
@@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float *             dst_d  = (float *) dst->data;
     cudaStream_t        stream = ctx.stream();
 
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     const int32_t lp0      = ((const int32_t *) (dst->op_params))[0];
     const int32_t rp0      = ((const int32_t *) (dst->op_params))[1];
@@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int32_t rp3      = ((const int32_t *) (dst->op_params))[7];
     const int32_t circular = ((const int32_t *) (dst->op_params))[8];
 
-    pad_f32_cuda(src0_d, dst_d,
+    const size_t s00 = nb00 / ggml_type_size(src0->type);
+    const size_t s01 = nb01 / ggml_type_size(src0->type);
+    const size_t s02 = nb02 / ggml_type_size(src0->type);
+    const size_t s03 = nb03 / ggml_type_size(src0->type);
+
+    pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
                  lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
                  dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
                  (bool) circular, stream);
diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu
index a8c68e44..4300ffc1 100644
--- a/ggml/src/ggml-cuda/quantize.cu
+++ b/ggml/src/ggml-cuda/quantize.cu
@@ -235,7 +235,7 @@ static __global__ void quantize_mmq_q8_1(
     q.z = roundf(xi.z*d_inv);
     q.w = roundf(xi.w*d_inv);
 
-    // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+    // Write back 4 int8 values as a single 32 bit value for better memory bandwidth:
     char4 * yqs4 = (char4 *) y[ib].qs;
     yqs4[iqs/4] = q;
 
diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh
index 6bcae9e5..de240fd4 100644
--- a/ggml/src/ggml-cuda/reduce_rows.cuh
+++ b/ggml/src/ggml-cuda/reduce_rows.cuh
@@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r
     }
 
     // sum up partial sums
-    sum = warp_reduce_sum(sum);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float s_sum[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = sum;
-        }
-        __syncthreads();
-        sum = 0.0f;
-        if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) {
-            sum = s_sum[lane_id];
-        }
-        sum = warp_reduce_sum(sum);
-    }
+    __shared__ float shared_vals[32];
+    sum = block_reduce(sum, shared_vals);
 
     if (col != 0) {
         return;
diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu
index 88ed7911..45a49a5d 100644
--- a/ggml/src/ggml-cuda/rope.cu
+++ b/ggml/src/ggml-cuda/rope.cu
@@ -43,10 +43,15 @@ static __device__ void rope_yarn(
 template 
 static __global__ void rope_norm(const T *            x,
                                  D *                  dst,
-                                 const int            ne0,
-                                 const int            ne1,
+                                 const int            ne00,
+                                 const int            ne01,
+                                 const int            ne02,
+                                 const int            s01,
+                                 const int            s02,
+                                 const int            s03,
                                  const int            s1,
                                  const int            s2,
+                                 const int            s3,
                                  const int            n_dims,
                                  const int32_t *      pos,
                                  const float          freq_scale,
@@ -59,23 +64,23 @@ static __global__ void rope_norm(const T *            x,
                                  const int            set_rows_stride) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
-
-    int       idst = row_dst * ne0 + i0;
-    const int ix   = channel_x*s2 + row_x*s1 + i0;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
+    int       idst = i0 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 + i1 * s01 + i2 * s02 + i3 * s03;
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
     if (set_rows_stride != 0) {
-        idst = row_x * ne0 + i0;
-        idst += row_indices[channel_x] * set_rows_stride;
+        idst = i1 * s1 + i0;
+        idst += row_indices[i2] * set_rows_stride;
     }
 
     const auto & store_coaelsced = [&](float x0, float x1) {
@@ -92,7 +97,7 @@ static __global__ void rope_norm(const T *            x,
         return;
     }
 
-    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
@@ -110,10 +115,15 @@ static __global__ void rope_norm(const T *            x,
 template 
 static __global__ void rope_neox(const T *            x,
                                  D *                  dst,
-                                 const int            ne0,
-                                 const int            ne1,
+                                 const int            ne00,
+                                 const int            ne01,
+                                 const int            ne02,
+                                 const int            s01,
+                                 const int            s02,
+                                 const int            s03,
                                  const int            s1,
                                  const int            s2,
+                                 const int            s3,
                                  const int            n_dims,
                                  const int32_t *      pos,
                                  const float          freq_scale,
@@ -126,23 +136,24 @@ static __global__ void rope_neox(const T *            x,
                                  const int            set_rows_stride) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    int       idst = row_dst * ne0 + i0 / 2;
-    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
     if (set_rows_stride != 0) {
-        idst = row_x * ne0 + i0 / 2;
-        idst += row_indices[channel_x] * set_rows_stride;
+        idst = i1 * s1 + i0 / 2;
+        idst += row_indices[i2] * set_rows_stride;
     }
 
     if (i0 >= n_dims) {
@@ -152,7 +163,7 @@ static __global__ void rope_neox(const T *            x,
         return;
     }
 
-    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
@@ -168,24 +179,42 @@ static __global__ void rope_neox(const T *            x,
     dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta);
 }
 
-template
-static __global__ void rope_multi(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
-        const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
-    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+template 
+static __global__ void rope_multi(const T *            x,
+                                  T *                  dst,
+                                  const int            ne00,
+                                  const int            ne01,
+                                  const int            ne02,
+                                  const int            s01,
+                                  const int            s02,
+                                  const int            s03,
+                                  const int            s1,
+                                  const int            s2,
+                                  const int            s3,
+                                  const int            n_dims,
+                                  const int32_t *      pos,
+                                  const float          freq_scale,
+                                  const float          ext_factor,
+                                  const float          attn_factor,
+                                  const rope_corr_dims corr_dims,
+                                  const float          theta_scale,
+                                  const float *        freq_factors,
+                                  const mrope_sections sections,
+                                  const bool           is_imrope) {
+    const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int idst = row_dst*ne0 + i0/2;
-    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     if (i0 >= n_dims) {
         dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
@@ -200,27 +229,24 @@ static __global__ void rope_multi(
 
     float theta_base = 0.0;
     if (is_imrope) {
-        if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
-            theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
-            theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
-            theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
+        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {         // h
+            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {  // w
+            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {  // t
+            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
         } else {
-            theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
         }
     } else {
         if (sector < sections.v[0]) {
-            theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sections.v[0] && sector < sec_w) {
-            theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
+            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector >= sections.v[0] && sector < sec_w) {
+            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
         }
     }
 
@@ -238,37 +264,53 @@ static __global__ void rope_multi(
     dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template
-static __global__ void rope_vision(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
-        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-        const float theta_scale, const float * freq_factors, const mrope_sections sections) {
+template 
+static __global__ void rope_vision(const T *            x,
+                                   T *                  dst,
+                                   const int            ne00,
+                                   const int            ne01,
+                                   const int            ne02,
+                                   const int            s01,
+                                   const int            s02,
+                                   const int            s03,
+                                   const int            s1,
+                                   const int            s2,
+                                   const int            s3,
+                                   const int            n_dims,
+                                   const int32_t *      pos,
+                                   const float          freq_scale,
+                                   const float          ext_factor,
+                                   const float          attn_factor,
+                                   const rope_corr_dims corr_dims,
+                                   const float          theta_scale,
+                                   const float *        freq_factors,
+                                   const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
     const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int row_x     = row_dst % ne1;
-    const int channel_x = row_dst / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int idst = row_dst*ne0 + i0/2;
-    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
+    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     const int sect_dims = sections.v[0] + sections.v[1];
-    const int sec_w = sections.v[1] + sections.v[0];
-    const int sector = (i0 / 2) % sect_dims;
+    const int sec_w     = sections.v[1] + sections.v[0];
+    const int sector    = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base = pos[channel_x]*powf(theta_scale, p);
-    }
-    else if (sector >= sections.v[0] && sector < sec_w) {
+        theta_base  = pos[i2] * powf(theta_scale, p);
+    } else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
+        theta_base  = pos[i2 + ne02] * powf(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -288,10 +330,15 @@ static __global__ void rope_vision(
 template 
 static void rope_norm_cuda(const T *            x,
                            D *                  dst,
-                           const int            ne0,
-                           const int            ne1,
+                           const int            ne00,
+                           const int            ne01,
+                           const int            ne02,
+                           const int            s01,
+                           const int            s02,
+                           const int            s03,
                            const int            s1,
                            const int            s2,
+                           const int            s3,
                            const int            n_dims,
                            const int            nr,
                            const int32_t *      pos,
@@ -304,31 +351,36 @@ static void rope_norm_cuda(const T *            x,
                            const int64_t *      row_indices,
                            const int            set_rows_stride,
                            cudaStream_t         stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     if (freq_factors == nullptr) {
         rope_norm<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     } else {
         rope_norm<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     }
 }
 
 template 
 static void rope_neox_cuda(const T *            x,
                            D *                  dst,
-                           const int            ne0,
-                           const int            ne1,
+                           const int            ne00,
+                           const int            ne01,
+                           const int            ne02,
+                           const int            s01,
+                           const int            s02,
+                           const int            s03,
                            const int            s1,
                            const int            s2,
+                           const int            s3,
                            const int            n_dims,
                            const int            nr,
                            const int32_t *      pos,
@@ -341,55 +393,92 @@ static void rope_neox_cuda(const T *            x,
                            const int64_t *      row_indices,
                            const int            set_rows_stride,
                            cudaStream_t         stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     if (freq_factors == nullptr) {
         rope_neox<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     } else {
         rope_neox<<>>(
-            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
-            freq_factors, row_indices, set_rows_stride);
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
     }
 }
 
-template
-static void rope_multi_cuda(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+template 
+static void rope_multi_cuda(const T *            x,
+                            T *                  dst,
+                            const int            ne00,
+                            const int            ne01,
+                            const int            ne02,
+                            const int            s01,
+                            const int            s02,
+                            const int            s03,
+                            const int            s1,
+                            const int            s2,
+                            const int            s3,
+                            const int            n_dims,
+                            const int            nr,
+                            const int32_t *      pos,
+                            const float          freq_scale,
+                            const float          freq_base,
+                            const float          ext_factor,
+                            const float          attn_factor,
+                            const rope_corr_dims corr_dims,
+                            const float *        freq_factors,
+                            const mrope_sections sections,
+                            const bool           is_imrope,
+                            cudaStream_t         stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     if (freq_factors == nullptr) {
         rope_multi<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
     } else {
         rope_multi<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
     }
 }
 
-template
-static void rope_vision_cuda(
-        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
-        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
-        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
+template 
+static void rope_vision_cuda(const T *            x,
+                             T *                  dst,
+                             const int            ne00,
+                             const int            ne01,
+                             const int            ne02,
+                             const int            s01,
+                             const int            s02,
+                             const int            s03,
+                             const int            s1,
+                             const int            s2,
+                             const int            s3,
+                             const int            n_dims,
+                             const int            nr,
+                             const int32_t *      pos,
+                             const float          freq_scale,
+                             const float          freq_base,
+                             const float          ext_factor,
+                             const float          attn_factor,
+                             const rope_corr_dims corr_dims,
+                             const float *        freq_factors,
+                             const mrope_sections sections,
+                             cudaStream_t         stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
-    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
     const dim3 block_nums(nr, n_blocks_x, 1);
     // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
     // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@@ -398,11 +487,11 @@ static void rope_vision_cuda(
 
     if (freq_factors == nullptr) {
         rope_vision<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
         rope_vision<<>>(
-            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
             attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
@@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
 
     const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
     const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
+
+    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
+    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
+    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
 
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
@@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
     // compute
     if (is_neox) {
         if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
-            rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                  nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                  freq_factors, row_indices, set_rows_stride, stream);
+            rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                  set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
-            rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                 freq_factors, row_indices, set_rows_stride, stream);
+            rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                 set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
-            rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
-                                                pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                freq_factors, row_indices, set_rows_stride, stream);
+            rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                set_rows_stride, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_mrope && !is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_multi_cuda(
-                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
+            rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                     corr_dims, freq_factors, sections, is_imrope, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_multi_cuda(
-                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
+            rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                     corr_dims, freq_factors, sections, is_imrope, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_vision_cuda(
-                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+            rope_vision_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                      corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_vision_cuda(
-                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+            rope_vision_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+                                      corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else {
         if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
-            rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                  nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                  freq_factors, row_indices, set_rows_stride, stream);
+            rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                  set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
-            rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
-                                                 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                 freq_factors, row_indices, set_rows_stride, stream);
+            rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                 set_rows_stride, stream);
         } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
-            rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
-                                                pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                                                freq_factors, row_indices, set_rows_stride, stream);
+            rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                                                set_rows_stride, stream);
         } else {
             GGML_ABORT("fatal error");
         }
diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu
index 1ae84ebf..285c0e95 100644
--- a/ggml/src/ggml-cuda/softmax.cu
+++ b/ggml/src/ggml-cuda/softmax.cu
@@ -46,7 +46,7 @@ struct soft_max_params {
 };
 
 // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
-// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
+// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here.
 #ifdef __clang__
 #pragma clang diagnostic push
 #pragma clang diagnostic ignored "-Wpass-failed"
@@ -75,9 +75,6 @@ static __global__ void soft_max_f32(
 
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
-
     const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
 
     extern __shared__ float data_soft_max_f32[];
@@ -102,21 +99,7 @@ static __global__ void soft_max_f32(
     }
 
     // find the max value in the block
-    max_val = warp_reduce_max(max_val);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf_iw[lane_id] = -INFINITY;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = max_val;
-        }
-        __syncthreads();
-
-        max_val = buf_iw[lane_id];
-        max_val = warp_reduce_max(max_val);
-    }
+    max_val = block_reduce(max_val, buf_iw);
 
     float tmp = 0.0f; // partial sum
 
@@ -134,22 +117,7 @@ static __global__ void soft_max_f32(
     }
 
     // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        __syncthreads();
-        if (warp_id == 0) {
-            buf_iw[lane_id] = 0.0f;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = tmp;
-        }
-        __syncthreads();
-
-        tmp = buf_iw[lane_id];
-        tmp = warp_reduce_sum(tmp);
-    }
+    tmp = block_reduce(tmp, buf_iw);
 
     if (sinks) {
         tmp += expf(sinks[i02] - max_val);
@@ -169,50 +137,6 @@ static __global__ void soft_max_f32(
     }
 }
 
-
-// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
-static __device__ float two_stage_warp_reduce_max(float val) {
-    val = warp_reduce_max(val);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float local_vals[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            local_vals[warp_id] = val;
-        }
-        __syncthreads();
-        val = -INFINITY;
-        if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) {
-            val = local_vals[lane_id];
-        }
-        return warp_reduce_max(val);
-    } else {
-        return val;
-    }
-}
-
-static __device__ float two_stage_warp_reduce_sum(float val) {
-    val = warp_reduce_sum(val);
-    if (blockDim.x > WARP_SIZE) {
-        assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
-        __shared__ float local_vals[32];
-        const int        warp_id = threadIdx.x / WARP_SIZE;
-        const int        lane_id = threadIdx.x % WARP_SIZE;
-        if (lane_id == 0) {
-            local_vals[warp_id] = val;
-        }
-        __syncthreads();
-        val = 0.0f;
-        if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) {
-            val = local_vals[lane_id];
-        }
-        return warp_reduce_sum(val);
-    } else {
-        return val;
-    }
-}
-
 // TODO: Template to allow keeping ncols in registers if they fit
 static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
                                                                 float * __restrict__ dst,
@@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     float     local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
     float     local_max                     = -INFINITY;
     const int step_size                     = gridDim.x * blockDim.x;
+    __shared__ float shared_vals[32];
 
     // Compute thread-local max
     for (int col = col_start; col < p.ncols;) {
@@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     }
 
     // Compute CTA-level max
-    local_max = two_stage_warp_reduce_max(local_max);
+    local_max = block_reduce(local_max, shared_vals);
 
     // Store CTA-level max to GMEM
     if (tid == 0) {
@@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     } else {
         local_max = -INFINITY;
     }
-    local_max = two_stage_warp_reduce_max(local_max);
+    local_max = block_reduce(local_max, shared_vals);
 
     // Compute softmax dividends, accumulate divisor
     float tmp_expf = 0.0f;
@@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     }
 
     // Reduce divisor within CTA
-    tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+    tmp_expf = block_reduce(tmp_expf, shared_vals);
 
     // Store CTA-level sum to GMEM
     if (tid == 0) {
@@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __
     } else {
         tmp_expf = 0.0f;
     }
-    tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
+    tmp_expf = block_reduce(tmp_expf, shared_vals);
 
     // Divide dividend by global sum + store data
     for (int col = col_start; col < p.ncols;) {
diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu
index 177ffc26..07ca33f5 100644
--- a/ggml/src/ggml-cuda/solve_tri.cu
+++ b/ggml/src/ggml-cuda/solve_tri.cu
@@ -83,7 +83,7 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
 // ======================
 // When ncols_template == 0 the bounds for the loops in this function are not
 // known and can't be unrolled. As we want to keep pragma unroll for all other
-// cases we supress the clang transformation warning here.
+// cases we suppress the clang transformation warning here.
 #ifdef __clang__
 #    pragma clang diagnostic push
 #    pragma clang diagnostic ignored "-Wpass-failed"
diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu
index 6d5ea704..69985cd3 100644
--- a/ggml/src/ggml-cuda/ssm-conv.cu
+++ b/ggml/src/ggml-cuda/ssm-conv.cu
@@ -1,6 +1,7 @@
 #include "ssm-conv.cuh"
+#include "unary.cuh"
 
-template 
+template 
 static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
                                     const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
                                     float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
@@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
         for (size_t j = 0; j < d_conv; j++) {
             sumf += x[(i + j) % d_conv] * w[j];
         }
-        y_block[i * stride_y + tid] = sumf;
+        y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
     }
 }
 
-template 
+template 
 static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
                                                const int src0_nb0, const int src0_nb1, const int src0_nb2,
                                                const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
@@ -65,36 +66,49 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
     const int stride_w = src1_nb1 / sizeof(float);
     const int stride_y = dst_nb1 / sizeof(float);
 
-    float x[d_conv] = { 0.0f };
-    float w[d_conv] = { 0.0f };
+    const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t);
+    const int     n_cols    = d_conv - 1 + split_n_t;
 
+    extern __shared__ float smem[];
+
+    constexpr int load_cols   = d_conv - 1 + split_n_t;
+    constexpr int total_elems = split_d_inner * load_cols;
+    int row = tid / load_cols;
+    int col = tid % load_cols;
+#pragma unroll
+    for (int idx = 0; idx < total_elems; idx += split_d_inner) {
+        if (row < (int)split_d_inner) {
+            smem[row * n_cols + col] = x_block[row * stride_x + col];
+        }
+
+        col += split_d_inner;
+        row += col / load_cols;
+        col  = col % load_cols;
+        if (idx >= total_elems - tid - split_d_inner) {
+            break;
+        }
+    }
+    __syncthreads();
+
+    // Load weights into registers (done once, small)
+    float w[d_conv] = { 0.0f };
 #pragma unroll
     for (size_t j = 0; j < d_conv; j++) {
         w[j] = w_block[tid * stride_w + j];
     }
 
+    // Compute from shared memory
+    for (int64_t i = 0; i < local_n_t; i++) {
+        float sumf = 0.0f;
 #pragma unroll
-    for (int64_t i = 0; i < split_n_t; i++) {
-        if (bidz * split_n_t + i < n_t) {
-            float sumf = 0.0f;
-
-            if (i == 0) {
-                for (size_t j = 0; j < d_conv; j++) {
-                    x[j] = x_block[tid * stride_x + j];
-                }
-            } else {
-                x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
-            }
-
-#pragma unroll
-            for (size_t j = 0; j < d_conv; j++) {
-                sumf += x[(i + j) % d_conv] * w[j];
-            }
-            y_block[i * stride_y + tid] = sumf;
+        for (size_t j = 0; j < d_conv; j++) {
+            sumf += smem[tid * n_cols + i + j] * w[j];
         }
+        y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
     }
 }
 
+template 
 static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
                               const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
                               const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
@@ -106,12 +120,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
         constexpr int kNC = decltype(NC)::value;
         if (n_t <= 32) {
             const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
-            ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+            ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
                                                                        dst, dst_nb0, dst_nb1, dst_nb2, n_t);
         } else {
             const int64_t split_n_t = 32;
             dim3          blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
-            ssm_conv_long_token_f32<<>>(
+            const size_t  smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
+            ssm_conv_long_token_f32<<>>(
                 src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
         }
     };
@@ -124,27 +139,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
     }
 }
 
-void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
     const struct ggml_tensor * src0 = dst->src[0];  // conv_x
     const struct ggml_tensor * src1 = dst->src[1];  // conv1d.weight
+    const bool fuse_silu = silu_dst != nullptr;
+
+    // When fusing, write to silu_dst (the node downstream references).
+    const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;
 
     const int64_t nc  = src1->ne[0];                // d_conv
     const int64_t nr  = src0->ne[1];                // d_inner
-    const int64_t n_t = dst->ne[1];                 // tokens per sequence
-    const int64_t n_s = dst->ne[2];                 // number of sequences in the batch
+    const int64_t n_t = out->ne[1];                 // tokens per sequence
+    const int64_t n_s = out->ne[2];                 // number of sequences in the batch
 
-    GGML_ASSERT(dst->ne[0] == nr);
+    GGML_ASSERT(out->ne[0] == nr);
     GGML_ASSERT(src0->nb[0] == sizeof(float));
     GGML_ASSERT(src1->nb[0] == sizeof(float));
     GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
 
     const float * src0_d = (const float *) src0->data;
     const float * src1_d = (const float *) src1->data;
-    float *       dst_d  = (float *) dst->data;
+    float *       dst_d  = (float *) out->data;
     cudaStream_t  stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
-                      dst->nb[2], nc, nr, n_t, n_s, stream);
+    GGML_ASSERT(out->type == GGML_TYPE_F32);
+    if (fuse_silu) {
+        ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
+                          out->nb[2], nc, nr, n_t, n_s, stream);
+    } else {
+        ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
+                          out->nb[2], nc, nr, n_t, n_s, stream);
+    }
 }
diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh
index 8e6c1f00..f96a1cd2 100644
--- a/ggml/src/ggml-cuda/ssm-conv.cuh
+++ b/ggml/src/ggml-cuda/ssm-conv.cuh
@@ -1,3 +1,3 @@
 #include "common.cuh"
 
-void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu
new file mode 100644
index 00000000..1f554d81
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
index 2074e954..517993cb 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu
new file mode 100644
index 00000000..264751d6
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
index 24c64cf0..97b19c67 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
index 1ada657f..989626df 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
index 86d4ffae..173de7aa 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index a5602da0..e382df1a 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -71,7 +71,7 @@ for type_k in TYPES_KV:
             f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
 
 for ncols in [8, 16, 32, 64]:
-    for ncols2 in [1, 2, 4, 8, 16]:
+    for ncols2 in [1, 2, 4, 8, 16, 32]:
         if ncols2 > ncols:
             continue
         ncols1 = ncols // ncols2
@@ -83,9 +83,9 @@ for ncols in [8, 16, 32, 64]:
                     continue
                 if head_size_kq == 72:
                     continue
-                if head_size_kq != 576 and ncols2 == 16:
+                if head_size_kq != 576 and ncols2 in (16, 32):
                     continue
-                if head_size_kq == 576 and ncols2 != 16:
+                if head_size_kq == 576 and ncols2 not in (4, 16, 32):
                     continue
                 head_size_v = head_size_kq if head_size_kq != 576 else 512
                 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu
index 318ac386..785a1838 100644
--- a/ggml/src/ggml-cuda/top-k.cu
+++ b/ggml/src/ggml-cuda/top-k.cu
@@ -4,7 +4,6 @@
 #ifdef GGML_CUDA_USE_CUB
 #    include 
 #    if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
-#        include 
 #        define CUB_TOP_K_AVAILABLE
 using namespace cub;
 #    endif  // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu
index 48e569ef..3020e5c7 100644
--- a/ggml/src/ggml-cuda/topk-moe.cu
+++ b/ggml/src/ggml-cuda/topk-moe.cu
@@ -5,6 +5,13 @@
 #include 
 #include 
 
+// Kernel config struct - passed by value to CUDA kernel
+struct topk_moe_config {
+    bool use_sigmoid;
+    bool with_norm;
+    bool delayed_softmax;
+};
+
 // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
 template 
 __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
     }
 }
 
+template 
+__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        const int  idx    = lane + i * WARP_SIZE;
+        const bool active = !use_limit || (idx < limit);
+        vals[i]           = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
+    }
+}
+
 /*
     This kernel does the following:
     1. optionally softmax over the logits per token [n_experts, n_tokens]
@@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
 
     It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
 */
-template 
-__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
-                                                                  float *       weights,
-                                                                  int32_t *     ids,
-                                                                  const int     n_rows,
-                                                                  const int     n_expert_used,
-                                                                  const float   clamp_val) {
+template 
+__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *         logits,
+                                                                  float *               weights,
+                                                                  int32_t *             ids,
+                                                                  float *               bias,
+                                                                  const int             n_rows,
+                                                                  const int             n_expert_used,
+                                                                  const float           clamp_val,
+                                                                  const float           scale_val,
+                                                                  const topk_moe_config config) {
     const int row = blockIdx.x * blockDim.y + threadIdx.y;
     if (row >= n_rows) {
         return;
@@ -79,14 +99,53 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
 
     float wt[experts_per_thread];
 
+    // Initialize all slots to -INFINITY
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        wt[i] = -INFINITY;
+    }
+
 #pragma unroll
     for (int i = 0; i < n_experts; i += WARP_SIZE) {
         const int expert  = i + threadIdx.x;
         wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
     }
 
-    if constexpr (!delayed_softmax) {
-        softmax_warp_inplace(wt, n_experts, threadIdx.x);
+    if (!config.delayed_softmax) {
+        if (config.use_sigmoid) {
+           sigmoid_warp_inplace(wt, n_experts, threadIdx.x);
+        } else {
+           softmax_warp_inplace(wt, n_experts, threadIdx.x);
+        }
+    }
+
+    // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs.
+    // NaN comparisons always return false, which would cause the same expert to be
+    // selected repeatedly. -FLT_MAX compares normally and is still excluded by the
+    // -INFINITY sentinel used after each selection round.
+    // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659
+#pragma unroll
+    for (int i = 0; i < experts_per_thread; i++) {
+        if (__isnanf(wt[i])) {
+            wt[i] = -FLT_MAX;
+        }
+    }
+
+    // selection_wt is only needed when bias is present (selection uses wt + bias)
+    // when no bias, we use wt directly for both selection and weight values
+    float selection_wt[has_bias ? experts_per_thread : 1];
+
+    if constexpr (has_bias) {
+#pragma unroll
+        for (int i = 0; i < experts_per_thread; i++) {
+            selection_wt[i] = -INFINITY;
+        }
+#pragma unroll
+        for (int i = 0; i < n_experts; i += WARP_SIZE) {
+            const int expert = i + threadIdx.x;
+            selection_wt[i / WARP_SIZE] =
+                (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
+        }
     }
 
     //at this point, each thread holds either a portion of the softmax distribution
@@ -106,22 +165,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         float max_val    = wt[0];
         int   max_expert = threadIdx.x;
 
-#pragma unroll
-        for (int i = 1; i < experts_per_thread; i++) {
-            const int expert = threadIdx.x + i * WARP_SIZE;
-            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
-                max_val    = wt[i];
-                max_expert = expert;
-            }
-        }
+        if constexpr (has_bias) {
+            float max_val_s = selection_wt[0];
 
 #pragma unroll
-        for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
-            const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
-            const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
-            if (val > max_val || (val == max_val && expert < max_expert)) {
-                max_val    = val;
-                max_expert = expert;
+            for (int i = 1; i < experts_per_thread; i++) {
+                const int expert = threadIdx.x + i * WARP_SIZE;
+                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
+                    max_val    = wt[i];
+                    max_val_s  = selection_wt[i];
+                    max_expert = expert;
+                }
+            }
+
+#pragma unroll
+            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+                const float val_s  = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
+                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+                if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
+                    max_val    = val;
+                    max_val_s  = val_s;
+                    max_expert = expert;
+                }
+            }
+
+            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+                selection_wt[max_expert / WARP_SIZE] = -INFINITY;
+            }
+        } else {
+#pragma unroll
+            for (int i = 1; i < experts_per_thread; i++) {
+                const int expert = threadIdx.x + i * WARP_SIZE;
+                if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+                    max_val    = wt[i];
+                    max_expert = expert;
+                }
+            }
+
+#pragma unroll
+            for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+                const float val    = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+                const int   expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+                if (val > max_val || (val == max_val && expert < max_expert)) {
+                    max_val    = val;
+                    max_expert = expert;
+                }
+            }
+
+            if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+                wt[max_expert / WARP_SIZE] = -INFINITY;
             }
         }
 
@@ -130,16 +223,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
 
         if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
-            wt[max_expert / WARP_SIZE] = -INFINITY;
-
             ids[k] = max_expert;
-            if constexpr (with_norm) {
+            if (config.with_norm) {
                 wt_sum += max_val;
             }
         }
     }
 
-    if constexpr (with_norm) {
+    if (config.with_norm) {
         wt_sum              = warp_reduce_sum(wt_sum);
         wt_sum              = max(wt_sum, clamp_val);
         const float inv_sum = 1.0f / wt_sum;
@@ -149,7 +240,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
         }
     }
 
-    if constexpr (delayed_softmax) {
+    if (config.delayed_softmax) {
         softmax_warp_inplace(output_weights, n_expert_used, threadIdx.x);
     }
 
@@ -157,25 +248,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
     for (int i = 0; i < experts_per_thread; i++) {
         const int idx = i * WARP_SIZE + threadIdx.x;
         if (idx < n_expert_used) {
-            weights[idx] = output_weights[i];
+            weights[idx] = output_weights[i] * scale_val;
         }
     }
-
-    if (!with_norm) {
-        GGML_UNUSED(clamp_val);
-    }
 }
 
-template 
+template
 static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
                                  const float *               logits,
                                  float *                     weights,
                                  int32_t *                   ids,
+                                 float *                     bias,
                                  const int                   n_rows,
                                  const int                   n_expert,
                                  const int                   n_expert_used,
-                                 const float                 clamp_val) {
-    static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
+                                 const float                 clamp_val,
+                                 const float                 scale_val,
+                                 const topk_moe_config       config) {
+    GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
+                "delayed softmax is not supported with weight normalization");
     const int    rows_per_block = 4;
     dim3         grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
     dim3         block_dims(WARP_SIZE, rows_per_block, 1);
@@ -183,44 +274,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
 
     switch (n_expert) {
         case 1:
-            topk_moe_cuda<1, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 2:
-            topk_moe_cuda<2, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 4:
-            topk_moe_cuda<4, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 8:
-            topk_moe_cuda<8, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                   clamp_val, scale_val, config);
             break;
         case 16:
-            topk_moe_cuda<16, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 32:
-            topk_moe_cuda<32, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 64:
-            topk_moe_cuda<64, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                    clamp_val, scale_val, config);
             break;
         case 128:
-            topk_moe_cuda<128, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         case 256:
-            topk_moe_cuda<256, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         case 512:
-            topk_moe_cuda<512, with_norm, delayed_softmax>
-                <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
+            topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
+            break;
+        case 576:
+            topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used,
+                                                                     clamp_val, scale_val, config);
             break;
         default:
             GGML_ASSERT(false && "fatal error");
@@ -228,13 +323,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
     }
 }
 
-void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
-                           const ggml_tensor *         logits,
-                           ggml_tensor *               weights,
-                           ggml_tensor *               ids,
-                           const bool                  with_norm,
-                           const bool                  delayed_softmax,
-                           ggml_tensor *               clamp) {
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
+                           const ggml_tensor *             logits,
+                           ggml_tensor *                   weights,
+                           ggml_tensor *                   ids,
+                           const ggml_tensor *             clamp,
+                           const ggml_tensor *             scale,
+                           const ggml_tensor *             bias,
+                           const ggml_cuda_topk_moe_args & args) {
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -245,107 +341,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
     const float * logits_d  = (const float *) logits->data;
     float *       weights_d = (float *) weights->data;
     int32_t *     ids_d     = (int32_t *) ids->data;
+    float *       bias_d    = bias ? (float *) bias->data : nullptr;
+
+    float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
 
     GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
 
     const int n_expert_used = weights->ne[1];
 
+    const bool with_norm = clamp != nullptr;
+
     float clamp_val = -INFINITY;
-    if (with_norm) {
-        if (clamp) {
-            clamp_val = ggml_get_op_params_f32(clamp, 0);
-        }
-        launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
+    if (clamp) {
+        clamp_val = ggml_get_op_params_f32(clamp, 0);
+    }
+
+    topk_moe_config config;
+    config.use_sigmoid     = args.sigmoid;
+    config.with_norm       = with_norm;
+    config.delayed_softmax = args.delayed_softmax;
+
+    if (bias) {
+        launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+                             scale_val, config);
     } else {
-        GGML_ASSERT(clamp == nullptr);
-        if (delayed_softmax) {
-            launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
-                                              clamp_val);
-        } else {
-            launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
-                                               clamp_val);
-        }
+        launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+                             scale_val, config);
     }
 }
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
                                    const ggml_tensor * weights,
-                                   const ggml_tensor * get_rows,
-                                   const ggml_tensor * argsort,
-                                   const ggml_tensor * clamp,
-                                   int n_expert) {
-    ggml_tensor * probs = get_rows->src[0];
-    if (probs->op != GGML_OP_RESHAPE) {
-        return false;
-    }
-    probs = probs->src[0];
-    ggml_tensor * selection_probs = argsort->src[0];
-
-    if (probs != selection_probs) {
+                                   const ggml_tensor * logits,
+                                   const ggml_tensor * ids) {
+    const int n_expert = ids->nb[1] / ids->nb[0];
+    if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
         return false;
     }
 
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
-
-    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
+    if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
         return false;
     }
 
-    if (scale != 1.0f || max_bias != 0.0f) {
-        return false;
-    }
+    if (gating_op->op == GGML_OP_SOFT_MAX) {
+        const ggml_tensor * softmax  = gating_op;
+        float               scale    = 1.0f;
+        float               max_bias = 0.0f;
 
-    // don't fuse when masks or sinks are present
-    if (softmax->src[1] || softmax->src[2]) {
-        return false;
-    }
+        memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
+        memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
 
-    // n_expert must be a power of 2
-    if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
-        return false;
-    }
-
-    if (clamp) {
-        if (clamp->op != GGML_OP_CLAMP) {
+        if (!ggml_is_contiguous(softmax->src[0])) {
             return false;
         }
-        float max_val = ggml_get_op_params_f32(clamp, 1);
 
-        if (max_val != INFINITY) {
+        if (scale != 1.0f || max_bias != 0.0f) {
+            return false;
+        }
+
+        // don't fuse when masks or sinks are present
+        if (softmax->src[1] || softmax->src[2]) {
+            return false;
+        }
+    } else if (gating_op->op == GGML_OP_UNARY) {
+        ggml_unary_op op = ggml_get_unary_op(gating_op);
+
+        if (op != GGML_UNARY_OP_SIGMOID) {
             return false;
         }
     }
 
-
     return true;
 }
-
-std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
-    static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
-                                                            GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                            GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
-                                                            GGML_OP_RESHAPE };
-
-    static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
-                                                               GGML_OP_VIEW, GGML_OP_GET_ROWS };
-
-    static std::initializer_list delayed_softmax_ops = { GGML_OP_ARGSORT,  GGML_OP_VIEW,
-                                                                       GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                                                       GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
-
-    GGML_ASSERT(!norm || !delayed_softmax);
-
-    if (delayed_softmax) {
-        return delayed_softmax_ops;
-    }
-
-    if (norm) {
-        return norm_ops;
-    }
-
-    return no_norm_ops;
-}
diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh
index 6b6c13c5..243dc2f1 100644
--- a/ggml/src/ggml-cuda/topk-moe.cuh
+++ b/ggml/src/ggml-cuda/topk-moe.cuh
@@ -3,19 +3,25 @@
 
 #include 
 
-void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
-                           const ggml_tensor *         logits,
-                           ggml_tensor *               weights,
-                           ggml_tensor *               ids,
-                           const bool                  with_norm,
-                           const bool                  delayed_softmax = false,
-                           ggml_tensor *               weight_clamp    = nullptr);
+struct ggml_cuda_topk_moe_args {
+    bool sigmoid{};
+    bool softmax{};
+    bool delayed_softmax{};
+    bool prob_bias{};
+    bool norm{};
+    bool scale{};
+};
 
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context &     ctx,
+                           const ggml_tensor *             logits,
+                           ggml_tensor *                   weights,
+                           ggml_tensor *                   ids,
+                           const ggml_tensor *             clamp,
+                           const ggml_tensor *             scale,
+                           const ggml_tensor *             bias,
+                           const ggml_cuda_topk_moe_args & args);
+
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
                                    const ggml_tensor * weights,
-                                   const ggml_tensor * get_rows,
-                                   const ggml_tensor * argsort,
-                                   const ggml_tensor * clamp,
-                                   int n_expert);
-
-std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
+                                   const ggml_tensor * logits,
+                                   const ggml_tensor * ids);
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index d4866067..4ad30fa1 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
         leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
     }
 }
+
+/* fused unary + mul */
+
+template 
+static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
+    // unary_node: UNARY op applied to unary_node->src[0]
+    // mul_node:   MUL(a, b) where one of a/b is unary_node
+    // Output goes to mul_node->data
+
+    const ggml_tensor * unary_src = unary_node->src[0];  // input to the unary op
+    const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous_1(unary_src));
+    GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src));
+    GGML_ASSERT(ggml_is_contiguous_1(other_src));
+    GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src));
+    GGML_ASSERT(ggml_are_same_shape(unary_src, other_src));
+
+    GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16);
+    GGML_ASSERT(unary_src->type == other_src->type);
+    GGML_ASSERT(unary_src->type == mul_node->type);
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t k  = ggml_nelements(mul_node);
+    const int64_t nc = unary_src->ne[0];
+    const int64_t unary_stride = unary_src->nb[1];
+    const int64_t other_stride = other_src->nb[1];
+
+    if (unary_src->type == GGML_TYPE_F16) {
+        unary_gated_cuda((const half *) unary_src->data, (const half *) other_src->data,
+                             (half *) mul_node->data, k, nc,
+                             unary_stride / sizeof(half), other_stride / sizeof(half), stream);
+    } else {
+        unary_gated_cuda((const float *) unary_src->data, (const float *) other_src->data,
+                             (float *) mul_node->data, k, nc,
+                             unary_stride / sizeof(float), other_stride / sizeof(float), stream);
+    }
+}
+
+void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) {
+    switch (ggml_get_unary_op(unary_node)) {
+        case GGML_UNARY_OP_SILU:
+            ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node);
+            break;
+        case GGML_UNARY_OP_SIGMOID:
+            ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node);
+            break;
+        case GGML_UNARY_OP_SOFTPLUS:
+            ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node);
+            break;
+        default:
+            GGML_ABORT("Unsupported unary op for fused unary+mul");
+    }
+}
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index 609046e5..f1dd2183 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
 void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node);
+
 __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
     return x / (1.0f + expf(-x));
 }
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index 6baab117..ab803aca 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
 #endif
 }
 
+static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
+    // v is a 7 bit int, with the 8th sign being encodable as popcnt
+    // with xor we can "correct" the bit instead of having to mask
+    const uint32_t p = __popc(v) & 1;
+    const uint32_t s = v ^ p << 7;
+    // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
+    return s * 0x01010101;
+}
+
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
 // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
 
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
     int sumi = 0;
 #pragma unroll
     for (int k0 = 0; k0 < 8; k0 += 2) {
-        const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
-        const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
+        const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
+        const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
 
-        const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
-        const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
         sumi = ggml_cuda_dp4a(grid0, u0, sumi);
 
-        const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
-        const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
         sumi = ggml_cuda_dp4a(grid1, u1, sumi);
     }
 
-    const int ls = aux32 >> 28;
-    sumi = (ls*sumi + sumi/2)/4;
+    const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
+    sumi = sumi * ls / 8;           // (sumi * scale + sumi / 2) / 4
     const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
     return d * sumi;
 }
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     int sumi1 = 0;
 #pragma unroll
     for (int l0 = 0; l0 < 8; l0 += 2) {
-        const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
-        const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l0/2] >> 9));
-
-        const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
-        const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+        const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
+        const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
 
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
 
         if (l0 < 4) {
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 #pragma unroll
     for (int l0 = 0; l0 < 8; l0 += 2) {
         const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
+        const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
 
-        const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
-
-        const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
-        const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
 
         sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 016b04e5..35d1e1a0 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -138,6 +138,8 @@
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
+#define cudaFuncSetAttribute hipFuncSetAttribute
+#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
 #define __trap() do { abort(); __builtin_unreachable(); } while(0)
 #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
@@ -205,6 +207,14 @@
 #define RDNA3
 #endif // defined(__GFX11__)
 
+#if defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3_5
+#endif // defined(__gfx1150__) || defined(__gfx1151__)
+
+#if defined(RDNA3) && !defined(RDNA3_5)
+#define RDNA3_0
+#endif // defined(RDNA3) && !defined(RDNA3_5)
+
 #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
     defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
 #define RDNA2
diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt
index d58e2878..f3a58354 100644
--- a/ggml/src/ggml-hexagon/CMakeLists.txt
+++ b/ggml/src/ggml-hexagon/CMakeLists.txt
@@ -1,7 +1,29 @@
+file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}"   HEXAGON_SDK_ROOT)
+file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+
+if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
+    message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
+endif()
+
+if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+    message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
+    file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
+    string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
+    message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
+    set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
+    file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+    if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+        message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
+    endif()
+endif()
+
+message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
+
 include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
 include(ExternalProject)
 
 option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
+set(GGML_HEXAGON_HTP_CERT  "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
 set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
 
 add_library(htp_iface OBJECT
@@ -25,56 +47,71 @@ else()
     target_link_options(htp_iface PUBLIC -ldl)
 endif()
 
-link_custom_library(htp_iface cdsprpc)
-link_custom_library(htp_iface rpcmem)
-
 set(TARGET_NAME ggml-hexagon)
 ggml_add_backend_library(${TARGET_NAME}
-    ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
+    ggml-hexagon.cpp
+    htp-drv.cpp
+    htp-drv.h
+    libdl.h
+    ../../include/ggml-hexagon.h)
 
 target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
 target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
 
-# Build HTP bits
-set(HTP_CMAKE_ARGS
-    -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
-    -DCMAKE_BUILD_TYPE=Release
-    -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-    -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
-    -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
-    -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-    -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
+# Build HTP skels
+set(HTP_SKELS)
+function(build_htp_skel V)
+    ExternalProject_Add(htp-${V}
+        SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
+        BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so
+        CMAKE_ARGS
+            -DCMAKE_BUILD_TYPE=Release
+            -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
+            -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
+            -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
+            -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
+            -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
+            -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
+            -DDSP_VERSION=${V}
+            -DPREBUILT_LIB_DIR="toolv19_${V}")
+    list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
+    set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
+endfunction()
 
-ExternalProject_Add(htp-v68
-    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68")
-
-ExternalProject_Add(htp-v69
-    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69")
-
-ExternalProject_Add(htp-v73
-    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
-
-ExternalProject_Add(htp-v75
-    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
-
-ExternalProject_Add(htp-v79
-    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
-
-ExternalProject_Add(htp-v81
-    SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-    CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
+build_htp_skel(v68)
+build_htp_skel(v69)
+build_htp_skel(v73)
+build_htp_skel(v75)
+build_htp_skel(v79)
+build_htp_skel(v81)
 
 # Install Hexagon skels required at runtime
-install(FILES
-    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so
-    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so
-    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
-    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
-    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
-    ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
-    TYPE LIB)
+install(FILES ${HTP_SKELS} TYPE LIB)
+
+if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)
+    file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64"      WINSDK_BIN0_ARM64)
+    file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86"        WINSDK_BIN0_X86)
+    file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64)
+    file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86"   WINSDK_BIN1_X86)
+
+    set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})
+
+    find_program(INF2CAT  NAMES inf2cat.exe  PATHS ${WINSDK_PATHS} REQUIRED)
+    find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)
+
+    message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels")
+
+    set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)
+    add_custom_target(libggml-htp-cat
+        BYPRODUCTS ${LIBGGML_HTP_CAT}
+        DEPENDS libggml-htp.inf ${HTP_SKELS}
+        COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}
+        COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64
+        COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}
+        COMMENT "generating and signing libggml-htp.cat file"
+        VERBATIM
+    )
+
+    add_dependencies(${TARGET_NAME} libggml-htp-cat)
+    install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)
+endif()
diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
index 365a24b4..19917cb1 100644
--- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp
+++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
@@ -14,9 +14,6 @@
 
 #ifdef _WIN32
 #    include 
-#    ifndef _WINDOWS
-#        define _WINDOWS
-#    endif
 #else
 #    include 
 #    include 
@@ -25,8 +22,6 @@
 #pragma clang diagnostic ignored "-Wnested-anon-types"
 #pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
 
-#include "htp-utils.h"
-
 #include 
 #include 
 #include 
@@ -40,14 +35,15 @@
 #include "op-desc.h"
 #include "htp-msg.h"
 #include "htp_iface.h"
+#include "htp-drv.h"
 
 static size_t opt_ndev         = 1;
-static size_t opt_nhvx         = 0;  // use all
-static int    opt_arch         = 0;  // autodetect
+static size_t opt_nhvx         = 0; // use all
+static int    opt_arch         = 0; // autodetect
 static int    opt_etm          = 0;
 static int    opt_verbose      = 0;
 static int    opt_profile      = 0;
-static int    opt_hostbuf      = 1;
+static int    opt_hostbuf      = 1; // hostbuf ON by default
 static int    opt_experimental = 0;
 
 // Enable all stages by default
@@ -143,16 +139,16 @@ struct ggml_hexagon_session {
 };
 
 void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) {
-    // Bump pending flag (cleared in the session::flush once we get the responce)
+    // Bump pending flag (cleared in the session::flush once we get the response)
     this->op_pending++;  // atomic inc
 
     int err = dspqueue_write(this->queue,
                              0,                       // flags - the framework will autoset this
                              n_bufs,                  // number of buffers
                              bufs,                    // buffer references
-                             sizeof(req),
+                             sizeof(req),             // Message length
                              (const uint8_t *) &req,  // Message
-                             1000000                  // Timeout
+                             DSPQUEUE_TIMEOUT         // Timeout
     );
 
     if (err != 0) {
@@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() {
 
         // Read response packet from queue
         int err = dspqueue_read(q, &flags,
-                                   HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references
-                                   &n_bufs,                 // Number of buffer references
-                                   bufs,                    // Buffer references
-                                   sizeof(rsp),             // Max message length
-                                   &rsp_size,               // Message length
-                                   (uint8_t *) &rsp,
-                                   1000000);                // Timeout
+                                HTP_MAX_PACKET_BUFFERS,  // Maximum number of buffer references
+                                &n_bufs,                 // Number of buffer references
+                                bufs,                    // Buffer references
+                                sizeof(rsp),             // Max message length
+                                &rsp_size,               // Message length
+                                (uint8_t *) &rsp,        // Message
+                                DSPQUEUE_TIMEOUT);       // Timeout
 
         if (err == AEE_EEXPIRED) {
             // TODO: might need to bail out if the HTP is stuck on something
@@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context {
     ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
         size += 4 * 1024;  // extra page for padding
 
-        if (rpcmem_alloc2) {
-            this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
-        } else {
-            GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
-            this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
-        }
-
+        this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
         if (!this->base) {
             GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
             throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
@@ -412,6 +402,7 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi
 static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
     static const int qk = QK_Q4_0x4x2;
     const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+    const int        nloe = k % qk;           // leftovers
 
     const int dblk_size = 8 * 2;              // 8x __fp16
     const int qblk_size = qk / 2;             // int4
@@ -445,15 +436,17 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
         unpack_q4_0_quants(qs, &x[i * 8 + 6], 6);
         unpack_q4_0_quants(qs, &x[i * 8 + 7], 7);
 
+        bool partial = (nloe && i == nb-1);
+
         uint8_t * q = y_q + (i * qblk_size);
         for (int j = 0; j < qk / 2; j++) {
-            q[j] = (qs[j + 128] << 4) | qs[j];
+            q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];
         }
     }
 
     // Repack the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Repack the scales
         ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
@@ -477,6 +470,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
 static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
     static const int qk = QK_Q4_0x4x2;
     const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+    const int        nloe = k % qk;           // leftovers
 
     const int dblk_size = 8 * 2;              // 8x __fp16
     const int qblk_size = qk / 2;             // int4
@@ -495,10 +489,17 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
     for (int i = 0; i < nb; i++) {
         uint8_t qs[QK_Q4_0x4x2];  // unpacked quants
 
+        bool partial = (nloe && i == nb-1);
+
         const uint8_t * q = y_q + (i * qblk_size);
         for (int j = 0; j < qk / 2; j++) {
-            qs[j]       = q[j] & 0xf;
-            qs[j + 128] = q[j] >> 4;
+            if (partial) {
+                qs[j*2+0] = q[j] & 0xf;
+                qs[j*2+1] = q[j] >> 4;
+            } else {
+                qs[j+000] = q[j] & 0xf;
+                qs[j+128] = q[j] >> 4;
+            }
         }
 
         pack_q4_0_quants(&x[i * 8 + 0], qs, 0);
@@ -513,7 +514,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
 
     // Repack the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Unpack the scales
         const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
@@ -562,7 +563,7 @@ static void init_row_q4x4x2(block_q4_0 * x, int64_t k) {
 
     // Init the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Unpack the scales
         x[i * 8 + 0].d = 0;
@@ -780,7 +781,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
 
     // Repack the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Repack the scales
         ggml_half * d = (ggml_half *) (y_d + i * dblk_size);
@@ -839,7 +840,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
 
     // Repack the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Unpack the scales
         const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size);
@@ -888,7 +889,7 @@ static void init_row_q8x4x2(block_q8_0 * x, int64_t k) {
 
     // Init the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Unpack the scales
         x[i * 8 + 0].d = 0;
@@ -1088,6 +1089,7 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int
 static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) {
     static const int qk = QK_MXFP4x4x2;
     const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+    const int        nloe = k % qk;           // leftovers
 
     const int eblk_size = 8 * 1;              // 8x E8M0
     const int qblk_size = qk / 2;             // int4
@@ -1122,15 +1124,17 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
         unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6);
         unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7);
 
+        bool partial = (nloe && i == nb-1);
+
         uint8_t * q = y_q + (i * qblk_size);
         for (int j = 0; j < qk / 2; j++) {
-            q[j] = (qs[j + 128] << 4) | qs[j];
+            q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000];
         }
     }
 
     // Repack the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Repack the scales
         uint8_t * e = (uint8_t *) (y_e + i * eblk_size);
@@ -1154,6 +1158,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
 static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) {
     static const int qk = QK_MXFP4x4x2;
     const int        nb = (k + qk - 1) / qk;  // number of blocks (padded)
+    const int        nloe = k % qk;           // leftovers
 
     const int eblk_size = 8 * 1;              // 8x E8M0
     const int qblk_size = qk / 2;             // int4
@@ -1172,10 +1177,17 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k)
     for (int i = 0; i < nb; i++) {
         uint8_t qs[QK_MXFP4x4x2];  // unpacked quants
 
+        bool partial = (nloe && i == nb-1);
+
         const uint8_t * q = y_q + (i * qblk_size);
         for (int j = 0; j < qk / 2; j++) {
-            qs[j]       = q[j] & 0xf;
-            qs[j + 128] = q[j] >> 4;
+            if (partial) {
+                qs[j*2+0] = q[j] & 0xf;
+                qs[j*2+1] = q[j] >> 4;
+            } else {
+                qs[j+000] = q[j] & 0xf;
+                qs[j+128] = q[j] >> 4;
+            }
         }
 
         pack_mxfp4_quants(&x[i * 8 + 0], qs, 0);
@@ -1190,7 +1202,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k)
 
     // Repack the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Unpack the scales
         const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size);
@@ -1239,7 +1251,7 @@ static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) {
 
     // Init the scales
     // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2)
-    // the last block is truncated and overriden by the scales.
+    // the last block is truncated and overridden by the scales.
     for (int i = 0; i < nb; i++) {
         // Unpack the scales
         x[i * 8 + 0].e = 0;
@@ -1753,26 +1765,12 @@ static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b)
 }
 
 static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) {
+    if (!opt_hostbuf) {
+        return ggml_backend_buffer_is_hexagon(b);
+    }
     return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
 }
 
-static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
-    if (x->ne[0] != y->ne[0]) {
-        return false;
-    }
-    if (x->ne[1] != y->ne[1]) {
-        return false;
-    }
-    if (x->ne[2] != y->ne[2]) {
-        return false;
-    }
-    if (x->ne[3] != y->ne[3]) {
-        return false;
-    }
-
-    return true;
-}
-
 static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
     const struct ggml_tensor * src0 = op->src[0];
     const struct ggml_tensor * src1 = op->src[1];
@@ -1804,43 +1802,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
     return opt_experimental;
 }
 
-static bool hex_supported_src0_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src1_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src2_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src1_type2(ggml_type t) {
-    return t == GGML_TYPE_F16;
-}
-
-static bool hex_supported_src1_type3(ggml_type t) {
-    return t == GGML_TYPE_I32;
-}
-
-static bool hex_supported_dst_type(ggml_type t) {
-    return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
-    // TODO: support broadcast for ne[2 and 3]
-    if (x->ne[0] != y->ne[0]) {
-        return false;
-    }
-    if (x->ne[2] != y->ne[2]) {
-        return false;
-    }
-    if (x->ne[3] != y->ne[3]) {
-        return false;
-    }
-    return true;
-}
 
 static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];
@@ -1862,12 +1823,12 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
                 return false;
             }
 
-            if (src0->ne[1] > 16 * 1024) {
+            if (ggml_nrows(src0) > 16 * 1024) {
                 return false;  // typically the lm-head which would be too large for VTCM
             }
 
-            if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
-                return false;
+            if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) {
+                return false;  // no huge batches or broadcasting (for now)
             }
 
             // src0 (weights) must be repacked
@@ -1881,6 +1842,9 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
                 GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n");
                 return false;
             }
+            if (ggml_nrows(src1) > 1024) {
+                return false;  // no huge batches (for now)
+            }
             break;
 
         default:
@@ -1926,24 +1890,30 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
-        return false;
+    if (src0->type == GGML_TYPE_F32) {
+        if (src1->type != GGML_TYPE_F32) {
+            return false;
+        }
+        if (dst->type != GGML_TYPE_F32) {
+            return false;
+        }
     }
-    if (!hex_supported_src1_type(src1->type)) {
-        return false;
+    else if (src0->type == GGML_TYPE_F16) {
+        if (src1->type != GGML_TYPE_F16) {
+            return false;
+        }
+        if (dst->type != GGML_TYPE_F16) {
+            return false;
+        }
     }
-    if (!hex_supported_dst_type(dst->type)) {
-        return false;
-    }
-    if (!hex_supported_dims2(src0, dst)) {
-        return false;
-    }
-    if (!ggml_can_repeat(src1, src0)) {
+    else {
         return false;
     }
 
-    // TODO: add support for non-contigiuos tensors
-    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+    if (!ggml_are_same_shape(src0, dst)) {
+        return false;
+    }
+    if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
         return false;
     }
 
@@ -1955,16 +1925,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_src1_type(src1->type)) {
+    if (src1->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dims2(src0, dst)) {
+    if (!ggml_are_same_shape(src0, dst)) {
         return false;
     }
 
@@ -1980,13 +1950,32 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
     const struct ggml_tensor * src0 = op->src[0];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dims2(src0, dst)) {
+    if (!ggml_are_same_shape(src0, dst)) {
+        return false;
+    }
+
+    // TODO: add support for non-contigiuos tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * dst  = op;
+
+    if (src0->type != GGML_TYPE_F32) {
+        return false;
+    }
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
@@ -2004,10 +1993,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
@@ -2016,10 +2005,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
     }
 
     if (src1) {
-        if (!hex_supported_src1_type(src1->type)) {
+        if (src1->type != GGML_TYPE_F32) {
             return false;
         }
-        if (!hex_supported_dims2(src0, src1)) {
+        if (!ggml_are_same_shape(src0, src1)) {
             return false;
         }
         if (!ggml_is_contiguous(src1)) {
@@ -2040,15 +2029,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
         return false;  // FIXME: add support for sinks
     }
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
     if (src1) {
-        if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
+        if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
             return false;
         }
         if (src0->ne[0] != src1->ne[0]) {
@@ -2118,6 +2107,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
     return true;
 }
 
+static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0]; // values
+    const struct ggml_tensor * dst  = op;         // indices
+
+    if (src0->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    if (dst->type != GGML_TYPE_I32) {
+        return false;
+    }
+
+    if (src0->ne[0] > (16*1024)) {
+        // reject tensors with huge rows for now
+        return false;
+    }
+
+    return true;
+}
+
 static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
     const int32_t * op_params = &op->op_params[0];
 
@@ -2135,17 +2144,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
     const struct ggml_tensor * src2 = op->src[2];
     const struct ggml_tensor * dst  = op;
 
-    if (!hex_supported_src0_type(src0->type)) {
+    if (src0->type != GGML_TYPE_F32) {
         return false;  // FIXME: add support for GGML_TYPE_F16 for src0
     }
-    if (!hex_supported_dst_type(dst->type)) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
-    if (!hex_supported_src1_type3(src1->type)) {
+    if (src1->type != GGML_TYPE_I32) {
         return false;
     }
     if (src2) {
-        if (!hex_supported_src2_type(src2->type)) {
+        if (src2->type != GGML_TYPE_F32) {
             return false;
         }
         int n_dims = op_params[1];
@@ -2168,6 +2177,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
     return true;
 }
 
+static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * dst  = op;
+
+    // Only support FP32 for now
+    if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    // Check IO tensor shapes and dims
+    if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) {
+        return false; // src0 should be effectively 3D
+    }
+
+    const int d_conv = src1->ne[0];
+    const int d_inner = src0->ne[1];
+    const int n_t = dst->ne[1];
+    const int n_s = dst->ne[2];
+
+    if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) {
+        return false;
+    }
+    if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) {
+        return false;
+    }
+    if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) {
+        return false;
+    }
+
+    // TODO: add support for non-contiguous tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    return true;
+}
+
 enum dspqbuf_type {
     DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0,
     DSPQBUF_TYPE_CPU_WRITE_DSP_READ,
@@ -2285,6 +2332,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
         case GGML_OP_SUB:
             req->op = HTP_OP_SUB;
             break;
+        case GGML_OP_DIV:
+            req->op = HTP_OP_DIV;
+            break;
         default:
             GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
             break;
@@ -2302,6 +2352,16 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
     return n_bufs;
 }
 
+static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_CPY;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     req->op = HTP_OP_GET_ROWS;
 
@@ -2313,6 +2373,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
     return n_bufs;
 }
 
+static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_ARGSORT;
+    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 template 
 static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     switch (t->op) {
@@ -2367,6 +2438,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
             supported = true;
             break;
 
+        case GGML_OP_SQR:
+            req->op   = HTP_OP_SQR;
+            supported = true;
+            break;
+
+        case GGML_OP_SQRT:
+            req->op   = HTP_OP_SQRT;
+            supported = true;
+            break;
+
         case GGML_OP_UNARY:
             if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
                 req->op   = HTP_OP_UNARY_SILU;
@@ -2384,6 +2465,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
             } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
                 req->op   = HTP_OP_GLU_SWIGLU_OAI;
                 supported = true;
+            } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
+                req->op   = HTP_OP_GLU_GEGLU;
+                supported = true;
             }
             break;
 
@@ -2408,6 +2492,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
     return n_bufs;
 }
 
+static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+    req->op = HTP_OP_SUM_ROWS;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
     req->op = HTP_OP_ROPE;
@@ -2436,6 +2531,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf
     return n_bufs;
 }
 
+static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_SSM_CONV;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
     auto sess = static_cast(backend->context);
     return sess->name.c_str();
@@ -2448,12 +2554,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
 }
 
 static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
-    return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type));
+    return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
 }
 
 static inline bool is_compute_op(ggml_tensor *node)
 {
-    return !(ggml_op_is_empty(node->op) || ggml_is_empty(node));
+    return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
 }
 
 // scan the graph and figure out last compute op index
@@ -2475,7 +2581,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
 
     const int last = last_compute_op(graph);
 
-    const struct ggml_tensor * prev_quant_op = nullptr;  // prev executed op with quantizer
+    const struct ggml_tensor * prev_op = nullptr;  // prev executed op
 
     for (int i = 0; i < graph->n_nodes; ++i) {
         ggml_tensor * node = graph->nodes[i];
@@ -2487,10 +2593,12 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
         uint32_t flags = 0;
 
         // skip quantizer if src1 is reused
-        if (op_reuse_src1(node, prev_quant_op)) {
+        if (op_reuse_src1(node, prev_op)) {
             flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
         }
 
+        prev_op = node;
+
         // ask for early notification for the last Op
         if (i == last) {
             flags |= HTP_OPFLAGS_EARLY_WAKEUP;
@@ -2503,7 +2611,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 } else {
                     ggml_hexagon_dispatch_op>(sess, node, flags);
                 }
-                prev_quant_op = node;
                 break;
             case GGML_OP_MUL_MAT_ID:
                 if (ggml_is_quantized(node->src[0]->type)) {
@@ -2511,11 +2618,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 } else {
                     ggml_hexagon_dispatch_op>(sess, node, flags);
                 }
-                prev_quant_op = node;
                 break;
             case GGML_OP_MUL:
             case GGML_OP_ADD:
             case GGML_OP_SUB:
+            case GGML_OP_DIV:
                 ggml_hexagon_dispatch_op>(sess, node, flags);
                 break;
             case GGML_OP_ADD_ID:
@@ -2525,6 +2632,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
             case GGML_OP_SCALE:
                 ggml_hexagon_dispatch_op(sess, node, flags);
                 break;
+            case GGML_OP_SQR:
+            case GGML_OP_SQRT:
+                ggml_hexagon_dispatch_op(sess, node, flags);
+                break;
+            case GGML_OP_SUM_ROWS:
+                ggml_hexagon_dispatch_op(sess, node, flags);
+                break;
             case GGML_OP_UNARY:
                 if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
                         (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
@@ -2533,7 +2647,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 break;
             case GGML_OP_GLU:
                 if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
-                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
+                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
+                        (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
                     ggml_hexagon_dispatch_op(sess, node, flags);
                 }
                 break;
@@ -2557,6 +2672,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op(sess, node, flags);
                 break;
 
+            case GGML_OP_CPY:
+                ggml_hexagon_dispatch_op(sess, node, flags);
+                break;
+
+            case GGML_OP_ARGSORT:
+                ggml_hexagon_dispatch_op(sess, node, flags);
+                break;
+
+            case GGML_OP_SSM_CONV:
+                ggml_hexagon_dispatch_op(sess, node, flags);
+                break;
+
             default:
                 GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
         }
@@ -2632,7 +2759,7 @@ static std::vector ggml_hexagon_graph_optimize_reorder(const std::vector ggml_hexagon_graph_optimize_reorder(const std::vectorsrc[0];
+    const struct ggml_tensor * dst  = op;
+
+    // for now we can do f32 -> f16 and f16 -> f32 (without reshaping)
+    if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
+    if ( dst->type != GGML_TYPE_F32 &&  dst->type != GGML_TYPE_F16) return false;
+
+    const bool sametype   = (src0->type == dst->type);
+    const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst);
+    const bool sameshape  = !transposed && ggml_are_same_shape(src0, dst);
+
+    // can handle any shape and any same-type (pretty slow if reshaping is required)
+    if (sametype) return true;
+
+    // cannot handle re-shaping and type conversion at the same time
+    if (!sameshape) return false;
+
+    return true;
+}
+
 static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
     auto sess = static_cast(dev->context);
 
@@ -2888,6 +3036,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
         case GGML_OP_MUL:
         case GGML_OP_ADD:
         case GGML_OP_SUB:
+        case GGML_OP_DIV:
             supp = ggml_hexagon_supported_binary(sess, op);
             break;
 
@@ -2900,6 +3049,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_unary(sess, op);
             break;
 
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+            supp = ggml_hexagon_supported_unary(sess, op);
+            break;
+
+        case GGML_OP_SUM_ROWS:
+            supp = ggml_hexagon_supported_sum_rows(sess, op);
+            break;
+
         case GGML_OP_SOFT_MAX:
             supp = ggml_hexagon_supported_softmax(sess, op);
             break;
@@ -2915,7 +3073,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
         case GGML_OP_GLU:
             {
                 const auto glu_op = ggml_get_glu_op(op);
-                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
+                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
                     supp = ggml_hexagon_supported_activations(sess, op);
                 }
                 break;
@@ -2936,6 +3094,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_get_rows(sess, op);
             break;
 
+        case GGML_OP_CPY:
+            supp = ggml_hexagon_supported_cpy(sess, op);
+            break;
+
+        case GGML_OP_ARGSORT:
+            supp = ggml_hexagon_supported_argsort(sess, op);
+            break;
+
+        case GGML_OP_SSM_CONV:
+            supp = ggml_hexagon_supported_ssm_conv(sess, op);
+            break;
+
         default:
             break;
     }
@@ -3010,10 +3180,12 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
         }
     }
 
+#if defined(__ANDROID__)
     if (opt_arch < 75) {
         opt_ndev = 1;
         GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
     }
+#endif
 
     GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
 
@@ -3061,7 +3233,7 @@ static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t
 }
 
 static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) {
-    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
+    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) {
         ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type;
         return (void *) fct;
     }
@@ -3078,34 +3250,31 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
     static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
                   "please update hexagon_type to match ggml_type");
 
+    const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
     const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
     const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF");
+    const char * str_opmask  = getenv("GGML_HEXAGON_OPMASK");
+    const char * str_opsync  = getenv("GGML_HEXAGON_OPSYNC");
+    const char * str_profile = getenv("GGML_HEXAGON_PROFILE");
+    const char * str_etm     = getenv("GGML_HEXAGON_ETM");
+    const char * str_nhvx    = getenv("GGML_HEXAGON_NHVX");
+    const char * str_ndev    = getenv("GGML_HEXAGON_NDEV");
+    const char * str_arch    = getenv("GGML_HEXAGON_ARCH");
 
+    opt_experimental = str_experimental ? atoi(str_experimental) : 0;
     opt_verbose      = str_verbose ? atoi(str_verbose) : 0;
-    opt_profile      = getenv("GGML_HEXAGON_PROFILE") != nullptr;
-    opt_etm          = getenv("GGML_HEXAGON_ETM") != nullptr;
-    opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr;
+    opt_hostbuf      = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
+    opt_opmask       = str_opmask  ? strtoul(str_opmask, NULL, 0) : opt_opmask;
+    opt_opsync       = str_opsync  ? atoi(str_opsync)  : 0;
+    opt_profile      = str_profile ? atoi(str_profile) : 0;
+    opt_etm          = str_etm     ? atoi(str_etm) : 0;
+    opt_nhvx         = str_nhvx    ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
+    opt_ndev         = str_ndev    ? strtoul(str_ndev, NULL, 0) : opt_ndev;
 
-    const char * str_opmask = getenv("GGML_HEXAGON_OPMASK");
-    if (str_opmask != nullptr) {
-        opt_opmask = strtoul(str_opmask, NULL, 0);
-    }
-    opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr;
-
-    const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
-    if (str_ndev) {
-        opt_ndev = strtoul(str_ndev, NULL, 0);
-        if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
-            opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
-        }
+    if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) {
+        opt_ndev = GGML_HEXAGON_MAX_SESSIONS;
     }
 
-    const char * str_nhvx = getenv("GGML_HEXAGON_NHVX");
-    if (str_nhvx) {
-        opt_nhvx = strtoul(str_nhvx, NULL, 0);
-    }
-
-    const char * str_arch = getenv("GGML_HEXAGON_ARCH");
     if (str_arch) {
         if (str_arch[0] == 'v') {
             str_arch++;
@@ -3139,6 +3308,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
         static std::mutex           mutex;
         std::lock_guard lock(mutex);
         if (!initialized) {
+            auto nErr = htpdrv_init();
+            if (nErr != AEE_SUCCESS) {
+                return NULL;
+            }
+
             ggml_hexagon_init(®);
         }
 
diff --git a/ggml/src/ggml-hexagon/htp-drv.cpp b/ggml/src/ggml-hexagon/htp-drv.cpp
new file mode 100644
index 00000000..4c376b5f
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp-drv.cpp
@@ -0,0 +1,418 @@
+// sample drv interface
+
+#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
+#pragma clang diagnostic ignored "-Wmissing-prototypes"
+#pragma clang diagnostic ignored "-Wsign-compare"
+
+#include 
+#include 
+#include 
+#include 
+#ifdef _WIN32
+#   define WIN32_LEAN_AND_MEAN
+#   ifndef NOMINMAX
+#       define NOMINMAX
+#   endif
+#   include 
+#   include 
+#else
+#    include 
+#    include 
+#endif
+#include "ggml-impl.h"
+#include "htp-drv.h"
+#include "libdl.h"
+
+#include 
+
+//
+// Driver API types
+//
+
+typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
+typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
+typedef void   (*rpcmem_free_pfn_t)(void * po);
+typedef int    (*rpcmem_to_fd_pfn_t)(void * po);
+
+typedef AEEResult (*dspqueue_create_pfn_t)(int                 domain,
+                                           uint32_t            flags,
+                                           uint32_t            req_queue_size,
+                                           uint32_t            resp_queue_size,
+                                           dspqueue_callback_t packet_callback,
+                                           dspqueue_callback_t error_callback,
+                                           void *              callback_context,
+                                           dspqueue_t *        queue);
+typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
+typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
+typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
+                                          uint32_t num_buffers,
+                                          struct dspqueue_buffer *buffers,
+                                          uint32_t message_length,
+                                          const uint8_t *message,
+                                          uint32_t timeout_us);
+typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
+                                         uint32_t max_buffers, uint32_t *num_buffers,
+                                         struct dspqueue_buffer *buffers,
+                                         uint32_t max_message_length,
+                                         uint32_t *message_length, uint8_t *message,
+                                         uint32_t timeout_us);
+
+typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
+typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
+
+typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
+typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
+typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
+typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
+typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
+typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
+
+//
+// Driver API pfns
+//
+
+rpcmem_alloc_pfn_t  rpcmem_alloc_pfn  = nullptr;
+rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
+rpcmem_free_pfn_t   rpcmem_free_pfn   = nullptr;
+rpcmem_to_fd_pfn_t  rpcmem_to_fd_pfn  = nullptr;
+
+fastrpc_mmap_pfn_t   fastrpc_mmap_pfn   = nullptr;
+fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
+
+dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
+dspqueue_close_pfn_t  dspqueue_close_pfn  = nullptr;
+dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
+dspqueue_write_pfn_t  dspqueue_write_pfn  = nullptr;
+dspqueue_read_pfn_t   dspqueue_read_pfn   = nullptr;
+
+remote_handle64_open_pfn_t    remote_handle64_open_pfn    = nullptr;
+remote_handle64_invoke_pfn_t  remote_handle64_invoke_pfn  = nullptr;
+remote_handle64_close_pfn_t   remote_handle64_close_pfn   = nullptr;
+remote_handle_control_pfn_t   remote_handle_control_pfn   = nullptr;
+remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
+remote_session_control_pfn_t  remote_session_control_pfn  = nullptr;
+
+//
+// Driver API
+//
+
+void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
+    return rpcmem_alloc_pfn(heapid, flags, size);
+}
+
+void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
+    if (rpcmem_alloc2_pfn) {
+        return rpcmem_alloc2_pfn(heapid, flags, size);
+    } else {
+        GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
+        return rpcmem_alloc_pfn(heapid, flags, size);
+    }
+}
+
+void rpcmem_free(void * po) {
+    return rpcmem_free_pfn(po);
+}
+
+int rpcmem_to_fd(void * po) {
+    return rpcmem_to_fd_pfn(po);
+}
+
+HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
+    return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
+}
+
+HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
+    return fastrpc_munmap_pfn(domain, fd, addr, length);
+}
+
+AEEResult dspqueue_create(int                 domain,
+                          uint32_t            flags,
+                          uint32_t            req_queue_size,
+                          uint32_t            resp_queue_size,
+                          dspqueue_callback_t packet_callback,
+                          dspqueue_callback_t error_callback,
+                          void *              callback_context,
+                          dspqueue_t *        queue) {
+    return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
+                               callback_context, queue);
+}
+
+AEEResult dspqueue_close(dspqueue_t queue) {
+    return dspqueue_close_pfn(queue);
+}
+
+AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
+    return dspqueue_export_pfn(queue, queue_id);
+}
+
+AEEResult dspqueue_write(dspqueue_t               queue,
+                         uint32_t                 flags,
+                         uint32_t                 num_buffers,
+                         struct dspqueue_buffer * buffers,
+                         uint32_t                 message_length,
+                         const uint8_t *          message,
+                         uint32_t                 timeout_us) {
+    return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
+}
+
+AEEResult dspqueue_read(dspqueue_t               queue,
+                        uint32_t *               flags,
+                        uint32_t                 max_buffers,
+                        uint32_t *               num_buffers,
+                        struct dspqueue_buffer * buffers,
+                        uint32_t                 max_message_length,
+                        uint32_t *               message_length,
+                        uint8_t *                message,
+                        uint32_t                 timeout_us) {
+    return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
+                             message, timeout_us);
+}
+
+HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
+    return remote_handle64_open_pfn(name, ph);
+}
+
+HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
+    return remote_handle64_invoke_pfn(h, dwScalars, pra);
+}
+
+HTPDRV_API int remote_handle64_close(remote_handle64 h) {
+    return remote_handle64_close_pfn(h);
+}
+
+HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
+    return remote_handle_control_pfn(req, data, datalen);
+}
+
+HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
+    return remote_handle64_control_pfn(h, req, data, datalen);
+}
+
+HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
+    return remote_session_control_pfn(req, data, datalen);
+}
+
+#ifdef _WIN32
+
+static std::string wstr_to_str(std::wstring_view wstr) {
+    std::string result;
+    if (wstr.empty()) {
+        return result;
+    }
+    auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
+                                            wstr.data(), (int) wstr.size(),
+                                            nullptr, 0, nullptr, nullptr);
+    if (bytes_needed == 0) {
+        GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
+        throw std::runtime_error("Invalid wstring input");
+    }
+
+    result.resize(bytes_needed, '\0');
+    int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
+                                            wstr.data(), (int) wstr.size(),
+                                            result.data(), bytes_needed,
+                                            nullptr, nullptr);
+    if (bytes_written == 0) {
+        GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
+        throw std::runtime_error("Wstring conversion failed");
+    }
+    return result;
+}
+
+static std::string get_driver_path() {
+    std::wstring serviceName = L"qcnspmcdm";
+    std::string result;
+
+    // Get a handle to the SCM database.
+    SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
+    if (nullptr == schSCManager) {
+        GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
+        return result;
+    }
+
+    // Get a handle to the service.
+    SC_HANDLE schService = OpenServiceW(schSCManager,           // SCM database
+                                        serviceName.c_str(),    // name of service
+                                        SERVICE_QUERY_CONFIG);  // need query config access
+
+    if (nullptr == schService) {
+        GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
+        CloseServiceHandle(schSCManager);
+        return result;
+    }
+
+    // Store the size of buffer used as an output.
+    DWORD bufferSize;
+    if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
+        (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
+        GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
+        CloseServiceHandle(schService);
+        CloseServiceHandle(schSCManager);
+        return result;
+    }
+    // Get the configuration of the service.
+    LPQUERY_SERVICE_CONFIGW serviceConfig =
+        static_cast(LocalAlloc(LMEM_FIXED, bufferSize));
+    if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
+        fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
+        LocalFree(serviceConfig);
+        CloseServiceHandle(schService);
+        CloseServiceHandle(schSCManager);
+        return result;
+    }
+
+    // Read the driver file path get its parent directory
+    std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
+    driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
+
+    // Clean up resources
+    LocalFree(serviceConfig);
+    CloseServiceHandle(schService);
+    CloseServiceHandle(schSCManager);
+
+    // Driver path would contain invalid path string, like:
+    // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
+    // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
+    const std::wstring systemRootPlaceholder = L"\\SystemRoot";
+    if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
+        GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
+        return result;
+    }
+
+    // Replace \SystemRoot with an absolute path from system ENV windir
+    const std::wstring systemRootEnv = L"windir";
+
+    // Query the number of wide characters this variable requires
+    DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
+    if (numWords == 0) {
+        GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
+        return result;
+    }
+
+    // Query the actual system root name from environment variable
+    std::vector systemRoot(numWords + 1);
+    numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
+    if (numWords == 0) {
+        GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
+        return result;
+    }
+    driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
+
+    return wstr_to_str(driverPath);
+}
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr;
+
+int htpdrv_init() {
+    static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
+    static bool initialized = false;
+#ifdef _WIN32
+    std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
+#else
+    std::string drv_path = "libcdsprpc.so";
+#endif
+    if (initialized) {
+        GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
+        return AEE_SUCCESS;
+    }
+    GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
+
+    fs::path path{ drv_path.c_str() };
+    dl_handle_ptr handle { dl_load_library(path) };
+    if (!handle) {
+        GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
+        return AEE_EUNABLETOLOAD;
+    }
+
+#define dlsym(drv, type, pfn, symbol, ignore)                               \
+    do {                                                                    \
+        pfn = (type) dl_get_sym(drv, #symbol);                              \
+        if (!ignore && nullptr == pfn) {                                    \
+            GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol);      \
+            return AEE_EUNABLETOLOAD;                                       \
+        }                                                                   \
+    } while (0)
+
+    dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
+    dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
+    dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
+    dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
+    dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
+    dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
+    dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
+    dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
+    dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
+    dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
+    dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
+    dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
+    dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
+    dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
+    dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
+    dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
+    dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
+
+    lib_cdsp_rpc_handle = std::move(handle);
+    initialized         = true;
+
+    return AEE_SUCCESS;
+}
+
+domain * get_domain(int domain_id) {
+    int i    = 0;
+    int size = sizeof(supported_domains) / sizeof(domain);
+
+    for (i = 0; i < size; i++) {
+        if (supported_domains[i].id == domain_id) {
+            return &supported_domains[i];
+        }
+    }
+
+    return NULL;
+}
+
+int get_hex_arch_ver(int domain, int * arch) {
+    if (!remote_handle_control_pfn) {
+        GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
+        return AEE_EUNSUPPORTEDAPI;
+    }
+
+    struct remote_dsp_capability arch_ver;
+    arch_ver.domain       = (uint32_t) domain;
+    arch_ver.attribute_ID = ARCH_VER;
+    arch_ver.capability   = (uint32_t) 0;
+
+    int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
+    if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
+        GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
+        return AEE_EUNSUPPORTEDAPI;
+    }
+
+    if (err != AEE_SUCCESS) {
+        GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
+        return err;
+    }
+
+    switch (arch_ver.capability & 0xff) {
+        case 0x68:
+            *arch = 68;
+            return 0;
+        case 0x69:
+            *arch = 69;
+            return 0;
+        case 0x73:
+            *arch = 73;
+            return 0;
+        case 0x75:
+            *arch = 75;
+            return 0;
+        case 0x79:
+            *arch = 79;
+            return 0;
+        case 0x81:
+            *arch = 81;
+            return 0;
+    }
+    return -1;
+}
diff --git a/ggml/src/ggml-hexagon/htp-drv.h b/ggml/src/ggml-hexagon/htp-drv.h
new file mode 100644
index 00000000..6eba7ba1
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp-drv.h
@@ -0,0 +1,121 @@
+#pragma once
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifdef _WIN32
+#    pragma clang diagnostic ignored "-Wignored-attributes"
+#endif
+
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_WIN32) && !defined(__MINGW32__)
+#    ifdef GGML_BACKEND_BUILD
+#        define HTPDRV_API __declspec(dllexport) extern
+#    else
+#        define HTPDRV_API __declspec(dllimport) extern
+#    endif
+#else
+#    define HTPDRV_API __attribute__ ((visibility ("default"))) extern
+#endif
+
+/* Offset to differentiate HLOS and Hexagon error codes.
+   Stores the value of AEE_EOFFSET for Hexagon. */
+#ifndef DSP_OFFSET
+#    define DSP_OFFSET 0x80000400
+#endif
+
+/* Errno for connection reset by peer. */
+#ifndef ECONNRESET
+#    ifdef __hexagon__
+#        define ECONNRESET 104
+#    endif
+#endif
+
+/* Abstraction of different OS specific sleep APIs.
+   SLEEP accepts input in seconds. */
+#ifndef SLEEP
+#    ifdef __hexagon__
+#        define SLEEP(x)                      \
+            { /* Do nothing for simulator. */ \
+            }
+#    else
+#        ifdef _WIN32
+#            define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
+#        else
+#            define SLEEP(x) sleep(x)        /* sleep accepts input in seconds. */
+#        endif
+#    endif
+#endif
+
+/* Include windows specific header files. */
+#ifdef _WIN32
+#    include 
+#    include 
+#    define _CRT_SECURE_NO_WARNINGS         1
+#    define _WINSOCK_DEPRECATED_NO_WARNINGS 1
+#endif
+
+/* Includes and defines for all HLOS except windows */
+#if !defined(__hexagon__) && !defined(_WIN32)
+#    include "unistd.h"
+
+#    include 
+#endif
+
+/* Includes and defines for Hexagon and all HLOS except Windows. */
+#if !defined(_WIN32)
+/* Weak reference to remote symbol for compilation. */
+#    pragma weak remote_session_control
+#    pragma weak remote_handle_control
+#    pragma weak remote_handle64_control
+#    pragma weak fastrpc_mmap
+#    pragma weak fastrpc_munmap
+#    pragma weak rpcmem_alloc2
+#endif
+
+#if !defined(_WIN32)
+#    pragma weak remote_system_request
+#endif
+
+#ifdef _WIN32
+#     define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE
+#else
+#     define DSPQUEUE_TIMEOUT 1000000
+#endif
+
+/**
+ * htpdrv_init API: driver interface entry point
+ *
+ * @return      Return AEE error codes as defined in Hexagon SDK.
+ */
+HTPDRV_API int htpdrv_init(void);
+
+/**
+ * get_domain API: get domain struct from domain value.
+ *
+ * @param[in]  domain value of a domain
+ * @return     Returns domain struct of the domain if it is supported or else
+ *             returns NULL.
+ *
+ */
+HTPDRV_API domain * get_domain(int domain_id);
+
+/**
+ * get_hex_arch_ver API: query the Hexagon processor architecture version information
+ *
+ * @param[in]   domain_id value of a domain
+ * @param[out]  Arch version (73, 75, ...)
+ * @return      0 if query is successful.
+ *              non-zero if error, return value points to the error.
+ *
+ */
+HTPDRV_API int get_hex_arch_ver(int domain, int * arch);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c
deleted file mode 100644
index 3f335bf7..00000000
--- a/ggml/src/ggml-hexagon/htp-utils.c
+++ /dev/null
@@ -1,454 +0,0 @@
-
-#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
-#pragma clang diagnostic ignored "-Wmissing-prototypes"
-#pragma clang diagnostic ignored "-Wsign-compare"
-
-#define GGML_COMMON_IMPL_C
-#include "ggml-backend-impl.h"
-#include "ggml-common.h"
-#include "ggml-hexagon.h"
-#include "ggml-impl.h"
-
-#include "htp-utils.h"
-
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-
-domain * get_domain(int domain_id) {
-    int i    = 0;
-    int size = sizeof(supported_domains) / sizeof(domain);
-
-    for (i = 0; i < size; i++) {
-        if (supported_domains[i].id == domain_id) {
-            return &supported_domains[i];
-        }
-    }
-
-    return NULL;
-}
-
-bool is_valid_domain_id(int domain_id, int compute_only) {
-    int i    = 0;
-    int size = sizeof(supported_domains) / sizeof(domain);
-
-    if (compute_only) {
-        return is_CDSP(domain_id);
-    }
-
-    for (i = 0; i < size; i++) {
-        if (supported_domains[i].id == domain_id) {
-            return true;
-        }
-    }
-
-    return false;
-}
-
-int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
-    int nErr    = AEE_SUCCESS;
-    int ss_info = 0;
-    if (domain_type != NULL) {
-        if (strcmp(domain_type, "LPASS") == 0) {
-            ss_info = FASTRPC_LPASS;
-        } else if (strcmp(domain_type, "HPASS") == 0) {
-            ss_info = FASTRPC_HPASS;
-        } else {
-            ss_info = FASTRPC_NSP;
-        }
-    }
-    system_req_payload req  = { 0 };
-    req.id                  = FASTRPC_GET_DOMAINS;
-    req.sys.domains         = NULL;
-    fastrpc_domain * domain = NULL;
-    if (ss_info != 0) {
-        req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
-    } else {
-        req.sys.flags = 0;
-    }
-#ifdef _WIN32
-    nErr = AEE_EUNSUPPORTED;
-    goto bail;
-#endif
-    if (remote_system_request) {
-        nErr = remote_system_request(&req);
-        if (nErr != AEE_SUCCESS) {
-            GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
-            goto bail;
-        }
-        // Allocate memory for domain-info array
-        req.sys.max_domains = req.sys.num_domains;
-        if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
-            nErr = AEE_ENOMEMORY;
-            GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
-            goto bail;
-        }
-
-        nErr = remote_system_request(&req);
-        if (nErr != AEE_SUCCESS) {
-            GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
-            goto bail;
-        }
-
-        for (int i = 0; i < req.sys.num_domains; i++) {
-            // Verify that only requested type domains were returned
-            domain = &req.sys.domains[i];
-            if (domain->type != ss_info && domain_type != NULL) {
-                nErr = -1;
-                GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
-                goto bail;
-            }
-        }
-        *domains_info = req.sys.domains;
-        *num_domains  = req.sys.num_domains;
-    } else {
-        nErr = AEE_EUNSUPPORTED;
-        goto bail;
-    }
-bail:
-    if (nErr && !req.sys.domains) {
-        free(req.sys.domains);
-    }
-    return nErr;
-}
-
-int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
-    int                              err  = 0;
-    remote_rpc_effective_domain_id_t sess = { 0 };
-
-    sess.domain_name     = domain_name;
-    sess.domain_name_len = strlen(domain_name);
-    sess.session_id      = session_id;
-
-    err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
-    if (err) {
-        GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
-               session_id);
-        return err;
-    }
-
-    *effec_domain_id = sess.effective_domain_id;
-    return err;
-}
-
-int get_dsp_support(int * domain) {
-    int nErr = AEE_SUCCESS;
-    *domain  = CDSP_DOMAIN_ID;  // DSP domain default value is CDSP_DOMAIN_ID
-
-    if (remote_handle_control) {
-        struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
-        nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
-        if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-            GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
-            goto bail;
-        }
-
-        if (dsp_capability_domain.capability == 0) {
-            dsp_capability_domain.domain       = ADSP_DOMAIN_ID;  // Check for ADSP support.
-            dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
-            dsp_capability_domain.capability   = 0;
-            nErr                               = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
-                                                                       sizeof(struct remote_dsp_capability));
-            if (dsp_capability_domain.capability) {
-                *domain = ADSP_DOMAIN_ID;  // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
-            }
-        }
-
-        if (nErr != AEE_SUCCESS) {
-            GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
-            goto bail;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
-    }
-
-bail:
-    return nErr;
-}
-
-int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
-    int nErr    = AEE_SUCCESS;
-    *capability = 0;
-
-    if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
-    } else {
-        nErr = AEE_EBADPARM;
-        GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
-        goto bail;
-    }
-    if (remote_handle_control) {
-        if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
-            /*
-            * Query the DSP for VTCM information
-            * Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
-            */
-            struct remote_dsp_capability dsp_capability_vtcm_dsp;
-            dsp_capability_vtcm_dsp.domain       = (uint32_t) domain;
-            dsp_capability_vtcm_dsp.attribute_ID = attr;
-            dsp_capability_vtcm_dsp.capability   = (uint32_t) 0;
-            nErr                                 = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
-                                                                         sizeof(struct remote_dsp_capability));
-            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
-                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
-                nErr = AEE_SUCCESS;
-                goto bail;
-            } else if (nErr == AEE_SUCCESS) {
-                *capability = dsp_capability_vtcm_dsp.capability;
-            } else {
-                GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
-                goto bail;
-            }
-        } else {
-            nErr = AEE_EUNSUPPORTED;
-            GGML_LOG_ERROR("Unsupported domain %d\n", domain);
-            goto bail;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
-    }
-
-bail:
-    return nErr;
-}
-
-bool is_unsignedpd_supported(int domain_id) {
-    int nErr = AEE_SUCCESS;
-    if (remote_handle_control) {
-        struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
-        nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
-        if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-            GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
-            return false;
-        }
-        if (nErr) {
-            GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
-            return false;
-        }
-        if (dsp_capability_domain.capability == 1) {
-            return true;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
-        return false;
-    }
-    return false;
-}
-
-bool get_unsignedpd_support(void) {
-    return is_unsignedpd_supported(CDSP_DOMAIN_ID);
-}
-
-bool is_async_fastrpc_supported(int domain) {
-    int nErr = AEE_SUCCESS;
-    if (remote_handle_control) {
-        if (domain == CDSP_DOMAIN_ID) {
-            /*
-            * Query the DSP for ASYNC_FASTRPC_SUPPORT information
-            * Async fastrpc is supported only on CDSP
-            */
-            struct remote_dsp_capability dsp_capability_async_support;
-            dsp_capability_async_support.domain       = (uint32_t) domain;
-            dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
-            dsp_capability_async_support.capability   = (uint32_t) 0;
-            nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
-                                         sizeof(struct remote_dsp_capability));
-            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
-                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
-                nErr = AEE_SUCCESS;
-                goto bail;
-            } else if (dsp_capability_async_support.capability == 1) {
-                return true;
-            }
-            if (nErr != AEE_SUCCESS) {
-                GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
-                goto bail;
-            }
-        } else {
-            nErr = AEE_EUNSUPPORTED;
-            GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
-            goto bail;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
-    }
-
-bail:
-    return false;
-}
-
-bool is_status_notification_supported(int domain) {
-    int nErr = AEE_SUCCESS;
-
-    if (remote_handle_control) {
-        /*
-        * Query the DSP for STATUS_NOTIFICATION_SUPPORT information
-        * DSP User PD status notification Support
-        */
-        struct remote_dsp_capability dsp_capability_status_notification_support;
-        dsp_capability_status_notification_support.domain       = (uint32_t) domain;
-        dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
-        dsp_capability_status_notification_support.capability   = (uint32_t) 0;
-        nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
-                                     sizeof(struct remote_dsp_capability));
-        if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-            GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
-            GGML_LOG_ERROR("Running the usecase without checking the capability\n");
-            nErr = AEE_SUCCESS;
-            goto bail;
-        } else if (dsp_capability_status_notification_support.capability == 1) {
-            return true;
-        }
-        if (nErr != AEE_SUCCESS) {
-            GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
-            goto bail;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
-    }
-
-bail:
-    return false;
-}
-
-int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
-    int nErr    = AEE_SUCCESS;
-    *capability = 0;
-
-    if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
-        nErr = AEE_EBADPARM;
-        GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
-        goto bail;
-    }
-    if (remote_handle_control) {
-        if (domain == CDSP_DOMAIN_ID) {
-            /*
-            * Query the DSP for HMX SUPPORT information
-            * HMX is supported on CDSP only
-            */
-            struct remote_dsp_capability dsp_capability_hmx_dsp;
-            dsp_capability_hmx_dsp.domain       = (uint32_t) domain;
-            dsp_capability_hmx_dsp.attribute_ID = attr;
-            dsp_capability_hmx_dsp.capability   = (uint32_t) 0;
-            nErr                                = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
-                                                                        sizeof(struct remote_dsp_capability));
-            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
-                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
-                nErr = AEE_SUCCESS;
-                goto bail;
-            } else if (nErr == AEE_SUCCESS) {
-                *capability = dsp_capability_hmx_dsp.capability;
-            } else {
-                GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
-                goto bail;
-            }
-        } else {
-            nErr = AEE_EUNSUPPORTED;
-            GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
-            goto bail;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
-    }
-
-bail:
-    return nErr;
-}
-
-int get_hex_arch_ver(int domain, int * arch) {
-    if (!remote_handle_control) {
-        GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
-        return AEE_EUNSUPPORTEDAPI;
-    }
-
-    struct remote_dsp_capability arch_ver;
-    arch_ver.domain       = (uint32_t) domain;
-    arch_ver.attribute_ID = ARCH_VER;
-    arch_ver.capability   = (uint32_t) 0;
-
-    int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
-    if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
-        GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
-        return AEE_EUNSUPPORTEDAPI;
-    }
-
-    if (err != AEE_SUCCESS) {
-        GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
-        return err;
-    }
-
-    switch (arch_ver.capability & 0xff) {
-        case 0x68:
-            *arch = 68;
-            return 0;
-        case 0x69:
-            *arch = 69;
-            return 0;
-        case 0x73:
-            *arch = 73;
-            return 0;
-        case 0x75:
-            *arch = 75;
-            return 0;
-        case 0x79:
-            *arch = 79;
-            return 0;
-        case 0x81:
-            *arch = 81;
-            return 0;
-    }
-    return -1;
-}
-
-int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
-    int nErr    = AEE_SUCCESS;
-    *capability = 0;
-
-    if (remote_handle_control) {
-        if (domain == CDSP_DOMAIN_ID) {
-            /*
-            * Query the DSP for HVX SUPPORT information
-            * HVX is supported on CDSP only
-            */
-            struct remote_dsp_capability dsp_capability_hvx_dsp;
-            dsp_capability_hvx_dsp.domain       = (uint32_t) domain;
-            dsp_capability_hvx_dsp.attribute_ID = attr;
-            dsp_capability_hvx_dsp.capability   = (uint32_t) 0;
-            nErr                                = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
-                                                                        sizeof(struct remote_dsp_capability));
-            if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
-                GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
-                GGML_LOG_ERROR("Running the usecase without checking the capability\n");
-                nErr = AEE_SUCCESS;
-                goto bail;
-            } else if (nErr == AEE_SUCCESS) {
-                *capability = dsp_capability_hvx_dsp.capability;
-            } else {
-                GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
-                goto bail;
-            }
-        } else {
-            nErr = AEE_EUNSUPPORTED;
-            GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
-            goto bail;
-        }
-    } else {
-        nErr = AEE_EUNSUPPORTEDAPI;
-        GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
-    }
-
-bail:
-    return nErr;
-}
diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h
deleted file mode 100644
index 7bbae3a0..00000000
--- a/ggml/src/ggml-hexagon/htp-utils.h
+++ /dev/null
@@ -1,221 +0,0 @@
-#ifndef HTP_UTILS_H
-#define HTP_UTILS_H
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#include 
-#include 
-#include 
-#include 
-#include 
-
-/* Offset to differentiate HLOS and Hexagon error codes.
-   Stores the value of AEE_EOFFSET for Hexagon. */
-#ifndef DSP_OFFSET
-#    define DSP_OFFSET 0x80000400
-#endif
-
-/* Errno for connection reset by peer. */
-#ifndef ECONNRESET
-#    ifdef __hexagon__
-#        define ECONNRESET 104
-#    endif
-#endif
-
-/* Abstraction of different OS specific sleep APIs.
-   SLEEP accepts input in seconds. */
-#ifndef SLEEP
-#    ifdef __hexagon__
-#        define SLEEP(x)                      \
-            { /* Do nothing for simulator. */ \
-            }
-#    else
-#        ifdef _WINDOWS
-#            define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
-#        else
-#            define SLEEP(x) sleep(x)        /* sleep accepts input in seconds. */
-#        endif
-#    endif
-#endif
-
-/* Include windows specific header files. */
-#ifdef _WINDOWS
-#    include 
-#    include 
-#    define _CRT_SECURE_NO_WARNINGS         1
-#    define _WINSOCK_DEPRECATED_NO_WARNINGS 1
-/* Including this file for custom implementation of getopt function. */
-#    include "getopt_custom.h"
-#endif
-
-/* Includes and defines for all HLOS except windows */
-#if !defined(__hexagon__) && !defined(_WINDOWS)
-#    include "unistd.h"
-
-#    include 
-#endif
-
-/* Includes and defines for Hexagon and all HLOS except Windows. */
-#if !defined(_WINDOWS)
-/* Weak reference to remote symbol for compilation. */
-#    pragma weak remote_session_control
-#    pragma weak remote_handle_control
-#    pragma weak remote_handle64_control
-#    pragma weak fastrpc_mmap
-#    pragma weak fastrpc_munmap
-#    pragma weak rpcmem_alloc2
-#endif
-
-#if !defined(_WINDOWS)
-#    pragma weak remote_system_request
-#endif
-/**
- * Wrapper for FastRPC Capability API: query DSP support.
- *
- * @param[out]  domain pointer to supported domain.
- * @return      0          if query is successful.
- *              non-zero   if error, return value points to the error.
- */
-int get_dsp_support(int * domain);
-
-/**
- * Wrapper for FastRPC Capability API: query VTCM information.
- *
- * @param[in]   domain value of domain in the queried.
- * @param[out]  capability capability value of the attribute queried.
- * @param[in]   attr value of the attribute to the queried.
- * @return      0          if query is successful.
- *              non-zero   if error, return value points to the error.
- */
-int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
-
-/**
- * Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
- *
- * @return      true          if unsigned pd is supported.
- *              false         if unsigned pd is not supported, capability query failed.
- */
-
-bool get_unsignedpd_support(void);
-
-/**
- * Wrapper for FastRPC Capability API: query unsigned pd support.
- *
- * @param[in]   domain value of domain in the queried.
- * @return      true          if unsigned pd is supported.
- *              false         if unsigned pd is not supported, capability query failed.
- */
-
-bool is_unsignedpd_supported(int domain_id);
-
-/**
- * is_valid_domain_id API: query a domain id is valid.
- *
- * @param[in]   domain value of domain in the queried.
- * @param[in]   compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
- * @return      true          if value of domain is valid.
- *              false         if value of domain is not valid.
- */
-
-bool is_valid_domain_id(int domain_id, int compute_only);
-
-/**
- * get_domain API: get domain struct from domain value.
- *
- * @param[in]  domain value of a domain
- * @return     Returns domain struct of the domain if it is supported or else
- *             returns NULL.
- *
- */
-
-domain * get_domain(int domain_id);
-
-/**
- * get_domains_info API: get information for all the domains available on the device
- *
- * @param[in]  domain_type pointer to domain type
- * @param[in]  num_domains pointer to number of domains
- * @param[in]  domains_info pointer to save discovered domains information.
- * @return     0 if query is successful.
- *              non-zero if error, return value points to the error.
- *
- * It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
- *
- */
-
-int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
-
-/**
- * get_effective_domain_id API: get effective domain id for given session id
- *
- * @param[in]  domain_name pointer to domain name
- * @param[in]  session_id
- * @param[in]  effec_domain_id pointer to save obtained effective domain id.
- * @return     0 if query is successful.
- *              non-zero if error, return value points to the error.
- *
- */
-
-int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
-
-/**
- * is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
- *
- * @param[in]  domain_id value of a domain
- * @return     Returns true or false stating support of Async FastRPC
- *
- */
-
-bool is_async_fastrpc_supported(int domain_id);
-
-/**
- * is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
- *
- * @param[in]  domain_id value of a domain
- * @return     Returns true or false stating status notification support information
- *
- */
-bool is_status_notification_supported(int domain_id);
-
-/**
- * get_hmx_support_info API: query the DSP for HMX SUPPORT information
- *
- * @param[in]   domain_id value of a domain
- * @param[out]  capability capability value of the attribute queried.
- * @param[in]   attr value of the attribute to the queried.
- * @return      0 if query is successful.
- *              non-zero if error, return value points to the error.
- *
- */
-int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
-
-/**
- * get_hex_arch_ver API: query the Hexagon processor architecture version information
- *
- * @param[in]   domain_id value of a domain
- * @param[out]  Arch version (73, 75, ...)
- * @return      0 if query is successful.
- *              non-zero if error, return value points to the error.
- *
- */
-int get_hex_arch_ver(int domain, int * arch);
-
-/**
- * get_hvx_support_info API: query the DSP for HVX SUPPORT information
- *
- * @param[in]   domain_id value of a domain
- * @param[out]  capability capability value of the attribute queried.
- * @param[in]   attr value of the attribute to the queried.
- * @return      0 if query is successful.
- *              non-zero if error, return value points to the error.
- *
- */
-int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif  //DSP_CAPABILITIES_UTILS_H
diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt
index 6a34a215..02d07a50 100644
--- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt
+++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
 include_directories(
     ${HEXAGON_SDK_ROOT}/incs
     ${HEXAGON_SDK_ROOT}/incs/stddef
+    ${CMAKE_CURRENT_SOURCE_DIR}/../../../include
     ${CMAKE_CURRENT_SOURCE_DIR}/../..
     ${CMAKE_CURRENT_SOURCE_DIR}/..
     ${CMAKE_CURRENT_SOURCE_DIR}
@@ -17,24 +18,25 @@ add_library(${HTP_LIB} SHARED
     main.c
     htp_iface_skel.c
     worker-pool.c
-    htp-dma.c
-    hvx-sigmoid.c
-    hvx-inverse.c
-    hvx-exp.c
-    hvx-utils.c
+    hex-dma.c
     matmul-ops.c
     binary-ops.c
     unary-ops.c
+    sum-rows-ops.c
     softmax-ops.c
     act-ops.c
     rope-ops.c
     flash-attn-ops.c
     set-rows-ops.c
     get-rows-ops.c
+    cpy-ops.c
+    argsort-ops.c
+    ssm-conv.c
 )
 
 target_compile_definitions(${HTP_LIB} PRIVATE
     $,HTP_DEBUG=1,NDEBUG=1>
+    $,FARF_HIGH=1,>
     FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
 
 build_idl(htp_iface.idl ${HTP_LIB})
diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c
index 88bd2ddc..d8b92498 100644
--- a/ggml/src/ggml-hexagon/htp/act-ops.c
+++ b/ggml/src/ggml-hexagon/htp/act-ops.c
@@ -2,27 +2,20 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
+
 #include 
-#include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
 #define htp_act_preamble3              \
     const uint32_t ne00 = src0->ne[0]; \
@@ -76,27 +69,45 @@
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
-                                       const struct htp_tensor * src1,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         src1_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+struct htp_act_context {
+    struct htp_ops_context *  octx;
+
+    // Precomputed values
+    const uint8_t *           data_src0;
+    const uint8_t *           data_src1;
+    uint8_t *                 data_dst;
+
+    size_t                    src0_row_size;
+    size_t                    src1_row_size;
+    size_t                    dst_row_size;
+
+    size_t                    src0_row_size_aligned;
+    size_t                    src1_row_size_aligned;
+    size_t                    dst_row_size_aligned;
+
+    size_t                    src0_spad_half_size;
+    size_t                    src1_spad_half_size;
+    size_t                    dst_spad_half_size;
+
+    uint32_t                  block;
+    uint32_t                  src0_nrows;
+    uint32_t                  src0_nrows_per_thread;
+    int                       nc;
+};
+
+static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * src1 = &actx->octx->src1;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble3;
 
-    size_t src0_row_size = nb01;
-    size_t src1_row_size = nb11;
-    size_t dst_row_size  = nb1;
-
-
-
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    size_t src0_row_size = actx->src0_row_size;
+    size_t src1_row_size = actx->src1_row_size;
+    size_t dst_row_size  = actx->dst_row_size;
 
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
 
@@ -108,43 +119,34 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+    const uint8_t * restrict data_src0 = actx->data_src0;
+    const uint8_t * restrict data_src1 = actx->data_src1;
+    uint8_t * restrict data_dst        = actx->data_dst;
 
-    const bool src1_valid = src1->ne[0];
-    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
-    if (!src1_valid) {
-        const int32_t swapped = op_params[1];
-        data_src1             = data_src0;
-        src1_row_size         = src0_row_size;
+    const int  nc = actx->nc;
 
-        const size_t nc_in_bytes = nc * SIZEOF_FP32;
-        data_src0 += swapped ? nc_in_bytes : 0;
-        data_src1 += swapped ? 0 : nc_in_bytes;
-    }
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
-    const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
-    const size_t dst_row_size_aligned  = htp_round_up(dst_row_size, VLEN);
+    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
 
-    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
-    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_spad->size_per_thread);
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t src1_spad_half_size = actx->src1_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread / 2;
-
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned;  // How many rows can we process in one block
+    const int BLOCK = actx->block;
     if (BLOCK == 0) {
         FARF(ERROR,
              "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-             src0_spad->size_per_thread, src0_row_size_aligned);
+             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -175,9 +177,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
             float *       dst_spad_ptr  = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
 
             //swiglu(x) = x1 * sigmoid(x0)
-            hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
-            hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
-                                (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
+            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc);
+            hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
+                                (const uint8_t *) src1_spad_ptr, nc);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
@@ -203,27 +205,22 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
-                                           const struct htp_tensor * src1,
-                                           struct htp_tensor *       dst,
-                                           const int32_t *           op_params,
-                                           struct htp_spad *         src0_spad,
-                                           struct htp_spad *         src1_spad,
-                                           struct htp_spad *         dst_spad,
-                                           uint32_t                  nth,
-                                           uint32_t                  ith,
-                                           uint32_t                  src0_nrows_per_thread,
-                                           dma_queue *               dma_queue) {
+static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * src1 = &actx->octx->src1;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble3;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    size_t src0_row_size = nb01;
-    size_t src1_row_size = nb11;
-    size_t dst_row_size  = nb1;
+    size_t src0_row_size = actx->src0_row_size;
+    size_t src1_row_size = actx->src1_row_size;
+    size_t dst_row_size  = actx->dst_row_size;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -233,45 +230,36 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+    const uint8_t * restrict data_src0 = actx->data_src0;
+    const uint8_t * restrict data_src1 = actx->data_src1;
+    uint8_t * restrict data_dst        = actx->data_dst;
 
-    const bool src1_valid = src1->ne[0];
-    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
-    if (!src1_valid) {
-        const int32_t swapped = op_params[1];
-        data_src1             = data_src0;
-        src1_row_size         = src0_row_size;
+    const int nc = actx->nc;
 
-        const size_t nc_in_bytes = nc * SIZEOF_FP32;
-        data_src0 += swapped ? nc_in_bytes : 0;
-        data_src1 += swapped ? 0 : nc_in_bytes;
-    }
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
-    const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
-    const size_t dst_row_size_aligned  = htp_round_up(dst_row_size, VLEN);
+    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
 
-    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
-    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_spad->size_per_thread);
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t src1_spad_half_size = actx->src1_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread / 2;
-
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned;  // How many rows can we process in one block
+    const int BLOCK = actx->block;
     if (BLOCK == 0) {
         FARF(ERROR,
              "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
              "%zu\n",
-             src0_spad->size_per_thread, src0_row_size_aligned);
+             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
-    const float alpha = ((const float *) (op_params))[2];
-    const float limit = ((const float *) (op_params))[3];
+    const float alpha = ((const float *) (actx->octx->op_params))[2];
+    const float limit = ((const float *) (actx->octx->op_params))[3];
+
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
 
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
@@ -304,18 +292,18 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
             float *       dst_spad_ptr  = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
 
             // x (src0_spad_data) = std::min(src0_p[k], limit);
-            hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc);
+            hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc);
             // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
-            hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc);
+            hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc);
             // y (src1_spad_data)  = y1 + 1.f
-            hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc);
+            hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc);
             // x1 (dst_spad_data) = alpha * (x)
-            hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc);
+            hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc);
             // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
-            hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
+            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);
             // out = x * sigmoid(alpha * x) * (y + 1.f)
-            hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
-                                (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
+            hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
+                                (const uint8_t *) src1_spad_ptr, nc);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
@@ -342,26 +330,22 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
 }
 
 
-static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble2;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
-    const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
-    const size_t dst_row_size_aligned  = htp_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size = actx->src0_row_size;
+    const size_t dst_row_size  = actx->dst_row_size;
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -371,25 +355,29 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * data_src0 = (const uint8_t *) src0->data;
-    uint8_t * data_dst        = (uint8_t *) dst->data;
+    const uint8_t * data_src0 = actx->data_src0;
+    uint8_t * data_dst        = actx->data_dst;
 
-    uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * dst_spad_data  = dst_spad->data  + (ith * dst_spad->size_per_thread);
+    // nc/ne0 matches.
+    const int ne0_val = actx->nc; // == dst->ne[0]
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread  / 2;
+    uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_data  = actx->octx->dst_spad.data  + (ith * actx->octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
 
     // In gelu = x*sigmoid(x*1.702)
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+    const int BLOCK = actx->block;
 
     if (BLOCK == 0) {
         FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-                src0_spad->size_per_thread, src0_row_size_aligned);
+                actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -415,9 +403,9 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
             float* dst_spad_ptr        = dst_spad  + ib * (dst_row_size_aligned  / sizeof(float));
 
             // gelu = x * sigmoid(1.702 * x) // current implementation
-            hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0);
-            hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
-            hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
+            hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
+            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
+            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -442,34 +430,23 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
          ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
-                               octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
 
-
-
-static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
-                                       struct htp_tensor *       dst,
-                                       const int32_t *           op_params,
-                                       struct htp_spad *         src0_spad,
-                                       struct htp_spad *         dst_spad,
-                                       uint32_t                  nth,
-                                       uint32_t                  ith,
-                                       uint32_t                  src0_nrows_per_thread,
-                                       dma_queue *               dma_queue) {
+static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * dst  = &actx->octx->dst;
     htp_act_preamble2;
 
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
-    const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
-    const size_t dst_row_size_aligned  = htp_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size = actx->src0_row_size;
+    const size_t dst_row_size  = actx->dst_row_size;
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -479,24 +456,27 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
         return;
     }
 
-    const uint8_t * data_src0 = (const uint8_t *) src0->data;
-    uint8_t * data_dst        = (uint8_t *) dst->data;
+    const uint8_t * data_src0 = actx->data_src0;
+    uint8_t * data_dst        = actx->data_dst;
 
-    uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
-    uint8_t * dst_spad_data  = dst_spad->data  + (ith * dst_spad->size_per_thread);
+    const int ne0_val = actx->nc; // == dst->ne[0]
 
-    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
-    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
-    size_t dst_spad_half_size  = dst_spad->size_per_thread  / 2;
+    uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_data  = actx->octx->dst_spad.data  + (ith * actx->octx->dst_spad.size_per_thread);
 
-    const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
+
+    const int BLOCK = actx->block;
 
     if (BLOCK == 0) {
         FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
-                src0_spad->size_per_thread, src0_row_size_aligned);
+                actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
         return;
     }
 
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
     // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
     for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
         const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
@@ -522,8 +502,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
             float* dst_spad_ptr        = dst_spad  + ib * (dst_row_size_aligned  / sizeof(float));
 
             // silu = x * sigmoid(x)
-            hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
-            hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
+            hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
+            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -548,27 +528,130 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
          ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
-                               octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+static const float GELU_COEF_A     = 0.044715f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_act_context * actx = (struct htp_act_context *) data;
+    const struct htp_tensor * src0 = &actx->octx->src0;
+    const struct htp_tensor * src1 = &actx->octx->src1;
+    const struct htp_tensor * dst  = &actx->octx->dst;
+    htp_act_preamble3;
+
+    size_t src0_row_size = actx->src0_row_size;
+    size_t src1_row_size = actx->src1_row_size;
+    size_t dst_row_size  = actx->dst_row_size;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t src0_nrows = actx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    const uint8_t * restrict data_src0 = actx->data_src0;
+    const uint8_t * restrict data_src1 = actx->data_src1;
+    uint8_t * restrict data_dst        = actx->data_dst;
+
+    const int nc = actx->nc;
+
+    const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+    const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+    const size_t dst_row_size_aligned  = actx->dst_row_size_aligned;
+
+    uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+    uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+    uint8_t * restrict dst_spad_data  = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half_size = actx->src0_spad_half_size;
+    size_t src1_spad_half_size = actx->src1_spad_half_size;
+    size_t dst_spad_half_size  = actx->dst_spad_half_size;
+
+    const int BLOCK = actx->block;
+    if (BLOCK == 0) {
+        FARF(ERROR,
+             "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+             actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
+        return;
+    }
+
+    dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
+    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+        dma_queue_push_vtcm_to_ddr(dma_queue,
+            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+            dst_row_size, dst_row_size_aligned, 0);
+
+        dma_queue_push_ddr_to_vtcm(dma_queue,
+            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+            src0_row_size_aligned, src0_row_size, block_size);
+        dma_queue_push_ddr_to_vtcm(dma_queue,
+            dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
+            src1_row_size_aligned, src1_row_size, block_size);
+    }
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;
+        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+        float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+        for (uint32_t ib = 0; ib < block_size; ib++) {
+            const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
+            const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
+            uint8_t *       dst_spad_ptr  = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
+
+            // geglu tanh implementation
+            // geglu(x, g) = gelu(x) * g
+            // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
+            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc);                       // res = x*x
+            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc);   // res = res * GELU_COEF_A
+            hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc);          // res = res + 1.0f
+            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc);       // res = res * x
+            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
+            hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);         // res = tanh(res)
+            hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc);           // res = res + 1.0f
+            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc);       // res = res * x
+            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc);          // res = res + 0.5f
+            hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc);       // res = res * g
+        }
+
+        dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
+                                   dst_row_size_aligned, block_size);
+
+        // prefetch N+2 loop iteration if any
+        const uint32_t pref_block = (ir + BLOCK * 2);
+        if (pref_block < src0_end_row) {
+            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+                                       src0_row_size_aligned, src0_row_size, pref_block_size);
+            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
+                                       src1_row_size_aligned, src1_row_size, pref_block_size);
+        }
+    }
+
+    dma_queue_flush(dma_queue);
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
-                               &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-    glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
-                                   &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static int execute_op_activations_fp32(struct htp_ops_context * octx) {
-    int err = HTP_STATUS_OK;
-
+static int execute_op_activations_f32(struct htp_ops_context * octx) {
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
     struct htp_tensor *       dst  = &octx->dst;
@@ -583,30 +666,35 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
 
     switch (octx->op) {
         case HTP_OP_UNARY_SILU:
-            act_op_func = unary_silu_fp32;
+            act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
             op_type     = "silu-f32";
             break;
 
         case HTP_OP_GLU_SWIGLU:
-            act_op_func = glu_swiglu_fp32;
+            act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
             op_type     = "swiglu-f32";
             break;
 
         case HTP_OP_GLU_SWIGLU_OAI:
-            act_op_func = glu_swiglu_oai_fp32;
+            act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
             op_type     = "swiglu-oai-f32";
             break;
         case HTP_OP_UNARY_GELU:
-            act_op_func = unary_gelu_fp32;
+            act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
             op_type     = "gelu-f32";
             break;
+
+        case HTP_OP_GLU_GEGLU:
+            act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
+            op_type     = "geglu-f32";
+            break;
         default:
             FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const uint32_t n_threads  = octx->n_threads;
     const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     size_t src0_row_size = src0->nb[1];
     size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
@@ -617,9 +705,9 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
         src1_row_size = src0_row_size;
     }
 
-    const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
-    const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN);
-    const size_t dst_row_size_aligned  = htp_round_up(dst_row_size, VLEN);
+    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+    const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
     // VTCM scratchpads for all tensors
     // N rows per thread, padded to HVX vector size
 
@@ -656,13 +744,56 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
              octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
     }
 
-    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
+    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        return HTP_STATUS_OK;
     }
 
-    return err;
+    // Prepare context
+    struct htp_act_context actx;
+    actx.octx = octx;
+
+    actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+
+    actx.src0_row_size = src0_row_size;
+    actx.src1_row_size = src1_row_size;
+    actx.dst_row_size  = dst_row_size;
+
+    actx.src0_row_size_aligned = src0_row_size_aligned;
+    actx.src1_row_size_aligned = src1_row_size_aligned;
+    actx.dst_row_size_aligned  = dst_row_size_aligned;
+
+    actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
+    actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
+    actx.dst_spad_half_size  = octx->dst_spad.size_per_thread / 2;
+
+    actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
+    actx.src0_nrows = src0_nrows;
+
+    actx.nc = dst->ne[0];
+
+    // Pointers and GLU logic
+    const uint8_t * data_src0 = (const uint8_t *) src0->data;
+    const uint8_t * data_src1 = (const uint8_t *) src1->data;
+
+    if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
+         const int32_t swapped = octx->op_params[1];
+         data_src1 = data_src0;
+         actx.src1_row_size = actx.src0_row_size;
+
+         size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
+         if (swapped) {
+             data_src0 += nc_in_bytes;
+         } else {
+             data_src1 += nc_in_bytes;
+         }
+    }
+
+    actx.data_src0 = data_src0;
+    actx.data_src1 = data_src1;
+    actx.data_dst  = (uint8_t *) dst->data;
+
+    worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);
+    return HTP_STATUS_OK;
 }
 
 int op_activations(struct htp_ops_context * octx) {
@@ -670,7 +801,7 @@ int op_activations(struct htp_ops_context * octx) {
 
     switch (octx->src0.type) {
         case HTP_TYPE_F32:
-            err = execute_op_activations_fp32(octx);
+            err = execute_op_activations_f32(octx);
             break;
 
         default:
diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c
new file mode 100644
index 00000000..170220e8
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c
@@ -0,0 +1,281 @@
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "ggml.h"
+
+#include "hvx-utils.h"
+#include "hex-dma.h"
+
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+struct htp_argsort_context {
+    struct htp_ops_context * octx;
+    uint32_t                 nrows_per_thread;
+};
+
+static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
+{
+    const HVX_Vector one  = Q6_V_vsplat_R(1);
+    const HVX_Vector zero = Q6_V_vzero();
+
+    HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
+    HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
+    HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
+    return hvx_vec_get_i32(sum) == 32;
+}
+
+// Sorts values and mirrors swaps to indices.
+static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
+    if (left >= right) return;
+
+    int pivot_idx = (left + right) / 2;
+    float pivot = values[pivot_idx];
+    int i = left;
+    int j = right;
+
+    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+    while (i <= j) {
+        // Vectorized scan for i
+        while (i <= j) {
+            // Check if we have at least one full vector
+            if (i + 32 <= j) {
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+                if (all_greater_f32(pivot_vec, vals_vec)) {
+                    // If all elements are < pivot, we can skip this whole block
+                    i += 32;
+                    continue;
+                }
+            }
+
+            // Scalar fallback / cleanup
+            if (values[i] < pivot) {
+                i++;
+            } else {
+                break;
+            }
+        }
+
+        // Vectorized scan for j
+        while (i <= j) {
+            if (j - 32 >= i) {
+                // Load 32 elements ending at j.
+                // Since we want `values[j] > pivot`, let's load from j-31 to j.
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+                if (all_greater_f32(vals_vec, pivot_vec)) {
+                    j -= 32;
+                    continue;
+                }
+            }
+
+            if (values[j] > pivot) {
+                j--;
+            } else {
+                break;
+            }
+        }
+
+        if (i <= j) {
+            float tmp_val = values[i];
+            values[i] = values[j];
+            values[j] = tmp_val;
+
+            int32_t tmp_idx = indices[i];
+            indices[i] = indices[j];
+            indices[j] = tmp_idx;
+            i++;
+            j--;
+        }
+    }
+
+    if (left < j) quicksort_values_indices_asc(values, indices, left, j);
+    if (i < right) quicksort_values_indices_asc(values, indices, i, right);
+}
+
+static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
+    if (left >= right) return;
+
+    int pivot_idx = (left + right) / 2;
+    float pivot = values[pivot_idx];
+    int i = left;
+    int j = right;
+
+    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+
+    while (i <= j) {
+        // Vectorized scan for i (values[i] > pivot)
+        while (i <= j) {
+            if (i + 32 <= j) {
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+                if (all_greater_f32(vals_vec, pivot_vec)) {
+                    i += 32;
+                    continue;
+                }
+            }
+
+            if (values[i] > pivot) {
+                i++;
+            } else {
+                break;
+            }
+        }
+
+        // Vectorized scan for j (values[j] < pivot)
+        while (i <= j) {
+            if (j - 32 >= i) {
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+                if (all_greater_f32(pivot_vec, vals_vec)) {
+                    j -= 32;
+                    continue;
+                }
+            }
+
+            if (values[j] < pivot) {
+                j--;
+            } else {
+                break;
+            }
+        }
+
+        if (i <= j) {
+            float tmp_val = values[i];
+            values[i] = values[j];
+            values[j] = tmp_val;
+
+            int32_t tmp_idx = indices[i];
+            indices[i] = indices[j];
+            indices[j] = tmp_idx;
+            i++;
+            j--;
+        }
+    }
+
+    if (left < j) quicksort_values_indices_desc(values, indices, left, j);
+    if (i < right) quicksort_values_indices_desc(values, indices, i, right);
+}
+
+static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
+    struct htp_ops_context * octx = actx->octx;
+
+    // Unpack context
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * dst = &octx->dst;
+
+    // Scratchpad memory
+    uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
+
+    // Dimensions
+    uint32_t ne00 = src0->ne[0];
+    uint32_t ne01 = src0->ne[1];
+    uint32_t ne02 = src0->ne[2];
+    uint32_t ne03 = src0->ne[3];
+
+    uint32_t nb01 = src0->nb[1];
+    //uint32_t nb02 = src0->nb[2];
+    //uint32_t nb03 = src0->nb[3];
+
+    uint32_t nb1 = dst->nb[1];
+    //uint32_t nb2 = dst->nb[2];
+    //uint32_t nb3 = dst->nb[3];
+
+    // Sort order
+    enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
+
+    // Rows to process
+    uint32_t total_rows = ne01 * ne02 * ne03;
+    uint32_t rows_per_thread = actx->nrows_per_thread;
+    uint32_t start_row = rows_per_thread * i;
+    uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
+
+    // Scratchpad layout:
+    // We need space for one row of float data (values) and one row of int32 indices.
+    // values: ne00 * sizeof(float)
+    // indices: ne00 * sizeof(int32_t)
+    // Padded to 128 bytes.
+
+    size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
+    float * values_buf = (float *) spad;
+    int32_t * indices_buf = (int32_t *) (spad + values_size);
+
+    for (uint32_t r = start_row; r < end_row; r++) {
+        uint32_t src_offset = r * nb01;
+        uint32_t dst_offset = r * nb1;
+
+        uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
+        uint8_t * dst_ptr = (uint8_t *) dst->data  + dst_offset;
+
+        hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
+        hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
+
+        // Initialize indices
+        for (uint32_t j = 0; j < ne00; j++) {
+            indices_buf[j] = j;
+        }
+
+        // Sort values and mirror swaps to indices
+        if (order == GGML_SORT_ORDER_ASC) {
+            quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
+        } else {
+            quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
+        }
+
+        // Copy indices back to DDR
+        hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
+    }
+}
+
+int op_argsort(struct htp_ops_context * octx) {
+    // Check supported types
+    if (octx->src0.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
+    const uint32_t n_threads = MIN(total_rows, octx->n_threads);
+
+    // Allocate scratchpad
+    // We need 1 row of float + 1 row of int32 per thread.
+    uint32_t ne00 = octx->src0.ne[0];
+    size_t values_size  = hex_round_up(ne00 * sizeof(float), 128);
+    size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
+    size_t spad_per_thread = values_size + indices_size;
+
+    // Make sure we round up to 256 for alignment requirements
+    spad_per_thread = hex_round_up(spad_per_thread, 256);
+
+    size_t total_spad_size = spad_per_thread * n_threads;
+
+    if (octx->ctx->vtcm_size < total_spad_size) {
+        FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src0_spad.size = total_spad_size;
+    octx->src0_spad.size_per_thread = spad_per_thread;
+
+    FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
+         octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
+         octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
+         octx->src0.data, octx->dst.data);
+
+    struct htp_argsort_context actx;
+    actx.octx = octx;
+    actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;
+
+    // Run jobs
+    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);
+
+    return HTP_STATUS_OK;
+}
diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c
index 8ed7f67d..ec90f22d 100644
--- a/ggml/src/ggml-hexagon/htp/binary-ops.c
+++ b/ggml/src/ggml-hexagon/htp/binary-ops.c
@@ -2,41 +2,52 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
-
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
+
 #include 
-#include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
-typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
-                                      const uint8_t * src1,
-                                      uint8_t *       data_dst,
-                                      const int       num_elems);
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
 
-static hvx_elemwise_f32_func func_table_HVX[]     = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
-static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
+// Context for binary operations
+struct htp_binary_context {
+    struct htp_ops_context * octx;
+    struct fastdiv_values dim1_div;
+    struct fastdiv_values dim2_div;
+    struct fastdiv_values dim12_div;
+
+    struct fastdiv_values src1_dim1_div; // ne11
+    struct fastdiv_values src1_dim2_div; // ne12
+    struct fastdiv_values src1_dim3_div; // ne13
+
+    uint32_t nrows_per_thread;
+    bool split_at_ne01;
+    bool split_at_ne02;
+
+    // Precomputed values
+    uint32_t block_max;
+    size_t   src0_row_size_aligned;
+    size_t   src1_row_size_aligned;
+    size_t   dst_row_size_aligned;
+    uint32_t src1_fetch_rows; // 1 or block_max
+    uint32_t src1_dma_stride; // 0 or stride
+};
 
 #define htp_binary_preamble            \
     const struct htp_tensor * src0 = &octx->src0; \
     const struct htp_tensor * src1 = &octx->src1; \
-    const struct htp_tensor * src2 = &octx->src2; \
     struct htp_tensor *       dst  = &octx->dst;  \
                                        \
     const uint32_t ne00 = src0->ne[0]; \
@@ -49,272 +60,752 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f
     const uint32_t ne12 = src1->ne[2]; \
     const uint32_t ne13 = src1->ne[3]; \
                                        \
-    const uint32_t ne0 = dst->ne[0];   \
-    const uint32_t ne1 = dst->ne[1];   \
-    const uint32_t ne2 = dst->ne[2];   \
-    const uint32_t ne3 = dst->ne[3];   \
-                                       \
-    const uint32_t nb00 = src0->nb[0]; \
     const uint32_t nb01 = src0->nb[1]; \
     const uint32_t nb02 = src0->nb[2]; \
     const uint32_t nb03 = src0->nb[3]; \
                                        \
-    const uint32_t nb10 = src1->nb[0]; \
     const uint32_t nb11 = src1->nb[1]; \
     const uint32_t nb12 = src1->nb[2]; \
     const uint32_t nb13 = src1->nb[3]; \
                                        \
-    const uint32_t nb0 = dst->nb[0];   \
     const uint32_t nb1 = dst->nb[1];   \
     const uint32_t nb2 = dst->nb[2];   \
-    const uint32_t nb3 = dst->nb[3];   \
-                                       \
-    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+    const uint32_t nb3 = dst->nb[3];
 
-static void binary_job_f32_per_thread(struct htp_ops_context * octx,
-                                      uint8_t *                spad_data,
-                                      uint32_t                 nth,
-                                      uint32_t                 ith,
-                                      enum htp_op              op) {
+static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
+                                uint32_t ne01, uint32_t ne02) {
+    uint32_t i03, i02, i01, rem;
+    i03 = fastdiv(ir, &bctx->dim12_div);
+    rem = ir - i03 * (ne02 * ne01);
+    i02 = fastdiv(rem, &bctx->dim1_div);
+    i01 = rem - i02 * ne01;
+
+    uint32_t rows_left = end_row - ir;
+    uint32_t block_limit = rows_left;
+
+    if (bctx->split_at_ne01) {
+        block_limit = MIN(block_limit, ne01 - i01);
+    }
+    if (bctx->split_at_ne02) {
+         uint32_t rows_in_plane = (ne02 * ne01) - rem;
+         block_limit = MIN(block_limit, rows_in_plane);
+    }
+
+    return MIN(bctx->block_max, block_limit);
+}
+
+// Macro for scalar op switch
+#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+            case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+            case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+            case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            default: break; \
+        } \
+    }
+
+// Macro for vector op switch (All Aligned)
+#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    }
+
+// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
+#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    }
+
+// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
+#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    }
+
+// 1. Scalar src1 (ne10 == 1)
+static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
-    const size_t src0_row_size = nb01;
-    const size_t src1_row_size = nb11;
-    const size_t dst_row_size  = nb1;
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
-    const uint32_t src1_nrows = ne11 * ne12 * ne13;  // src1 rows
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;
 
-    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
-    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
 
-    // no work for this thread
-    if (src0_start_row >= src0_end_row) {
-        return;
+    // Preamble
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
     }
 
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
+    // Main loop
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
 
-    int is_aligned = 1;
-    int opt_path   = 0;
-    if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
-        (0 == htp_is_aligned((void *) dst->data, VLEN))) {
-        FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
-        is_aligned = 0;
-    }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
-    }
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
-    hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
 
-    uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
+        // src1 indices (broadcast/repeat)
+        uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+        uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+        uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
 
-    const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
-    uint8_t * restrict dst_ptr        = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
+        uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+        uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
 
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-
-    const uint32_t ne02_ne01 = ne02 * ne01;
-
-    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
-        const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
-        const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
-        const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
-
-        const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
-        const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
-        const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
-
-        const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
-
-        if (ir + 1 < src0_end_row) {
-            htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
-            if (src1_row_size == src0_row_size) {
-                htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
-            }
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+            COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
+            src1_ptr += s1_stride;
         }
 
-        const uint32_t nr0 = ne00 / ne10;
-        if (nr0 > 1) {
-            if ((1 == is_aligned) && (nr0 == ne00)) {
-                hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
-            } else {
-                for (uint32_t r = 0; r < nr0; r++) {
-                    memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
-                }
-            }
-            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
-        } else {
-            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+             ir_prefetch += next_block_size;
         }
-
-        src0_ptr += src0_row_size;
-        dst_ptr += dst_row_size;
+        ir += current_block_size;
     }
-
-    t2 = HAP_perf_get_qtimer_count();
-
-    FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
-         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
-         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+    dma_queue_flush(q);
 }
 
-static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
-                                             uint8_t *                spad_data,
-                                             uint32_t                 nth,
-                                             uint32_t                 ith,
-                                             hvx_elemwise_f32_func    func_HVX) {
+// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
+static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
-    const size_t src0_row_size = nb01;
-    const size_t src1_row_size = nb11;
-    const size_t dst_row_size  = nb1;
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
 
-    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
-    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
 
-    // no work for this thread
-    if (src0_start_row >= src0_end_row) {
-        return;
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint32_t i13 = (ne13 == 1) ? 0 : i03;
+        uint32_t i12 = (ne12 == 1) ? 0 : i02;
+        uint32_t i11 = (ne11 == 1) ? 0 : i01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
     }
 
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+        uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
 
-    if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
-        (0 == htp_is_aligned((void *) dst->data, VLEN))) {
-        FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
+            uint8_t * r_dst  = d_spad  + r * bctx->dst_row_size_aligned;
+            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
+        }
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+
+             uint32_t p13 = (ne13 == 1) ? 0 : p03;
+             uint32_t p12 = (ne12 == 1) ? 0 : p02;
+             uint32_t p11 = (ne11 == 1) ? 0 : p01;
+
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
+
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
+
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
+}
+
+// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
+static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
+
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    void * s1_ptr = (void *) src1_spad;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
     }
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
-    const uint32_t ne02_ne01  = ne02 * ne01;
-    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
-        // src0 indices
-        const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
-        const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
-        const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
+        }
 
-        // src1 indices
-        const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
-        assert(i11 >= 0 && i11 < ne11);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
-        float * restrict dst_ptr        = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
-        const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
-        const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
+}
 
-        if (ir + 1 < src0_end_row) {
-            htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
-            if (src1_row_size == src0_row_size) {
-                htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
+// 4. Vector Complex (ne10 == ne00, complex broadcast)
+static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
+
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
+
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint32_t r_i01 = i01 + r;
+            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+
+            // Read src1 from DDR (unaligned)
+            COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
+        }
+
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
+}
+
+// 5. Element Repeat (ne10 != ne00)
+static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
+
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
+    const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
+
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint32_t r_i01 = i01 + r;
+            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+
+            // Repeat src1 row
+            for (uint32_t c = 0; c < ne00; c += ne10) {
+                uint32_t len = MIN(ne10, ne00 - c);
+                // Use UUU for speed and simplicity
+                COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
             }
         }
 
-        const uint32_t nr0 = ne00 / ne10;
-        if (nr0 > 1) {
-            for (uint32_t r = 0; r < nr0; r++) {
-                memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
-            }
-            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
-        } else {
-            func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+             ir_prefetch += next_block_size;
         }
+        ir += current_block_size;
     }
-
-    t2 = HAP_perf_get_qtimer_count();
-
-    FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
-         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
-         src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
-         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+    dma_queue_flush(q);
 }
 
-static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+// 6. ADD_ID (src1 gathered via src2 indices)
+static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
 
-    switch (octx->op) {
-        case HTP_OP_MUL:
-        case HTP_OP_ADD:
-        case HTP_OP_SUB:
-            binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
-            break;
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    const struct htp_tensor * src2 = &octx->src2;
+    struct htp_tensor *       dst  = &octx->dst;
 
-        case HTP_OP_ADD_ID:
-            binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
-            break;
+    const uint32_t ne00 = src0->ne[0];
+    const uint32_t ne01 = src0->ne[1];
+    const uint32_t ne02 = src0->ne[2];
+    const uint32_t ne03 = src0->ne[3];
+    const uint32_t ne11 = src1->ne[1]; // for bounds check
 
-        default:
-            FARF(ERROR, "Unknown Binary Op %u", octx->op);
-            break;
+    const uint32_t nb01 = src0->nb[1];
+    const uint32_t nb02 = src0->nb[2];
+    const uint32_t nb03 = src0->nb[3];
+    const uint32_t nb11 = src1->nb[1]; // src1 row stride
+    const uint32_t nb1 = dst->nb[1];
+    const uint32_t nb2 = dst->nb[2];
+    const uint32_t nb3 = dst->nb[3];
+
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
     }
+
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
+
+            const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
+
+            uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+
+            hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
+        }
+
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
 }
 
-static int execute_op_binary_f32(struct htp_ops_context * octx) {
-    int err = HTP_STATUS_OK;
-
+static int execute_op_binary(struct htp_ops_context * octx) {
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t binary_op_func;
-    const char *      op_type = NULL;
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
-    switch (octx->op) {
-        case HTP_OP_MUL:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "mul-f32";
-            break;
+    // Use packed row sizes for VTCM allocation
+    const uint32_t src0_type = octx->src0.type;
+    const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
+    const size_t src0_row_size = src0->ne[0] * elem_size;
+    const size_t src1_row_size = src1->ne[0] * elem_size;
+    const size_t dst_row_size  = dst->ne[0] * elem_size;
 
-        case HTP_OP_ADD:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "add-f32";
-            break;
+    // Align to VLEN
+    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
 
-        case HTP_OP_SUB:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "sub-f32";
-            break;
+    bool is_add_id = (octx->op == HTP_OP_ADD_ID);
+    bool is_scalar = !is_add_id && (src1->ne[0] == 1);
 
-        case HTP_OP_ADD_ID:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "add-id-f32";
-            break;
+    // Determine which kernel we will use to alloc memory and dispatch
+    bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
+               (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
+               (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
+               (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
 
-        default:
-            FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
-            return HTP_STATUS_NO_SUPPORT;
+    bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
+    bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
+    bool use_repeat  = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
+
+    size_t spad_row_total;
+    if (is_scalar) {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    } else if (is_row_bcast) {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    } else if (use_vector_same) {
+        spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
+    } else if (is_add_id) {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
+    } else {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
     }
 
-    const int      n_threads  = octx->n_threads;
-    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
+    // Adjust for static src1 in row_bcast case
+    if (is_row_bcast) {
+        size_t needed_static = src1_row_size_aligned;
+        if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
+        size_t avail = octx->ctx->vtcm_size - needed_static;
+        rows_per_buffer = avail / (n_threads * spad_row_total);
+    }
 
-    const size_t src0_row_size = src0->nb[1];
-    const size_t src1_row_size = src1->nb[1];
-    const size_t dst_row_size  = dst->nb[1];
+    if (rows_per_buffer < 1) {
+         FARF(ERROR, "binary: VTCM too small\n");
+         return HTP_STATUS_VTCM_TOO_SMALL;
+    }
 
-    // VTCM scratchpads for all tensors
-    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
-    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
+    octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
+    octx->dst_spad.size_per_thread  = rows_per_buffer * 2 * dst_row_size_aligned;
 
-    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+    if (is_scalar || use_complex || use_repeat || is_add_id) {
+        octx->src1_spad.size_per_thread = 0;
+    } else if (is_row_bcast) {
+        octx->src1_spad.size_per_thread = 0;
+    } else {
+        octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
+    }
 
-    FARF(HIGH,
-         "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
-         op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
-         src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
-         octx->dst_spad.size);
+    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+    if (is_row_bcast) {
+        octx->src1_spad.size = src1_row_size_aligned;
+    } else {
+        octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
+    }
+    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
 
-    // Make sure the reserved vtcm size is sufficient
-    if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
-             octx->ctx->vtcm_size, spad_size);
+    if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
@@ -322,39 +813,79 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
     octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
     octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
 
-    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-
-        octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
-        octx->src0_div3  = init_fastdiv_values(src0->ne[3]);
-        octx->src0_div2  = init_fastdiv_values(src0->ne[2]);
-        octx->src0_div1  = init_fastdiv_values(src0->ne[1]);
-
-        octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
-        octx->src1_div3  = init_fastdiv_values(src1->ne[3]);
-        octx->src1_div2  = init_fastdiv_values(src1->ne[2]);
-        octx->src1_div1  = init_fastdiv_values(src1->ne[1]);
-
-        worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
+    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        return HTP_STATUS_OK;
     }
 
-    return err;
+    dma_queue * q = octx->ctx->dma[0];
+    if (is_row_bcast) {
+        dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
+    }
+
+    struct htp_binary_context bctx;
+    bctx.octx = octx;
+    bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+    bctx.block_max = rows_per_buffer;
+    bctx.src0_row_size_aligned = src0_row_size_aligned;
+    bctx.src1_row_size_aligned = src1_row_size_aligned;
+    bctx.dst_row_size_aligned  = dst_row_size_aligned;
+
+    bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
+    bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
+    bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
+
+    bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
+    bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
+    bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
+
+    bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
+    bool dst_contig_dim1  = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
+
+    bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
+    bool dst_contig_dim2  = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
+
+    bctx.split_at_ne01 = (src0->ne[2] > 1) &&
+                         ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
+
+    bctx.split_at_ne02 = (src0->ne[3] > 1) &&
+                         ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
+
+    // Precompute specific kernel parameters
+    if (use_vector_same) {
+        bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
+        bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
+    }
+
+    worker_callback_t worker_func;
+    if (is_add_id) worker_func = binary_job_add_id;
+    else if (is_scalar) worker_func = binary_job_scalar;
+    else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
+    else if (use_vector_same) worker_func = binary_job_vector_same_shape;
+    else if (use_complex) worker_func = binary_job_vector_complex;
+    else worker_func = binary_job_element_repeat;
+
+    if (is_row_bcast) {
+        dma_queue_pop(q);
+    }
+
+    worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
+
+    return HTP_STATUS_OK;
 }
 
 int op_binary(struct htp_ops_context * octx) {
-    int err = HTP_STATUS_OK;
 
-    switch (octx->src0.type) {
-        case HTP_TYPE_F32:
-            err = execute_op_binary_f32(octx);
-            break;
-
-        default:
-            err = HTP_STATUS_NO_SUPPORT;
-            break;
+    // Does not support permutations of src1
+    const struct htp_tensor * src1 = &octx->src1;
+    if (src1->nb[1] < src1->nb[0]) {
+        return HTP_STATUS_NO_SUPPORT;
     }
 
-    return err;
+    const uint32_t src0_type = octx->src0.type;
+    if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
+        return execute_op_binary(octx);
+    }
+
+    return HTP_STATUS_NO_SUPPORT;
 }
+
diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c
new file mode 100644
index 00000000..a40d866b
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c
@@ -0,0 +1,252 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include 
+#include 
+
+#include 
+#include 
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+
+struct htp_copy_context {
+    struct htp_ops_context * octx;
+
+    uint32_t          src0_type_size;
+    uint32_t          src0_block_size;
+
+    uint32_t          dst_type_size;
+    uint32_t          dst_block_size;
+
+    uint32_t          src0_blocks_per_row;
+    uint32_t          dst_blocks_per_row;
+
+    uint32_t          src0_nrows_per_thread;
+
+    void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith);
+};
+
+#define cpy_preamble                       \
+    struct htp_tensor *src0 = &octx->src0; \
+    struct htp_tensor *dst  = &octx->dst;  \
+                                           \
+    const uint32_t ne00 = src0->ne[0];     \
+    const uint32_t ne01 = src0->ne[1];     \
+    const uint32_t ne02 = src0->ne[2];     \
+    const uint32_t ne03 = src0->ne[3];     \
+                                           \
+    const uint32_t nb00 = src0->nb[0];     \
+    const uint32_t nb01 = src0->nb[1];     \
+    const uint32_t nb02 = src0->nb[2];     \
+    const uint32_t nb03 = src0->nb[3];     \
+                                           \
+    const uint32_t  ne0 = dst->ne[0];      \
+    const uint32_t  ne1 = dst->ne[1];      \
+    const uint32_t  ne2 = dst->ne[2];      \
+    const uint32_t  ne3 = dst->ne[3];      \
+                                           \
+    const uint32_t  nb0 = dst->nb[0];      \
+    const uint32_t  nb1 = dst->nb[1];      \
+    const uint32_t  nb2 = dst->nb[2];      \
+    const uint32_t  nb3 = dst->nb[3];      \
+                                           \
+    const uint32_t   nr = ne01;
+
+static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
+    cpy_preamble;
+
+    // parallelize by src0 rows
+    const uint32_t dr  = ct->src0_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+    // copy by rows
+    for (uint32_t i03 = 0; i03 < ne03; i03++) {
+        for (uint32_t i02 = 0; i02 < ne02; i02++) {
+            #pragma unroll(2)
+            for (uint32_t i01 = ir0; i01 < ir1; i01++) {
+                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;
+                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2);
+                hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size);
+            }
+        }
+    }
+}
+
+static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) {
+    cpy_preamble;
+
+    // parallelize by src0 rows
+    const uint32_t dr  = ct->src0_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+    // dst counters
+    int64_t k10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    // number of blocks in a row
+    const int64_t nk00 = ct->src0_blocks_per_row;
+    const int64_t nk0  = ct->dst_blocks_per_row;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            k10 += nk00 * ir0;
+            while (k10 >= nk0) {
+                k10 -= nk0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+            for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                for (int64_t k00 = 0; k00 < nk00; k00++) {
+                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+                    memcpy(dst_ptr, src0_ptr, ct->dst_type_size);
+
+                    if (++k10 == nk0) {
+                        k10 = 0;
+                        if (++i11 == ne1) {
+                            i11 = 0;
+                            if (++i12 == ne2) {
+                                i12 = 0;
+                                if (++i13 == ne3) {
+                                    i13 = 0;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            k10 += nk00 * (ne01 - ir1);
+            while (k10 >= nk0) {
+                k10 -= nk0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
+    cpy_preamble;
+
+    // parallelize by src0 rows
+    const uint32_t dr  = ct->src0_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+    // copy by rows
+    for (uint32_t i03 = 0; i03 < ne03; i03++) {
+        for (uint32_t i02 = 0; i02 < ne02; i02++) {
+            #pragma unroll(2)
+            for (uint32_t i01 = ir0; i01 < ir1; i01++) {
+                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;
+                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2);
+                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
+            }
+        }
+    }
+}
+
+static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) {
+    cpy_preamble;
+
+    // parallelize by src0 rows
+    const uint32_t dr  = ct->src0_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr;
+
+    // copy by rows
+    for (uint32_t i03 = 0; i03 < ne03; i03++) {
+        for (uint32_t i02 = 0; i02 < ne02; i02++) {
+            #pragma unroll(2)
+            for (uint32_t i01 = ir0; i01 < ir1; i01++) {
+                uint8_t* dst_ptr  = (uint8_t*) dst->data  + i01*nb1  + i02*nb2  + i03*nb3;
+                uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2);
+                hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00);
+            }
+        }
+    }
+}
+
+static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
+    struct htp_copy_context *ct = (struct htp_copy_context *) data;
+    ct->copy(ct, ct->octx, n, i);
+}
+
+int op_cpy(struct htp_ops_context * octx) {
+    cpy_preamble;
+
+    const uint32_t n_threads = MIN(nr, octx->n_threads);
+
+    struct htp_copy_context ct;
+    ct.octx = octx;
+
+    switch (src0->type) {
+    case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
+    case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break;
+    default:
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    switch (dst->type) {
+    case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
+    case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break;
+    default:
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+        return HTP_STATUS_OK;
+    }
+
+    const bool sametype   = (src0->type == dst->type);
+    const bool transposed = (nb00 > nb01) || (nb0 > nb1);
+    const bool sameshape  = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
+
+    ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
+
+    if (sametype && sameshape) {
+        ct.copy = cpy_thread_sametype_sameshape;
+    } else if (sameshape) {
+        /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32)
+            ct.copy = cpy_thread_f16_f32_sameshape;
+        else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16)
+            ct.copy = cpy_thread_f32_f16_sameshape;
+        else
+            return HTP_STATUS_NO_SUPPORT;
+    } else if (sametype) {
+        ct.copy = cpy_thread_sametype_reshape;
+    } else {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);
+
+    return HTP_STATUS_OK;
+}
diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c
index 04a7b843..6dc978dd 100644
--- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c
+++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c
@@ -2,78 +2,30 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
+#include 
 #include 
-#include 
 #include 
-#include 
-#include 
 #include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+#include "hvx-dump.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
-// Dot product of FP32 and FP16 vectors, accumulating to float
-static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
-    const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
-    const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+// Must be multiple of 32
+#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
 
-    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
-    uint32_t nloe = n % VLEN_FP16; // leftover elements
-
-    const HVX_Vector zero = Q6_V_vsplat_R(0);
-    HVX_Vector       rsum = Q6_V_vsplat_R(0);
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (i = 0; i < nvec; i++) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
-        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
-        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
-
-        // Load x (fp16)
-        HVX_Vector x_hf  = vx[i];
-
-        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
-    }
-
-    if (nloe) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
-        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
-        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
-
-        // Load x (fp16)
-        HVX_Vector x_hf  = vx[i];
-
-        // Zero-out unused elements
-        // Note that we need to clear both x and y because they may contain NANs
-        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
-        x_hf = Q6_V_vand_QV(bmask, x_hf);
-        y_hf = Q6_V_vand_QV(bmask, y_hf);
-
-        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
-    }
-
-    rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
-
-    hvx_vec_store_u(r, 4, rsum);
+// This is a bit of a hack because the compiler is strugling to properly inline
+// the default hvx_vec_f32_to_f16 with output into the local array.
+static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
+{
+    *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
 }
 
 // Dot product of two F16 vectors, accumulating to float
@@ -84,84 +36,254 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
     uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
     uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-    const HVX_Vector zero = Q6_V_vsplat_R(0);
-    HVX_Vector       rsum = Q6_V_vsplat_R(0);
+    HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
 
     uint32_t i = 0;
 
     #pragma unroll(4)
     for (i = 0; i < nvec; i++) {
-        HVX_Vector y_hf = vy[i];
-        HVX_Vector x_hf = vx[i];
-
-        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
     }
 
     if (nloe) {
-        HVX_Vector y_hf = vy[i];
-
-        // Load x (fp16) and zero-out unused elements
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
-        HVX_Vector      x_hf = Q6_V_vand_QV(bmask, vx[i]);
+        HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
+        HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
 
-        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
     }
 
-    rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
+    rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
     hvx_vec_store_u(r, 4, rsum);
 }
 
-// MAD: y (F32) += x (F16) * v (float)
-static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
-    const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
-    HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
+                                                const uint8_t * restrict x,
+                                                const size_t stride_x,
+                                                const size_t nvec,
+                                                const size_t nloe) {
+    const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x;                   // fp16
+    const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x);      // fp16
+    const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2);  // fp16
+    const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3);  // fp16
+    const HVX_Vector * restrict vy  = (const HVX_Vector * restrict) y;                   // fp16
+
+    HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+    HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+    HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+    HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
+
+    uint32_t i = 0;
+
+    for (i = 0; i < nvec; i++) {
+        HVX_Vector y_hf  = vy[i];
+        HVX_Vector x0_hf = vx0[i];
+        HVX_Vector x1_hf = vx1[i];
+        HVX_Vector x2_hf = vx2[i];
+        HVX_Vector x3_hf = vx3[i];
+
+        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
+        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
+        rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
+        rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
+    }
+
+    if (nloe) {
+        // Load x (fp16) and zero-out unused elements
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector     y_hf  = Q6_V_vand_QV(bmask, vy[i]);
+        HVX_Vector     x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
+        HVX_Vector     x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
+        HVX_Vector     x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
+        HVX_Vector     x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
+
+        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
+        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
+        rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
+        rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
+    }
+
+    HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
+    HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
+    HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
+    HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
+
+    HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
+    return hvx_vec_reduce_sum_f32x4(rsum0123);
+}
+
+static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
+                                                 const uint8_t * restrict x,
+                                                 const size_t stride_x,
+                                                 const size_t n,
+                                                 float        s) {
+
+    const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    const size_t nloe = n % VLEN_FP16; // leftover elements
+
+    HVX_Vector   sums;  // initialize at j = 0
+    const size_t stride_x_4 = stride_x * 4;
+    for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
+        HVX_Vector     sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
+        HVX_VectorPred pred    = Q6_Q_vsetq_R(j * SIZEOF_FP32);
+        sums                   = Q6_V_vmux_QVV(pred, sums, sums_x4);
+        x += stride_x_4;
+    }
+
+    sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
+    return Q6_Vsf_equals_Vqf32(sums);
+}
+
+// MAD: y (F32) += x (F16) * s (F16)
+static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) {
+    const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
+
+    HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
+    HVX_Vector * restrict vy = (HVX_Vector *) y;
 
     uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
     uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-    HVX_Vector S = hvx_vec_splat_fp16(s);
+    HVX_Vector S0 = hvx_vec_splat_f16(*s);
 
     uint32_t i = 0;
-    #pragma unroll(4)
+
+    #pragma unroll(2)
     for (i = 0; i < nvec; ++i) {
-        // Multiply x * s -> pair of F32 vectors
-        HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
-        ptr_y[i*2]   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
-        ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
+        vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
     }
 
     if (nloe) {
-        HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+        HVX_VectorPair xy_p = vy_p[i];
+        xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
 
-        HVX_Vector xs = Q6_V_lo_W(xs_p);
-        i = 2 * i; // index for ptr_y
+        HVX_Vector xy = Q6_V_lo_W(xy_p);
+        i = 2 * i;  // index for vy
 
-        if (nloe >= 32) {
-            ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
-            nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
+        if (nloe >= VLEN_FP32) {
+            vy[i] = xy;
+            nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
         }
 
         if (nloe) {
-            HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
-            hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
+            hvx_vec_store_a(&vy[i], nloe * 4, xy);
         }
     }
 }
 
-#define FLASH_ATTN_BLOCK_SIZE 128
+// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
+static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
+                                          const __fp16 * restrict s0, const __fp16 * restrict s1, int n) {
+    const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
+    const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
 
-static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
+    HVX_VectorPair * restrict vy_p  = (HVX_VectorPair *) y;
+    HVX_Vector * restrict vy        = (HVX_Vector *) y;
+
+    uint32_t nvec = n / VLEN_FP16;  // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16;  // leftover elements
+
+    HVX_Vector S0 = hvx_vec_splat_f16(*s0);
+    HVX_Vector S1 = hvx_vec_splat_f16(*s1);
+
+    uint32_t i = 0;
+
+    #pragma unroll(2)
+    for (i = 0; i < nvec; ++i) {
+        vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
+        vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
+    }
+
+    if (nloe) {
+        HVX_VectorPair xy_p = vy_p[i];
+        xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
+        xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
+
+        HVX_Vector xy = Q6_V_lo_W(xy_p);
+        i = 2 * i;  // index for vy
+
+        if (nloe >= VLEN_FP32) {
+            vy[i] = xy;
+            nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
+        }
+
+        if (nloe) {
+            hvx_vec_store_a(&vy[i], nloe * 4, xy);
+        }
+    }
+}
+
+struct htp_fa_context {
+    const struct htp_ops_context * octx;
+
+    struct fastdiv_values src0_div21;
+    struct fastdiv_values src0_div1;
+
+    struct fastdiv_values broadcast_rk2;
+    struct fastdiv_values broadcast_rk3;
+    struct fastdiv_values broadcast_rv2;
+    struct fastdiv_values broadcast_rv3;
+
+    struct fastdiv_values src3_div2;
+    struct fastdiv_values src3_div3;
+
+    float scale;
+    float max_bias;
+    float logit_softcap;
+
+    uint32_t n_head_log2;
+    float m0;
+    float m1;
+
+    uint32_t n_blocks;
+
+    size_t size_q_row_padded;
+    size_t size_k_row_padded;
+    size_t size_v_row_padded;
+
+    size_t size_k_block;
+    size_t size_v_block;
+    size_t size_m_block;
+
+    uint32_t qrows;
+    uint32_t qrows_per_thread;
+
+    bool is_q_fp32;
+
+    uint64_t t_start;
+};
+
+static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
+    assert((size_t) dst % 128 == 0);
+    assert((size_t) src % 128 == 0);
+
+    const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
+    HVX_Vector * restrict vdst       = (HVX_Vector * restrict) dst;
+
+    const uint32_t nvec = n / VLEN_FP32;
+    const uint32_t nloe = n % VLEN_FP32;
+
+    uint32_t i = 0;
+    #pragma unroll(4)
+    for (; i < nvec; ++i) {
+        vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
+    }
+    if (nloe) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+        hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_fa_context * factx = (struct htp_fa_context *) data;
+    const struct htp_ops_context * octx = factx->octx;
     const struct htp_tensor * q = &octx->src0;
     const struct htp_tensor * k = &octx->src1;
     const struct htp_tensor * v = &octx->src2;
     const struct htp_tensor * mask  = (octx->src3.data) ? &octx->src3 : NULL;
     const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
-    struct htp_tensor * dst = &octx->dst;
+    const struct htp_tensor * dst = &octx->dst;
 
     const uint32_t neq0 = q->ne[0];
     const uint32_t neq1 = q->ne[1];
@@ -198,22 +320,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
     const uint32_t nb2 = dst->nb[2];
     const uint32_t nb3 = dst->nb[3];
 
-    float scale         = 1.0f;
-    float max_bias      = 0.0f;
-    float logit_softcap = 0.0f;
-
-    memcpy(&scale,         (float *) octx->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (float *) octx->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
-
-    if (logit_softcap != 0) {
-        scale /= logit_softcap;
-    }
-
     // total rows in q
-    const uint32_t nr = neq1*neq2*neq3;
-
-    const uint32_t dr = (nr + nth - 1) / nth;
+    const uint32_t nr = factx->qrows;
+    const uint32_t dr = factx->qrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = MIN(ir0 + dr, nr);
 
@@ -225,18 +334,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
     const uint32_t DV = nev0;
 
     const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
-    const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
-
     const size_t size_k_row = DK * sizeof(__fp16);
     const size_t size_v_row = DV * sizeof(__fp16);
-    const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
-
-    const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
-    const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
-
-    const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
-    const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
-    const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
 
     // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
     uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
@@ -245,72 +344,79 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
     uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
     uint8_t * spad_a = octx->dst_spad.data  + octx->dst_spad.size_per_thread  * ith;
 
-    const uint32_t n_head = neq2;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+    const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
 
     for (uint32_t ir = ir0; ir < ir1; ++ir) {
-        const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
-        const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
+        const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
+        const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
         const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
 
-        const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
-        const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
+        const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
+        const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
 
-        const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
-        const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
+        const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
+        const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
 
         // Fetch Q row
         const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
-        dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
+        dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
 
-        const uint32_t h = iq2; // head index
-        const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
-
-        float S = 0.0f;      // sum
-        float M = -INFINITY; // maximum KQ value
-
-        // Clear accumulator
-        float * VKQ32 = (float *) spad_a;
-        memset(VKQ32, 0, DV * sizeof(float));
+        // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row,
+        //                 (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
 
         const __fp16 * mp_base = NULL;
         if (mask) {
-            const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
-            const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
+            const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
+            const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
             mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
         }
 
-        const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
-
         // Prefetch first two blocks
-        for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
+        for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
             const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
             const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
 
             // K
             const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
-            uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
-            dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
+            uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
+            dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
 
             // V
             const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
-            uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
-            dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
+            uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
+            dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
 
             // Mask
             if (mask) {
                 const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
-                uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
+                uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
                 // Mask is 1D contiguous for this row
                 dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
             }
+
+            // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
+            //             ith, ir, ib, iq1, iq2, iq3,
+            //             size_k_row, size_v_row, current_block_size,
+            //             (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
         }
 
-        const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
+        const uint32_t h = iq2; // head index
+        const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
 
-        for (uint32_t ib = 0; ib < n_blocks; ++ib) {
+        HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
+        HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
+
+        // Clear accumulator
+        hvx_splat_f32_a(spad_a, 0, DV);
+        float * VKQ32 = (float *) (spad_a + 0);
+
+        uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
+        if (factx->is_q_fp32) {
+            hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK);  // inplace convert f32 to f16
+        }
+
+        const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
+        for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
             const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
             const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
 
@@ -319,156 +425,166 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
             uint8_t * v_base = dma_queue_pop(dma).dst; // V
             __fp16  * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
 
+            // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u",
+            //              ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm,
+            //             (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
+
             // Inner loop processing the block from VTCM
             uint32_t ic = 0;
 
-            // Process in blocks of 32 (VLEN_FP32)
-            for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
+            // Process in sub-blocks of 32 (VLEN_FP32)
+            HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
+            HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
+            for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
                 // 1. Compute scores
-                float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
-                for (int j = 0; j < VLEN_FP32; ++j) {
-                    const uint32_t cur_ic = ic + j;
-                    const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
-                    if (q->type == HTP_TYPE_F32) {
-                        hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
-                    } else {
-                        hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
-                    }
-                }
-
-                HVX_Vector scores = *(HVX_Vector *) scores_arr;
+                HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);
 
                 // 2. Softcap
-                if (logit_softcap != 0.0f) {
-                    scores = hvx_vec_tanh_fp32(scores);
-                    scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
+                if (factx->logit_softcap != 0.0f) {
+                    scores = hvx_vec_tanh_f32(scores);
+                    scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
                     scores = Q6_Vsf_equals_Vqf32(scores);
                 }
 
                 // 3. Mask
                 if (mask) {
                     const __fp16 * mp = m_base + ic;
-                    HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
-
-                    HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
-                    HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
-
-                    HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
-
-                    HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
-                    HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
-                    scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
+                    HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
+                    HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
+                    HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
+                    scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
                     scores = Q6_Vsf_equals_Vqf32(scores);
                 }
 
-                // 4. Online Softmax Update
-                HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
-                float m_block = hvx_vec_get_fp32(v_max);
-
-                float M_old = M;
-                float M_new = (m_block > M) ? m_block : M;
-                M = M_new;
-
-                float ms = expf(M_old - M_new);
-
-                hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
-                S = S * ms;
-
-                HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
-                HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
-                HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
-
-                HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
-                float p_sum = hvx_vec_get_fp32(p_sum_vec);
-                S += p_sum;
-
-                // 5. Accumulate V
-                float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
-                *(HVX_Vector*)p_arr = P;
-
-                for (int j = 0; j < VLEN_FP32; ++j) {
-                    const uint32_t cur_ic = ic + j;
-                    const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
-                    hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
-                }
+                sb_scores[iv] = scores;
+                v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
             }
 
-            // Leftover
-            for (; ic < current_block_size; ++ic) {
-                float s_val;
-                const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
+            {
+                // 4. Online Softmax Update
+                HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
+                HVX_Vector diff_vec  = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
+                HVX_Vector ms_vec    = hvx_vec_exp_f32(diff_vec);
+                M_vec = M_new_vec;
 
-                if (q->type == HTP_TYPE_F32) {
-                    hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
-                } else {
-                    hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+                hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
+
+                HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
+                for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
+                    HVX_Vector scores = sb_scores[iv];
+                    HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
+                    HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+                    p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
+
+                    // 5. Accumulate V
+                    __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
+                    hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));
+
+                    for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
+                        const uint32_t  cur_ic = ic2 + j;
+                        const uint8_t * v_ptr  = v_base + cur_ic * factx->size_v_row_padded;
+                        hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
+                    }
                 }
 
-                if (logit_softcap != 0.0f) {
-                    s_val = logit_softcap * tanhf(s_val);
+                p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
+                S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
+            }
+
+            if (ic < current_block_size) {
+                // Sync scalars for leftover/next block if needed
+                float M = hvx_vec_get_f32(M_vec);
+                float S = hvx_vec_get_f32(S_vec);
+
+                // Leftover
+                for (; ic < current_block_size; ++ic) {
+                    float s_val;
+                    const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
+                    hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
+                    if (factx->logit_softcap != 0.0f) {
+                        s_val = factx->logit_softcap * tanhf(s_val);
+                    }
+
+                    if (mask) {
+                        const float m_val = m_base[ic];
+                        s_val += slope * m_val;
+                    }
+
+                    const float Mold = M;
+                    __fp16 vs = 1.0f;
+
+                    if (s_val > M) {
+                        M = s_val;
+                        HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
+                        HVX_Vector ms_vec   = hvx_vec_exp_f32(diff_vec);
+                        hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
+
+                        float ms = hvx_vec_get_f32(ms_vec);
+                        S = S * ms + vs;
+                    } else {
+                        HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
+                        vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
+                        S += vs;
+                    }
+
+                    const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
+
+                    hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
                 }
 
-                if (mask) {
-                    const float m_val = m_base[ic];
-                    s_val += slope * m_val;
-                }
-
-                const float Mold = M;
-                float ms = 1.0f;
-                float vs = 1.0f;
-
-                if (s_val > M) {
-                    M = s_val;
-                    ms = expf(Mold - M);
-                    hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
-                } else {
-                    vs = expf(s_val - M);
-                }
-
-                const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
-
-                hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
-
-                S = S * ms + vs;
+                M_vec = hvx_vec_splat_f32(M);
+                S_vec = hvx_vec_splat_f32(S);
             }
 
             // Issue DMA for next+1 block (if exists)
-            if (ib + 2 < n_blocks) {
+            if (ib + 2 < factx->n_blocks) {
                 const uint32_t next_ib = ib + 2;
                 const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
                 const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
 
                 // K
                 const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
-                dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
+                dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
 
                 // V
                 const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
-                dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
+                dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
 
                 // Mask
                 if (mask) {
                     const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
                     dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
                 }
+
+                // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
+                //         ith, ir, next_ib, iq1, iq2, iq3,
+                //         size_k_row, size_v_row, next_block_size,
+                //         (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start));
             }
         }
 
         // sinks
+        float M = hvx_vec_get_f32(M_vec);
+        float S = hvx_vec_get_f32(S_vec);
+
         if (sinks) {
             const float s = ((float *)((char *) sinks->data))[h];
 
-            float ms = 1.0f;
             float vs = 1.0f;
 
             if (s > M) {
-                ms = expf(M - s);
-                hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
-            } else {
-                vs = expf(s - M);
-            }
+                HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
+                HVX_Vector ms_vec   = hvx_vec_exp_f32(diff_vec);
+                hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
 
-            S = S * ms + vs;
+                float ms = hvx_vec_get_f32(ms_vec);
+                S = S * ms + vs;
+            } else {
+                HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
+                vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
+                S += vs;
+            }
         }
 
         const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
@@ -484,60 +600,91 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
         uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
 
         if (dst->type == HTP_TYPE_F32) {
-            hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+            hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
         } else if (dst->type == HTP_TYPE_F16) {
-            hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+            hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
         }
     }
 }
 
-static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    flash_attn_ext_f16_thread(octx, i, n);
-}
-
 int op_flash_attn_ext(struct htp_ops_context * octx) {
     const struct htp_tensor * q = &octx->src0;
     const struct htp_tensor * k = &octx->src1;
     const struct htp_tensor * v = &octx->src2;
-    const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
-    struct htp_tensor * dst = &octx->dst;
+    const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
+    const struct htp_tensor * dst = &octx->dst;
 
     // Check support
-    if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
-        k->type != HTP_TYPE_F16 ||
-        v->type != HTP_TYPE_F16) {
+    if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
         return HTP_STATUS_NO_SUPPORT;
     }
 
-    octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
-    octx->src0_div1  = init_fastdiv_values(q->ne[1]);
+    struct htp_fa_context factx;
+    factx.octx = octx;
 
-    octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
-    octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
-    octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
-    octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
+    factx.t_start = HAP_perf_get_qtimer_count();
+
+    factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
+    factx.src0_div1  = init_fastdiv_values(q->ne[1]);
+
+    factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
+    factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
+    factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
+    factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
 
     if (mask) {
-        octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
-        octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
+        factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
+        factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
     }
 
-    size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
-    size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
-    size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
+    factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
+    factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
+    factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
+    factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
 
-    size_t size_q_block = size_q_row_padded * 1; // single row for now
-    size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
-    size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
-    size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+    size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
+    factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+    factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+    factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
 
-    size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
+    factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) octx->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) octx->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
+
+    factx.scale = scale;
+    factx.max_bias = max_bias;
+    factx.logit_softcap = logit_softcap;
+
+    uint32_t n_head = q->ne[2];
+    factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+    factx.m0 = powf(2.0f, -(max_bias       ) / factx.n_head_log2);
+    factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
+
+    // total rows in q
+    const uint32_t neq0 = q->ne[0];
+    const uint32_t neq1 = q->ne[1];
+    const uint32_t neq2 = q->ne[2];
+    const uint32_t neq3 = q->ne[3];
+
+    factx.qrows = neq1*neq2*neq3;
+    factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads;
+
+    size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
 
     octx->src0_spad.size_per_thread = size_q_block * 1;
-    octx->src1_spad.size_per_thread = size_k_block * 2;
-    octx->src2_spad.size_per_thread = size_v_block * 2;
-    octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
+    octx->src1_spad.size_per_thread = factx.size_k_block * 2;
+    octx->src2_spad.size_per_thread = factx.size_v_block * 2;
+    octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
     octx->dst_spad.size_per_thread  = size_vkq_acc;
 
     octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -559,7 +706,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
     octx->dst_spad.data  = octx->src3_spad.data + octx->src3_spad.size;
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
+        worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
     }
 
     return HTP_STATUS_OK;
diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c
index 54321421..047d2850 100644
--- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c
+++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c
@@ -2,14 +2,9 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
 #include 
-#include 
 #include 
-#include 
-#include 
+
 #include 
 #include 
 
@@ -19,7 +14,13 @@
 #include "htp-msg.h"
 #include "htp-ops.h"
 #include "hvx-utils.h"
-#include "ops-utils.h"
+
+struct get_rows_context {
+    struct htp_ops_context * octx;
+    uint32_t src1_nrows_per_thread;
+    struct fastdiv_values get_rows_div_ne10;
+    struct fastdiv_values get_rows_div_ne10_ne11;
+};
 
 #define get_rows_preamble \
     const uint32_t ne00 = octx->src0.ne[0]; \
@@ -45,20 +46,22 @@
                                             \
     const uint32_t nr = ne10 * ne11 * ne12;
 
-static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+    struct get_rows_context * grctx = (struct get_rows_context *)data;
+    struct htp_ops_context * octx = grctx->octx;
     get_rows_preamble;
 
     // parallelize by src1 elements (which correspond to dst rows)
-    const uint32_t dr  = octx->src1_nrows_per_thread;
+    const uint32_t dr  = grctx->src1_nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 
     const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
 
     for (uint32_t i = ir0; i < ir1; ++i) {
-        const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
+        const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
         const uint32_t rem = i - i12 * ne11 * ne10;
-        const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
+        const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
         const uint32_t i10 = rem - i11 * ne10;
 
         const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@@ -72,19 +75,15 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
 
         const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
         const uintptr_t dst_ptr  = octx->dst.data  + i10*nb1  + i11*nb2  + i12*nb3;
-        hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+        hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
     }
-
-    return HTP_STATUS_OK;
-}
-
-static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
-    get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
 }
 
 int op_get_rows(struct htp_ops_context * octx) {
     get_rows_preamble;
 
+    const uint32_t n_threads = MIN(nr, octx->n_threads);
+
     if (octx->src0.type != HTP_TYPE_F32) {
         return HTP_STATUS_NO_SUPPORT;
     }
@@ -101,12 +100,13 @@ int op_get_rows(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    octx->get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
-    octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
+    struct get_rows_context grctx;
+    grctx.octx = octx;
+    grctx.get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
+    grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
 
-    const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
 
-    worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);
     return HTP_STATUS_OK;
 }
diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.c b/ggml/src/ggml-hexagon/htp/hex-dma.c
similarity index 98%
rename from ggml/src/ggml-hexagon/htp/htp-dma.c
rename to ggml/src/ggml-hexagon/htp/hex-dma.c
index 880c4542..44e1be40 100644
--- a/ggml/src/ggml-hexagon/htp/htp-dma.c
+++ b/ggml/src/ggml-hexagon/htp/hex-dma.c
@@ -1,4 +1,4 @@
-#include "htp-dma.h"
+#include "hex-dma.h"
 
 #include 
 #include 
diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h
similarity index 84%
rename from ggml/src/ggml-hexagon/htp/htp-dma.h
rename to ggml/src/ggml-hexagon/htp/hex-dma.h
index 32fd06e7..350ab9d9 100644
--- a/ggml/src/ggml-hexagon/htp/htp-dma.h
+++ b/ggml/src/ggml-hexagon/htp/hex-dma.h
@@ -2,7 +2,6 @@
 #define HTP_DMA_H
 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -103,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
     dmlink(q->tail, desc);
     q->tail = desc;
 
-    // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
+    // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
     q->push_idx = (q->push_idx + 1) & q->idx_mask;
     return true;
 }
@@ -145,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
 
     dptr = q->dptr[q->pop_idx];
 
-    // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
+    // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
     q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
     return dptr;
 }
 
+static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
+    dma_ptr dptr  = { NULL };
+
+    if (q->push_idx == q->pop_idx) {
+        return dptr;
+    }
+
+    dptr = q->dptr[q->pop_idx];
+
+    // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
+    q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
+    return dptr;
+}
+
+static inline bool dma_queue_empty(dma_queue * q) {
+    return q->push_idx == q->pop_idx;
+}
+
+static inline uint32_t dma_queue_depth(dma_queue * q) {
+    return (q->push_idx - q->pop_idx) & q->idx_mask;
+}
+
+static inline uint32_t dma_queue_capacity(dma_queue * q) {
+    return q->capacity;
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/ggml/src/ggml-hexagon/htp/hex-dump.h b/ggml/src/ggml-hexagon/htp/hex-dump.h
new file mode 100644
index 00000000..e3badb57
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hex-dump.h
@@ -0,0 +1,77 @@
+#ifndef HEX_DUMP_H
+#define HEX_DUMP_H
+
+#include 
+
+static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) {
+    char str[1024], *p = str, *p_end = str + sizeof(str);
+    p += snprintf(p, p_end - p, "%s: ", pref);
+    for (int i = 0; i < n && p < p_end; i++) {
+        p += snprintf(p, p_end - p, "%d, ", x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
+    char str[1024], *p = str, *p_end = str + sizeof(str);
+    p += snprintf(p, p_end - p, "%s: ", pref);
+    for (int i = 0; i < n && p < p_end; i++) {
+        p += snprintf(p, p_end - p, "%d, ", x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
+    char str[1024], *p = str, *p_end = str + sizeof(str);
+    p += snprintf(p, p_end - p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) {
+    char str[1024], *p = str, *p_end = str + sizeof(str);
+    p += snprintf(p, p_end - p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) {
+    char str[1024], *p = str, *p_end = str + sizeof(str);
+    p += snprintf(p, p_end - p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += snprintf(p, p_end - p, "%.6f, ", x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
+static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) {
+    uint32_t n0 = n / 16;
+    uint32_t n1 = n % 16;
+
+    uint32_t i = 0;
+    for (; i < n0; i++) {
+        hex_dump_f32_line(pref, x + (16 * i), 16);
+    }
+    if (n1) {
+        hex_dump_f32_line(pref, x + (16 * i), n1);
+    }
+}
+
+static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
+    uint32_t n0 = n / 16;
+    uint32_t n1 = n % 16;
+
+    uint32_t i = 0;
+    for (; i < n0; i++) {
+        hex_dump_f16_line(pref, x + (16 * i), 16);
+    }
+    if (n1) {
+        hex_dump_f16_line(pref, x + (16 * i), n1);
+    }
+}
+
+#endif /* HEX_DUMP_H */
diff --git a/ggml/src/ggml-hexagon/htp/hex-fastdiv.h b/ggml/src/ggml-hexagon/htp/hex-fastdiv.h
new file mode 100644
index 00000000..b7b58675
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hex-fastdiv.h
@@ -0,0 +1,37 @@
+#ifndef HEX_FASTDIV_H
+#define HEX_FASTDIV_H
+
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+struct fastdiv_values {
+    uint32_t mp;
+    uint32_t l;
+};
+
+static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
+    struct fastdiv_values result = { 0, 0 };
+    // compute L = ceil(log2(d));
+    while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
+        ++(result.l);
+    }
+
+    result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
+    return result;
+}
+
+static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
+    // Compute high 32 bits of n * mp
+    const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32);  // mulhi(n, mp)
+    // add n, apply bit shift
+    return (hi + n) >> vals->l;
+}
+
+static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
+    return n - fastdiv(n, vals) * d;
+}
+
+#endif /* HEX_FASTDIV_H */
diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h
new file mode 100644
index 00000000..fb8a25a3
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hex-utils.h
@@ -0,0 +1,51 @@
+#ifndef HEX_UTILS_H
+#define HEX_UTILS_H
+
+#include 
+#include 
+
+#include "hexagon_types.h"
+
+#include "hex-fastdiv.h"
+#include "hex-dump.h"
+
+#ifndef MAX
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#endif
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+static inline uint64_t hex_get_cycles() {
+    uint64_t cycles = 0;
+    asm volatile(" %0 = c15:14\n" : "=r"(cycles));
+    return cycles;
+}
+
+static inline uint64_t hex_get_pktcnt() {
+    uint64_t pktcnt;
+    asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
+    return pktcnt;
+}
+
+static inline int32_t hex_is_aligned(void * addr, uint32_t align) {
+    return ((size_t) addr & (align - 1)) == 0;
+}
+
+static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
+    uint32_t left_off  = (size_t) addr & (chunk_size - 1);
+    uint32_t right_off = left_off + n;
+    return right_off <= chunk_size;
+}
+
+static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
+    return m * ((n + m - 1) / m);
+}
+
+static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
+    const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
+    Q6_l2fetch_AP((void *) p, control);
+}
+
+#endif /* HEX_UTILS_H */
diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h
index 4bd0ea7a..a707d982 100644
--- a/ggml/src/ggml-hexagon/htp/htp-ctx.h
+++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h
@@ -1,7 +1,7 @@
 #ifndef HTP_CTX_H
 #define HTP_CTX_H
 
-#include "htp-dma.h"
+#include "hex-dma.h"
 #include "worker-pool.h"
 
 #include 
diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h
index 846d0617..52dcc36d 100644
--- a/ggml/src/ggml-hexagon/htp/htp-msg.h
+++ b/ggml/src/ggml-hexagon/htp/htp-msg.h
@@ -42,31 +42,37 @@ enum htp_data_type {
     HTP_TYPE_COUNT
 };
 
-// These values are manually translated over to HTP
-// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
+// Do not reorder first 4 (used as an index)
 enum htp_op {
-    HTP_OP_MUL            = 0,
-    HTP_OP_ADD            = 1,
-    HTP_OP_SUB            = 2,
-    HTP_OP_DIV            = 3,
-    HTP_OP_MUL_MAT        = 4,
-    HTP_OP_MUL_MAT_ID     = 5,
-    HTP_OP_RMS_NORM       = 6,
-    HTP_OP_UNARY_SILU     = 7,
-    HTP_OP_UNARY_GELU     = 8,
-    HTP_OP_GLU_SWIGLU     = 9,
-    HTP_OP_GLU_SWIGLU_OAI = 10,
-    HTP_OP_SOFTMAX        = 11,
-    HTP_OP_ADD_ID         = 12,
-    HTP_OP_ROPE           = 13,
-    HTP_OP_FLASH_ATTN_EXT = 14,
-    HTP_OP_SET_ROWS       = 15,
-    HTP_OP_SCALE          = 16,
-    HTP_OP_GET_ROWS       = 17,
+    HTP_OP_MUL = 0,
+    HTP_OP_ADD = 1,
+    HTP_OP_SUB = 2,
+    HTP_OP_DIV = 3,
+    HTP_OP_MUL_MAT,
+    HTP_OP_MUL_MAT_ID,
+    HTP_OP_RMS_NORM,
+    HTP_OP_UNARY_SILU,
+    HTP_OP_UNARY_GELU,
+    HTP_OP_GLU_SWIGLU,
+    HTP_OP_GLU_SWIGLU_OAI,
+    HTP_OP_GLU_GEGLU,
+    HTP_OP_SOFTMAX,
+    HTP_OP_ADD_ID,
+    HTP_OP_ROPE,
+    HTP_OP_FLASH_ATTN_EXT,
+    HTP_OP_SET_ROWS,
+    HTP_OP_GET_ROWS,
+    HTP_OP_SCALE,
+    HTP_OP_CPY,
+    HTP_OP_ARGSORT,
+    HTP_OP_SQR,
+    HTP_OP_SQRT,
+    HTP_OP_SUM_ROWS,
+    HTP_OP_SSM_CONV,
     INVALID
 };
 
-static inline size_t htp_type_block_size(uint32_t t) {
+static inline size_t htp_t_block_size(uint32_t t) {
     switch (t) {
         case HTP_TYPE_F32:
             return 1;
@@ -102,22 +108,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
     return 0;
 }
 
-static const char * htp_type_name(uint32_t t) {
-    switch (t) {
-        case HTP_TYPE_F32:
-            return "fp32";
-        case HTP_TYPE_F16:
-            return "fp16";
-        case HTP_TYPE_Q4_0:
-            return "q4_0";
-        case HTP_TYPE_Q8_0:
-            return "q8_0";
-        case HTP_TYPE_MXFP4:
-            return "mxfp4";
-    }
-    return 0;
-}
-
 // Internal types
 #define QK_Q4_0x4x2  256  // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
 #define QK_Q8_0x4x2  256  // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h
index 7c828ae6..2ef20936 100644
--- a/ggml/src/ggml-hexagon/htp/htp-ops.h
+++ b/ggml/src/ggml-hexagon/htp/htp-ops.h
@@ -4,11 +4,12 @@
 #include "htp-ctx.h"
 #include "htp-msg.h"
 #include "worker-pool.h"
-#include "ops-utils.h"
 
 #include 
 #include 
 
+#include 
+
 // ggml-common.h must be included prior to this header
 
 struct htp_spad {
@@ -40,40 +41,6 @@ struct htp_ops_context {
     worker_pool_context_t * wpool;      // worker pool
     uint32_t                n_threads;  // num threads
 
-    uint32_t src0_nrows_per_thread;
-    uint32_t src1_nrows_per_thread;
-
-    struct fastdiv_values src0_div1;  // fastdiv values for ne1
-    struct fastdiv_values src0_div2;  // fastdiv values for ne2
-    struct fastdiv_values src0_div3;  // fastdiv values for ne3
-    struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
-
-    struct fastdiv_values src1_div1;  // fastdiv values for ne1
-    struct fastdiv_values src1_div2;  // fastdiv values for ne2
-    struct fastdiv_values src1_div3;  // fastdiv values for ne3
-    struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
-
-    struct fastdiv_values src3_div1;  // fastdiv values for ne1
-    struct fastdiv_values src3_div2;  // fastdiv values for ne2
-    struct fastdiv_values src3_div3;  // fastdiv values for ne3
-    struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
-
-    struct fastdiv_values broadcast_rk2;
-    struct fastdiv_values broadcast_rk3;
-    struct fastdiv_values broadcast_rv2;
-    struct fastdiv_values broadcast_rv3;
-
-    struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
-    struct fastdiv_values mm_div_ne1;      // fastdiv values for ne1
-    struct fastdiv_values mm_div_r2;       // fastdiv values for ne12 / ne02
-    struct fastdiv_values mm_div_r3;       // fastdiv values for ne13 / ne03
-
-    struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
-    struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
-
-    struct fastdiv_values get_rows_div_ne10;      // fastdiv values for ne10
-    struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
-
     uint32_t flags;
 };
 
@@ -81,6 +48,7 @@ int op_matmul(struct htp_ops_context * octx);
 int op_matmul_id(struct htp_ops_context * octx);
 int op_binary(struct htp_ops_context * octx);
 int op_unary(struct htp_ops_context * octx);
+int op_sum_rows(struct htp_ops_context * octx);
 int op_activations(struct htp_ops_context * octx);
 int op_softmax(struct htp_ops_context * octx);
 int op_add_id(struct htp_ops_context * octx);
@@ -88,5 +56,8 @@ int op_rope(struct htp_ops_context * octx);
 int op_flash_attn_ext(struct htp_ops_context * octx);
 int op_set_rows(struct htp_ops_context * octx);
 int op_get_rows(struct htp_ops_context * octx);
+int op_cpy(struct htp_ops_context * octx);
+int op_argsort(struct htp_ops_context * octx);
+int op_ssm_conv(struct htp_ops_context * octx);
 
 #endif /* HTP_OPS_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h
new file mode 100644
index 00000000..82e34169
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h
@@ -0,0 +1,443 @@
+#ifndef HVX_ARITH_H
+#define HVX_ARITH_H
+
+#include 
+#include 
+#include 
+#include 
+
+#include "hvx-base.h"
+#include "hex-utils.h"
+
+//
+// Binary operations (add, mul, sub)
+//
+
+#define UNUSED(x) (void)(x)
+
+#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \
+    do {                                                                       \
+        dst_type * restrict vdst  = (dst_type *) dst;                          \
+        src0_type * restrict vsrc0 = (src0_type *) src0;                       \
+        src1_type * restrict vsrc1 = (src1_type *) src1;                       \
+                                                                               \
+        const uint32_t epv  = 128 / (elem_size);                               \
+        const uint32_t nvec = n / epv;                                         \
+        const uint32_t nloe = n % epv;                                         \
+                                                                               \
+        uint32_t i = 0;                                                        \
+                                                                               \
+        _Pragma("unroll(4)")                                                   \
+        for (; i < nvec; i++) {                                                \
+            vdst[i] = vec_op(vsrc0[i], vsrc1[i]);                              \
+        }                                                                      \
+        if (nloe) {                                                            \
+            HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]);                         \
+            vec_store((void *) &vdst[i], nloe * (elem_size), v);               \
+        }                                                                      \
+    } while(0)
+
+#if __HVX_ARCH__ < 79
+
+#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
+#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+
+#else
+
+#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
+#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+
+#endif
+
+#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b)
+#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b)
+#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b)
+
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float)
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16)
+
+// Dispatcher logic
+#define HVX_BINARY_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128)) { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \
+        } \
+    } else { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \
+        } \
+    } \
+}
+
+HVX_BINARY_DISPATCHER(hvx_add_f32)
+HVX_BINARY_DISPATCHER(hvx_sub_f32)
+HVX_BINARY_DISPATCHER(hvx_mul_f32)
+
+HVX_BINARY_DISPATCHER(hvx_add_f16)
+HVX_BINARY_DISPATCHER(hvx_sub_f16)
+HVX_BINARY_DISPATCHER(hvx_mul_f16)
+
+// Mul-Mul Optimized
+static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src0 % 128 == 0);
+    assert((unsigned long) src1 % 128 == 0);
+    assert((unsigned long) src2 % 128 == 0);
+
+    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;
+    HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0;
+    HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1;
+    HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2;
+
+    const uint32_t elem_size = sizeof(float);
+    const uint32_t epv  = 128 / elem_size;
+    const uint32_t nvec = num_elems / epv;
+    const uint32_t nloe = num_elems % epv;
+
+    uint32_t i = 0;
+
+    _Pragma("unroll(4)")
+    for (; i < nvec; i++) {
+        HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
+        vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
+    }
+
+    if (nloe) {
+        HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
+        HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]);
+        hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
+    }
+}
+
+// Scalar Operations
+
+#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro)   \
+    do {                                                                       \
+        dst_type * restrict vdst = (dst_type *) dst;                           \
+        src_type * restrict vsrc = (src_type *) src;                           \
+                                                                               \
+        const uint32_t epv  = 128 / (elem_size);                               \
+        const uint32_t nvec = n / epv;                                         \
+        const uint32_t nloe = n % epv;                                         \
+                                                                               \
+        uint32_t i = 0;                                                        \
+                                                                               \
+        _Pragma("unroll(4)")                                                   \
+        for (; i < nvec; i++) {                                                \
+            HVX_Vector v = vsrc[i];                                            \
+            vdst[i] = scalar_op_macro(v);                                      \
+        }                                                                      \
+        if (nloe) {                                                            \
+            HVX_Vector v = vsrc[i];                                            \
+            v = scalar_op_macro(v);                                            \
+            vec_store((void *) &vdst[i], nloe * (elem_size), v);               \
+        }                                                                      \
+    } while(0)
+
+#define HVX_OP_ADD_SCALAR_F32(v) \
+    ({ \
+        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
+        HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \
+        Q6_V_vmux_QVV(pred_inf, inf, out); \
+    })
+
+#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec)
+#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec)
+
+#define HVX_OP_ADD_SCALAR_F16(v) \
+    ({ \
+        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \
+        HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \
+        Q6_V_vmux_QVV(pred_inf, inf, out); \
+    })
+
+#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec)
+#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec)
+
+// Scalar Variants
+
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \
+static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src % 128 == 0); \
+    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    assert((uintptr_t) dst % 128 == 0); \
+    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    assert((uintptr_t) src % 128 == 0); \
+    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float)
+
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+
+// Dispatcher logic
+#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_aa(dst, src, val, num_elems); \
+    } else if (hex_is_aligned((void *) dst, 128)) { \
+        OP_NAME##_au(dst, src, val, num_elems); \
+    } else if (hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_ua(dst, src, val, num_elems); \
+    } else { \
+        OP_NAME##_uu(dst, src, val, num_elems); \
+    } \
+}
+
+HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float)
+
+HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16)
+
+// MIN Scalar variants
+
+#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v)
+
+static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+    assert((unsigned long) dst % 128 == 0);
+    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+    assert((unsigned long) src % 128 == 0);
+    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
+    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
+    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+}
+
+static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
+    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+        hvx_min_scalar_f32_aa(dst, src, val, num_elems);
+    } else if (hex_is_aligned((void *) dst, 128)) {
+        hvx_min_scalar_f32_au(dst, src, val, num_elems);
+    } else if (hex_is_aligned((void *) src, 128)) {
+        hvx_min_scalar_f32_ua(dst, src, val, num_elems);
+    } else {
+        hvx_min_scalar_f32_uu(dst, src, val, num_elems);
+    }
+}
+
+// CLAMP Scalar variants
+
+#define HVX_OP_CLAMP_SCALAR(v) \
+    ({ \
+        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \
+        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \
+        HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \
+        Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \
+    })
+
+static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+    assert((unsigned long) dst % 128 == 0);
+    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+    assert((unsigned long) src % 128 == 0);
+    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
+    const HVX_Vector min_vec = hvx_vec_splat_f32(min);
+    const HVX_Vector max_vec = hvx_vec_splat_f32(max);
+    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+}
+
+static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
+    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
+        hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems);
+    } else if (hex_is_aligned((void *) dst, 128)) {
+        hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems);
+    } else if (hex_is_aligned((void *) src, 128)) {
+        hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems);
+    } else {
+        hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems);
+    }
+}
+
+//
+// Square
+//
+
+#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store)           \
+    do {                                                                   \
+        dst_type * restrict vdst  = (dst_type *) dst;                      \
+        src_type * restrict vsrc = (src_type *) src;                       \
+                                                                           \
+        const uint32_t elem_size = sizeof(float);                          \
+        const uint32_t epv  = 128 / elem_size;                             \
+        const uint32_t nvec = n / epv;                                     \
+        const uint32_t nloe = n % epv;                                     \
+                                                                           \
+        uint32_t i = 0;                                                    \
+                                                                           \
+        _Pragma("unroll(4)")                                               \
+        for (; i < nvec; i++) {                                            \
+            vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]);                        \
+        }                                                                  \
+        if (nloe) {                                                        \
+            HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]);                   \
+            vec_store((void *) &vdst[i], nloe * elem_size, v);             \
+        }                                                                  \
+    } while(0)
+
+static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
+    if (hex_is_aligned((void *) dst, 128)) {
+        if (hex_is_aligned((void *) src, 128)) {
+            hvx_sqr_f32_aa(dst, src, num_elems);
+        } else {
+            hvx_sqr_f32_au(dst, src, num_elems);
+        }
+    } else {
+        if (hex_is_aligned((void *) src, 128)) {
+            hvx_sqr_f32_ua(dst, src, num_elems);
+        } else {
+            hvx_sqr_f32_uu(dst, src, num_elems);
+        }
+    }
+}
+
+#undef HVX_OP_ADD_F32
+#undef HVX_OP_SUB_F32
+#undef HVX_OP_MUL_F32
+#undef HVX_OP_ADD_F16
+#undef HVX_OP_SUB_F16
+#undef HVX_OP_MUL_F16
+#undef hvx_arith_loop_body
+#undef HVX_OP_ADD_SCALAR_F32
+#undef HVX_OP_SUB_SCALAR_F32
+#undef HVX_OP_MUL_SCALAR_F32
+#undef HVX_OP_ADD_SCALAR_F16
+#undef HVX_OP_SUB_SCALAR_F16
+#undef HVX_OP_MUL_SCALAR_F16
+#undef hvx_scalar_loop_body
+#undef HVX_OP_MIN_SCALAR
+#undef HVX_OP_CLAMP_SCALAR
+#undef DEFINE_HVX_BINARY_OP_VARIANTS
+#undef HVX_BINARY_DISPATCHER
+#undef UNUSED
+
+#endif // HVX_ARITH_H
diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h
new file mode 100644
index 00000000..578ca288
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-base.h
@@ -0,0 +1,240 @@
+#ifndef HVX_BASE_H
+#define HVX_BASE_H
+
+#include 
+#include 
+
+#include "hex-utils.h"
+#include "hvx-types.h"
+
+static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) {
+    // Rotate as needed.
+    v = Q6_V_vlalign_VVR(v, v, (size_t) dst);
+
+    uint32_t left_off  = (size_t) dst & 127;
+    uint32_t right_off = left_off + n;
+
+    HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst);
+    HVX_VectorPred qr     = Q6_Q_vsetq2_R(right_off);
+
+    if (right_off > 128) {
+        Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v);
+        // all 1's
+        qr = Q6_Q_vcmp_eq_VbVb(v, v);
+    }
+
+    ql_not = Q6_Q_or_QQn(ql_not, qr);
+    Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v);
+}
+
+static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) {
+    assert((unsigned long) dst % 128 == 0);
+    HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n));
+    Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v);
+}
+
+static inline HVX_Vector hvx_vec_splat_f32(float v) {
+    union { float  f; uint32_t i; } u = { .f = v };
+    return Q6_V_vsplat_R(u.i);
+}
+
+static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) {
+    union { __fp16 f; uint16_t i; } u = { .f = v };
+    return Q6_Vh_vsplat_R(u.i);
+}
+
+static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
+    // vdelta control to replicate first 4 bytes across all elements
+    static const uint8_t __attribute__((aligned(128))) repl[128] = {
+        0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
+    };
+
+    HVX_Vector ctrl = *(HVX_Vector *) repl;
+    return Q6_V_vdelta_VV(v, ctrl);
+}
+
+static inline float hvx_vec_get_f32(HVX_Vector v) {
+    float __attribute__((aligned(128))) x;
+    hvx_vec_store_a(&x, 4, v);
+    return x;
+}
+
+static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
+    int32_t __attribute__((aligned(128))) x;
+    hvx_vec_store_a(&x, 4, v);
+    return x;
+}
+
+static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
+    // abs by clearing the fp16 sign bit
+    HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
+    return Q6_V_vand_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) {
+    // neg by setting the fp16 sign bit
+    HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
+    return Q6_V_vxor_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) {
+    // abs by clearing the fp32 sign bit
+    HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
+    return Q6_V_vand_VV(v, mask);
+}
+
+static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) {
+#if __HVX_ARCH__ > 75
+    return Q6_Vsf_vfneg_Vsf(v);
+#else
+    // neg by setting the fp32 sign bit
+    HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
+    return Q6_V_vxor_VV(v, mask);
+#endif  // __HVX_ARCH__ > 75
+}
+
+static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
+    const HVX_Vector vnan_exp  = Q6_Vh_vsplat_R(0x7C00);
+    const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF);
+
+    // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s
+    HVX_VectorPred p_exp  = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp);
+    HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp));
+    return Q6_Q_and_QQ(p_exp, p_frac);
+}
+
+static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+    HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
+    HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
+    HVX_Vector  v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)));
+
+#if __HVX_ARCH__ < 79
+    // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
+    const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
+    HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
+    v = Q6_V_vmux_QVV(nan, neg_inf, v);
+#endif
+
+    return v;
+}
+
+/* Q6_Vsf_equals_Vw is only available on v73+.*/
+#if __HVX_ARCH__ < 73
+static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
+{
+    HVX_Vector const vzero = Q6_V_vzero();
+    HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
+    HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
+    HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
+    HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
+    HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
+    HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
+    return ret;
+}
+
+static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
+{
+    return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
+}
+#endif
+
+static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
+    // This looks complicated.
+    // Ideally should just be Q6_Vh_equals_Vhf(vin)
+    // but that instruction does not do proper rounding.
+
+    // convert to qf32, multiplying by 1.0 in the process.
+    HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
+
+    // 'in-range' values are +/32752.
+    // add 192K to it, convert to sf
+    HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
+    HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
+    HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
+
+    // for in-range cases, result is {163858... 229360} so the exponent is always 144.
+    // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
+    // Start by <<10 to get the final 'sign' bit in bit 15...
+    vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
+    vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
+
+    // now round down to 16
+    return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
+}
+
+#if __HVX_ARCH__ < 79
+
+static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y)
+{
+    HVX_VectorPair m = Q6_Wqf32_vmpy_VhfVhf(x, y);
+    HVX_Vector a0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(m), Q6_V_lo_W(acc)));
+    HVX_Vector a1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(m), Q6_V_hi_W(acc)));
+    return Q6_W_vcombine_VV(a1, a0);
+}
+
+#else
+
+static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y)
+{
+    return Q6_Wsf_vmpyacc_WsfVhfVhf(acc, x, y);
+}
+
+#endif
+
+#if __HVX_ARCH__ < 79
+
+static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
+    const HVX_Vector one    = Q6_Vh_vsplat_R(0x3C00); //  1.0 in IEEE FP16
+    HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
+    HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
+    HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
+    HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
+    return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
+}
+
+static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
+    const HVX_Vector one    = Q6_Vh_vsplat_R(0x3C00); //  1.0 in IEEE FP16
+    HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
+    HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
+    HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
+    HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
+    return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
+}
+
+static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_vadd_VhfVhf(a, b);
+}
+
+static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_vsub_VhfVhf(a, b);
+}
+
+static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_vmpy_VhfVhf(a, b);
+}
+
+#endif // __HVX_ARCH__ < 79
+
+#endif /* HVX_BASE_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h
new file mode 100644
index 00000000..851482e0
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h
@@ -0,0 +1,245 @@
+#ifndef HVX_COPY_H
+#define HVX_COPY_H
+
+#include 
+#include 
+#include 
+
+#include "hvx-base.h"
+
+#define hvx_splat_loop_body(dst_type, vec_store)                 \
+    do {                                                         \
+        dst_type * restrict vdst = (dst_type *) dst;             \
+                                                                 \
+        uint32_t nvec = n / (128 / elem_size);                   \
+        uint32_t nloe = n % (128 / elem_size);                   \
+                                                                 \
+        uint32_t i = 0;                                          \
+                                                                 \
+        _Pragma("unroll(4)")                                     \
+        for (; i < nvec; i++) {                                  \
+            vdst[i] = src;                                       \
+        }                                                        \
+        if (nloe) {                                              \
+            vec_store((void *) &vdst[i], nloe * elem_size, src); \
+        }                                                        \
+    } while(0)
+
+static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) {
+    hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) {
+    hvx_splat_a(dst,  hvx_vec_splat_f32(v), n, sizeof(float));
+}
+
+static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) {
+    hvx_splat_u(dst,  hvx_vec_splat_f32(v), n, sizeof(float));
+}
+
+static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) {
+    hvx_splat_u(dst,  hvx_vec_splat_f16(v), n, sizeof(__fp16));
+}
+
+static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) {
+    hvx_splat_u(dst,  hvx_vec_splat_f16(v), n, sizeof(__fp16));
+}
+
+#define hvx_copy_loop_body(dst_type, src_type, vec_store)            \
+    do {                                                             \
+        dst_type * restrict vdst = (dst_type *) dst;                 \
+        src_type * restrict vsrc = (src_type *) src;                 \
+                                                                     \
+        const uint32_t epv  = 128 / elem_size;                       \
+        const uint32_t nvec = n / epv;                               \
+        const uint32_t nloe = n % epv;                               \
+                                                                     \
+        uint32_t i = 0;                                              \
+                                                                     \
+        _Pragma("unroll(4)")                                         \
+        for (; i < nvec; i++) { vdst[i] = vsrc[i]; }                 \
+        if (nloe) {                                                  \
+            vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \
+        }                                                            \
+    } while(0)
+
+// Generic copy routines
+static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) {
+    hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
+static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_aa(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_au(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_ua(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp16 elements : source is aligned, destination is potentially unaligned
+static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_uu(dst, src, n, sizeof(__fp16));
+}
+
+// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
+static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_aa(dst, src, n, sizeof(float));
+}
+
+// copy n fp32 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_ua(dst, src, n, sizeof(float));
+}
+
+// copy n fp32 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_au(dst, src, n, sizeof(float));
+}
+
+// copy n fp32 elements : source is unaligned, destination unaligned
+static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_uu(dst, src, n, sizeof(float));
+}
+
+//// fp32 -> fp16
+
+#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store)                   \
+    do {                                                                            \
+        dst_type * restrict vdst = (dst_type *) dst;                                \
+        src_type * restrict vsrc = (src_type *) src;                                \
+                                                                                    \
+        const uint32_t elem_size = sizeof(__fp16);                                  \
+        const uint32_t epv  = 128 / elem_size;                                      \
+        const uint32_t nvec = n / epv;                                              \
+        const uint32_t nloe = n % epv;                                              \
+                                                                                    \
+        uint32_t i = 0;                                                             \
+                                                                                    \
+        _Pragma("unroll(4)")                                                        \
+        for (; i < nvec; i++) {                                                     \
+            vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]);                 \
+        }                                                                           \
+        if (nloe) {                                                                 \
+            HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]);            \
+            vec_store((void *) &vdst[i], nloe * elem_size, v);                      \
+        }                                                                           \
+    } while(0)
+
+// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned
+static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
+static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+//// fp16 -> fp32
+
+#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store)                   \
+    do {                                                                            \
+        dst_type * restrict vdst = (dst_type *) dst;                                \
+        src_type * restrict vsrc = (src_type *) src;                                \
+                                                                                    \
+        const HVX_Vector one = hvx_vec_splat_f16(1.0);                              \
+                                                                                    \
+        const uint32_t elem_size = sizeof(__fp16);                                  \
+        const uint32_t epv  = 128 / elem_size;                                      \
+        const uint32_t nvec = n / epv;                                              \
+              uint32_t nloe = n % epv;                                              \
+                                                                                    \
+        uint32_t i = 0;                                                             \
+                                                                                    \
+        _Pragma("unroll(4)")                                                        \
+        for (i = 0; i < nvec; ++i) {                                                \
+            HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
+            vdst[i*2]   = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p));                        \
+            vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p));                        \
+        }                                                                           \
+                                                                                    \
+        if (nloe) {                                                                 \
+            HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \
+                                                                                    \
+            HVX_Vector vd = Q6_V_lo_W(p);                                           \
+            i = 2 * i;                                                              \
+                                                                                    \
+            if (nloe >= 32) {                                                       \
+                vdst[i] = Q6_Vsf_equals_Vqf32(vd);                                  \
+                nloe -= 32; ++i; vd = Q6_V_hi_W(p);                                 \
+            }                                                                       \
+                                                                                    \
+            if (nloe) {                                                             \
+                vd = Q6_Vsf_equals_Vqf32(vd);                                       \
+                hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd);                \
+            }                                                                       \
+        }                                                                           \
+    } while(0)
+
+// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned
+static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned
+static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+#endif // HVX_COPY_H
diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h
new file mode 100644
index 00000000..05cefea0
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-div.h
@@ -0,0 +1,251 @@
+#ifndef HVX_DIV_H
+#define HVX_DIV_H
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "hvx-base.h"
+#include "hex-utils.h"
+#include "hvx-inverse.h"
+#include "hvx-arith.h"
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
+static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+    HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
+    HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32));
+    HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32));
+#else
+    HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
+    HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32);
+    HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32);
+#endif
+
+    HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const);
+    HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const);
+
+#if __HVX_ARCH__ < 79
+    HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
+#else
+    HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
+#endif
+    return res;
+}
+
+#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store)                     \
+    do {                                                                                \
+        dst_type * restrict vdst = (dst_type *) dst;                                    \
+        src_type * restrict vsrc = (src_type *) src;                                    \
+        HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                     \
+                                                                                        \
+        const uint32_t nvec = n / VLEN_FP16;                                            \
+        const uint32_t nloe = n % VLEN_FP16;                                            \
+                                                                                        \
+        uint32_t i = 0;                                                                 \
+                                                                                        \
+        _Pragma("unroll(4)")                                                            \
+        for (; i < nvec; i++) {                                                         \
+            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
+            vdst[i] = res;                                                              \
+        }                                                                               \
+        if (nloe) {                                                                     \
+            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                      \
+        }                                                                               \
+    } while(0)
+
+static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    assert((uintptr_t) dst % 128 == 0);
+    assert((uintptr_t) src % 128 == 0);
+    hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    assert((uintptr_t) dst % 128 == 0);
+    hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    assert((uintptr_t) src % 128 == 0);
+    hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32.
+static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+    // Convert first input to fp32
+    HVX_VectorPair vec1_to_f32   = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32));
+    HVX_Vector     vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32));
+
+    // Convert second input to fp32
+    HVX_VectorPair vec2_to_f32   = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32));
+    HVX_Vector     vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32));
+#else
+    // Convert first input to fp32
+    HVX_VectorPair vec1_to_f32   = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32);
+    HVX_Vector     vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32);
+
+    // Convert second input to fp32
+    HVX_VectorPair vec2_to_f32   = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32);
+    HVX_Vector     vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32);
+#endif
+
+    // Inverse second input in fp32
+    HVX_Vector     vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask);
+    HVX_Vector     vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask);
+
+    // Multiply first input by inverse of second, in fp32
+    HVX_Vector     div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0);
+    HVX_Vector     div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1);
+
+    // Convert back to fp16
+#if __HVX_ARCH__ < 79
+    HVX_Vector     recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
+#else
+    HVX_Vector     recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
+#endif
+
+    return recip;
+}
+
+#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store)                  \
+    do {                                                                                  \
+        dst_type * restrict vdst = (dst_type *) dst;                                      \
+        src0_type * restrict vsrc0 = (src0_type *) src0;                                  \
+        src1_type * restrict vsrc1 = (src1_type *) src1;                                  \
+                                                                                          \
+        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                        \
+        const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                 \
+                                                                                          \
+        const uint32_t nvec = n / VLEN_FP16;                                              \
+        const uint32_t nloe = n % VLEN_FP16;                                              \
+                                                                                          \
+        uint32_t i = 0;                                                                   \
+                                                                                          \
+        _Pragma("unroll(4)")                                                              \
+        for (; i < nvec; i++) {                                                           \
+            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+            vdst[i] = res;                                                                \
+        }                                                                                 \
+        if (nloe) {                                                                       \
+            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                        \
+        }                                                                                 \
+    } while(0)
+
+#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store)             \
+    do {                                                                             \
+        dst_type * restrict vdst = (dst_type *) dst;                                 \
+        src0_type * restrict vsrc0 = (src0_type *) src0;                             \
+        src1_type * restrict vsrc1 = (src1_type *) src1;                             \
+                                                                                     \
+        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                   \
+                                                                                     \
+        const uint32_t nvec = n / VLEN_FP32;                                         \
+        const uint32_t nloe = n % VLEN_FP32;                                         \
+                                                                                     \
+        uint32_t i = 0;                                                              \
+                                                                                     \
+        _Pragma("unroll(4)")                                                         \
+        for (; i < nvec; i++) {                                                      \
+            HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+            HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1);                     \
+            vdst[i] = res;                                                           \
+        }                                                                            \
+        if (nloe) {                                                                  \
+            HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+            HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1);                     \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res);                   \
+        }                                                                            \
+    } while(0)
+
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \
+} \
+
+// Dispatcher logic
+#define HVX_DIV_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128)) { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \
+        } \
+    } else { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \
+        } \
+    } \
+}
+
+DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body)
+DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body)
+
+HVX_DIV_DISPATCHER(hvx_div_f32)
+HVX_DIV_DISPATCHER(hvx_div_f16)
+
+#undef HVX_OP_MUL_F32
+
+#endif // HVX_DIV_H
diff --git a/ggml/src/ggml-hexagon/htp/hvx-dump.h b/ggml/src/ggml-hexagon/htp/hvx-dump.h
new file mode 100644
index 00000000..85201fc3
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-dump.h
@@ -0,0 +1,129 @@
+#ifndef HVX_DUMP_H
+#define HVX_DUMP_H
+
+#include 
+
+#include 
+#include 
+
+#include "hex-utils.h"
+#include "hvx-types.h"
+
+static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) {
+    HVX_VectorAlias u = { .v = v };
+
+    const uint32_t n0 = n / 16;
+    const uint32_t n1 = n % 16;
+    int            i  = 0;
+    for (; i < n0; i++) {
+        hex_dump_f16_line(pref, u.fp16 + (16 * i), 16);
+    }
+    if (n1) {
+        hex_dump_f16_line(pref, u.fp16 + (16 * i), n1);
+    }
+}
+
+static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
+    hvx_vec_dump_f16_n(pref, v, 64);
+}
+
+static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
+    HVX_VectorAlias u = { .v = v };
+
+    const uint32_t n0 = n / 16;
+    const uint32_t n1 = n % 16;
+    int            i  = 0;
+    for (; i < n0; i++) {
+        hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
+    }
+    if (n1) {
+        hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
+    }
+}
+
+static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        float      d[32];
+    } u = { .v = v };
+
+    FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ...  %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
+         u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
+}
+
+static void hvx_vec_dump_f32(char * pref, HVX_Vector v) {
+    hvx_vec_dump_f32_n(pref, v, 32);
+}
+
+static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int32_t    d[32];
+    } u = { .v = v };
+
+    for (int i = 0; i < 32 / 16; i++) {
+        hex_dump_int32_line(pref, u.d + (16 * i), 16);
+    }
+}
+
+static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int32_t    d[32];
+    } u = { .v = v };
+
+    FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
+         u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
+}
+
+static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int8_t     d[128];
+    } u = { .v = v };
+
+    FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
+         u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
+}
+
+static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        int8_t     d[128];
+    } u = { .v = v };
+
+    for (int i = 0; i < 128 / 16; i++) {
+        hex_dump_int8_line(pref, u.d + (16 * i), 16);
+    }
+}
+
+static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
+    union {
+        HVX_Vector v;
+        uint8_t    d[128];
+    } u = { .v = v };
+
+    for (int i = 0; i < 128 / 16; i++) {
+        hex_dump_uint8_line(pref, u.d + (16 * i), 16);
+    }
+}
+
+static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
+    typedef union {
+        HVX_Vector v;
+        int8_t     d[128];
+    } U;
+
+    U u0 = { .v = v0 };
+    U u1 = { .v = v1 };
+
+    for (int i = 0; i < n; i++) {
+        if (u0.d[i] != u1.d[i]) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+#endif /* HVX_DUMP_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c
deleted file mode 100644
index 21bf46a5..00000000
--- a/ggml/src/ggml-hexagon/htp/hvx-exp.c
+++ /dev/null
@@ -1,94 +0,0 @@
-#pragma clang diagnostic ignored "-Wunused-variable"
-#pragma clang diagnostic ignored "-Wunused-function"
-#pragma clang diagnostic ignored "-Wunused-but-set-variable"
-
-#include 
-#include 
-#include 
-#include 
-
-#define GGML_COMMON_DECL_C
-#include "ggml-common.h"
-#include "htp-ctx.h"
-#include "htp-dma.h"
-#include "htp-msg.h"
-#include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
-
-static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
-    const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
-
-    HVX_Vector out = hvx_vec_exp_fp32(in_vec);
-
-    return Q6_V_vmux_QVV(pred0, inf, out);
-}
-
-void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-    // assert((0 == unaligned_addr) || (0 == num_elems_whole));
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    HVX_Vector vec_out = Q6_V_vzero();
-
-    static const float kInf    = INFINITY;
-    static const float kMaxExp = 88.02f;  // log(INF)
-
-    const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
-    const HVX_Vector inf     = hvx_vec_splat_fp32(kInf);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
-        HVX_Vector * p_vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            if (true == negate) {
-                HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
-                *p_vec_out++          = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
-            } else {
-                *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf);
-            }
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
-            if (true == negate) {
-                HVX_Vector neg_vec_in                    = hvx_vec_neg_fp32(in);
-                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
-            } else {
-                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf);
-            }
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in = *(HVX_UVector *) srcf;
-
-        if (true == negate) {
-            HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
-
-            vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf);
-        } else {
-            vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf);
-        }
-
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
-    }
-}
diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h
new file mode 100644
index 00000000..44dfe232
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h
@@ -0,0 +1,215 @@
+#ifndef HVX_EXP_H
+#define HVX_EXP_H
+
+#include 
+#include 
+
+#include "hvx-base.h"
+#include "hvx-floor.h"
+
+#define EXP_COEFF_5 (0x39506967)  // 0.000198757 = 1/(7!)
+#define EXP_COEFF_4 (0x3AB743CE)  // 0.0013982   = 1/(6!)
+#define EXP_COEFF_3 (0x3C088908)  // 0.00833345  = 1/(5!)
+#define EXP_COEFF_2 (0x3D2AA9C1)  // 0.416658    = 1/(4!)
+#define EXP_COEFF_1 (0x3E2AAAAA)  // 0.16666667  = 1/(3!)
+#define EXP_COEFF_0 (0x3F000000)  // 0.5         = 1/(2!)
+#define EXP_LOGN2   (0x3F317218)  // ln(2)   = 0.6931471805
+#define EXP_LOG2E   (0x3FB8AA3B)  // log2(e) = 1/ln(2) = 1.4426950408
+#define EXP_ONE     (0x3f800000)  // 1.0
+#define EXP_RANGE_R (0x41a00000)  // 20.0
+#define EXP_RANGE_L (0xc1a00000)  // -20.0
+
+static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
+    HVX_Vector z_qf32_v;
+    HVX_Vector x_v;
+    HVX_Vector x_qf32_v;
+    HVX_Vector y_v;
+    HVX_Vector k_v;
+    HVX_Vector f_v;
+    HVX_Vector epsilon_v;
+    HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
+    HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
+    HVX_Vector E_const;
+    HVX_Vector zero_v = Q6_V_vzero();
+
+    // exp(x) is approximated as follows:
+    //   f = floor(x/ln(2)) = floor(x*log2(e))
+    //   epsilon = x - f*ln(2)
+    //   exp(x) = exp(epsilon+f*ln(2))
+    //          = exp(epsilon)*exp(f*ln(2))
+    //          = exp(epsilon)*2^f
+    //
+    //   Since epsilon is close to zero, it can be approximated with its Taylor series:
+    //            exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
+    //   Preserving the first eight elements, we get:
+    //            exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
+    //                   =  1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
+
+    HVX_Vector temp_v = in_vec;
+
+    // Clamp inputs to (-20.0, 20.0)
+    HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
+    HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
+
+    in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
+    in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
+
+    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
+    epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
+
+    //    f_v is the floating point result and k_v is the integer result
+    f_v = hvx_vec_floor_f32(epsilon_v);
+    k_v = hvx_vec_truncate_f32(f_v);
+
+    x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
+
+    //  x = x - f_v * logn2;
+    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
+    x_qf32_v  = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
+    // normalize before every QFloat's vmpy
+    x_qf32_v  = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
+
+    // z = x * x;
+    z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
+    z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
+
+    x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
+
+    // y = E4 + E5 * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_5);
+    y_v     = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
+    E_const = Q6_V_vsplat_R(EXP_COEFF_4);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E3 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_3);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E2 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_2);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E1 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_1);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = E0 + y * x;
+    E_const = Q6_V_vsplat_R(EXP_COEFF_0);
+    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
+    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = x + y * z;
+    y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
+    y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
+    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
+
+    // y = y + 1.0;
+    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
+
+    // insert exponents
+    //        y = ldexpf(y, k);
+    //    y_v += k_v; // qf32
+    // modify exponent
+
+    y_v = Q6_Vsf_equals_Vqf32(y_v);
+
+    // add k_v to the exponent of y_v
+    HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
+
+    y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
+    y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
+
+    // exponent cannot be negative; if overflow is detected, result is set to zero
+    HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
+
+    y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
+
+    y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
+
+    return y_v;
+}
+
+static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) {
+    const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
+
+    HVX_Vector out = hvx_vec_exp_f32(in_vec);
+
+    return Q6_V_vmux_QVV(pred0, inf, out);
+}
+
+static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
+    int left_over       = num_elems & (VLEN_FP32 - 1);
+    int num_elems_whole = num_elems - left_over;
+
+    int unaligned_addr = 0;
+    int unaligned_loop = 0;
+    if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) {
+        unaligned_addr = 1;
+    }
+    // assert((0 == unaligned_addr) || (0 == num_elems_whole));
+    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
+        unaligned_loop = 1;
+    }
+
+    HVX_Vector vec_out = Q6_V_vzero();
+
+    static const float kInf    = INFINITY;
+    static const float kMaxExp = 88.02f;  // log(INF)
+
+    const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
+    const HVX_Vector inf     = hvx_vec_splat_f32(kInf);
+
+    if (0 == unaligned_loop) {
+        HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
+        HVX_Vector * p_vec_out = (HVX_Vector *) dst;
+
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            if (true == negate) {
+                HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++);
+                *p_vec_out++          = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
+            } else {
+                *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf);
+            }
+        }
+    } else {
+        #pragma unroll(4)
+        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+
+            if (true == negate) {
+                HVX_Vector neg_vec_in                    = hvx_vec_neg_f32(in);
+                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
+            } else {
+                *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf);
+            }
+        }
+    }
+
+    if (left_over > 0) {
+        const float * srcf = (float *) src + num_elems_whole;
+        float *       dstf = (float *) dst + num_elems_whole;
+
+        HVX_Vector in = *(HVX_UVector *) srcf;
+
+        if (true == negate) {
+            HVX_Vector neg_vec_in = hvx_vec_neg_f32(in);
+
+            vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf);
+        } else {
+            vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf);
+        }
+
+        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
+    }
+}
+
+#endif /* HVX_EXP_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-floor.h b/ggml/src/ggml-hexagon/htp/hvx-floor.h
new file mode 100644
index 00000000..6a1bfde5
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-floor.h
@@ -0,0 +1,100 @@
+#ifndef HVX_FLOOR_H
+#define HVX_FLOOR_H
+
+#include 
+#include 
+
+#include "hvx-base.h"
+
+#define IEEE_VSF_EXPLEN   (8)
+#define IEEE_VSF_EXPBIAS  (127)
+#define IEEE_VSF_EXPMASK  (0xFF)
+#define IEEE_VSF_MANTLEN  (23)
+#define IEEE_VSF_MANTMASK (0x7FFFFF)
+#define IEEE_VSF_MIMPMASK (0x800000)
+
+static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) {
+    HVX_Vector mask_mant_v  = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
+    HVX_Vector mask_impl_v  = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
+    HVX_Vector const_zero_v = Q6_V_vzero();
+
+    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
+
+    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
+    expval_v &= IEEE_VSF_EXPMASK;
+    expval_v -= IEEE_VSF_EXPBIAS;
+
+    // negative exp == fractional value
+    HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
+
+    HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v;         // fractional bits - exp shift
+
+    HVX_Vector mant_v = in_vec & mask_mant_v;                  // obtain mantissa
+    HVX_Vector vout   = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v);  // add implicit 1.0
+
+    vout = Q6_Vw_vasr_VwVw(vout, rshift_v);                    // shift to obtain truncated integer
+    vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout);        // expval<0 -> 0
+
+    HVX_Vector neg_vout = -vout;
+
+    vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout);  // handle negatives
+
+    return (vout);
+}
+
+static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) {
+    HVX_Vector mask_mant_v    = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
+    HVX_Vector mask_impl_v    = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
+    HVX_Vector const_mnlen_v  = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
+    HVX_Vector const_zero_v   = Q6_V_vzero();
+    HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000);  // -1 IEEE vsf
+
+    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
+
+    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
+    expval_v &= IEEE_VSF_EXPMASK;
+    expval_v -= IEEE_VSF_EXPBIAS;
+
+    HVX_VectorPred q_negexp     = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
+    HVX_VectorPred q_expltmn    = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
+    HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
+    HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
+
+    // if expval < 0 (q_negexp)         // <0, floor is 0
+    //    if vin > 0
+    //       floor = 0
+    //    if vin < 0
+    //       floor = -1
+    // if expval < mant_len (q_expltmn) // >0, but fraction may exist
+    //    get sign (q_negative)
+    //    mask >> expval                // fraction bits to mask off
+    //    vout = ~(mask)                // apply mask to remove fraction
+    //    if (qneg)                     // negative floor is one less (more, sign bit for neg)
+    //      vout += ((impl_mask) >> expval)
+    //    if (mask && vin)
+    //      vout = vin
+    // else                             // already an integer
+    //    ;                             // no change
+
+    // compute floor
+    mask_mant_v >>= expval_v;
+    HVX_Vector neg_addin_v    = mask_impl_v >> expval_v;
+    HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
+    HVX_Vector vout           = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
+
+    HVX_Vector     mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v);  // chk if bits set
+    HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
+
+    HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v);        // frac bits to clear
+    HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v);  // clear frac bits
+
+    vout = in_vec;
+    vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout);         // expval0 -> 0
+    vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout);  // expval<0 x<0 -> -1
+
+    return vout;
+}
+
+#endif /* HVX_FLOOR_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c
deleted file mode 100644
index 4d70634f..00000000
--- a/ggml/src/ggml-hexagon/htp/hvx-inverse.c
+++ /dev/null
@@ -1,72 +0,0 @@
-#pragma clang diagnostic ignored "-Wunused-variable"
-#pragma clang diagnostic ignored "-Wunused-function"
-#pragma clang diagnostic ignored "-Wunused-but-set-variable"
-
-#include 
-#include 
-#include 
-#include 
-
-#define GGML_COMMON_DECL_C
-#include "ggml-common.h"
-#include "htp-ctx.h"
-#include "htp-dma.h"
-#include "htp-msg.h"
-#include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
-
-static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
-    HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
-
-    HVX_Vector           masked_out = Q6_V_vand_VV(out, nan_inf_mask);
-    const HVX_VectorPred pred       = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
-
-    return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
-}
-
-void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-    // assert((0 == unaligned_addr) || (0 == num_elems_whole));
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    static const uint32_t kNanInfMask  = 0x7f800000;
-    const HVX_Vector      nan_inf_mask = Q6_V_vsplat_R(kNanInfMask);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * p_vec_in  = (HVX_Vector *) src;
-        HVX_Vector * p_vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in                            = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in  = *(HVX_UVector *) srcf;
-        HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask);
-
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
-    }
-}
diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/ggml/src/ggml-hexagon/htp/hvx-inverse.h
new file mode 100644
index 00000000..f2054f45
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.h
@@ -0,0 +1,210 @@
+#ifndef HVX_INVERSE_H
+#define HVX_INVERSE_H
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "hvx-base.h"
+
+// ====================================================
+// FUNCTION: 1/(x+1)     y(0) = 1,  y(0.5) = 0.6667, y(1) = 0.5
+// Order:3; continuity: True; Ends forced: True
+// Mode: unsigned;   Result fractional bits: 14
+// Peak Error: 1.1295e-04  Rms Error: 2.8410e-05   Mean Error: 1.1370e-05
+//      32769  -32706   31252  -10589
+//      32590  -30635   22793   -4493
+//      32066  -27505   16481   -2348
+//      31205  -24054   11849   -1306
+
+static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
+    // input is 0..0xffff representing 0.0  .. 1.0
+    HVX_Vector p;
+    p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
+    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
+    p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
+    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
+    return p;  // signed result, 14 fractional bits
+}
+
+// Find reciprocal of fp16.
+// (1) first, convert to fp32, multiplying by 1.0; this is done to
+//    handle denormals. Ignoring sign and zero, result should be at
+//    least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
+//    (exponent in range [103,143])
+// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
+// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
+// (4) convert that to fp16
+// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
+//     the result with the max value.
+static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) {
+    HVX_Vector     em_mask  = Q6_Vh_vsplat_R(0x7FFF);
+    HVX_Vector     avals    = Q6_V_vand_VV(vals, em_mask);
+    HVX_VectorPred is_neg   = Q6_Q_vcmp_gt_VhVh(avals, vals);
+    // is too small to 1/x ? for 'standard' fp16, this would be 0x101
+    HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
+
+    HVX_VectorPair to_qf32  = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00));  // *1.0
+    HVX_Vector     to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
+    HVX_Vector     to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
+
+    // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
+    HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
+    // likewise extract the upper 16 from each, containing the exponents in range 103..142
+    HVX_Vector exp_u16  = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
+    //Get exponent in IEEE 32-bit representation
+    exp_u16             = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
+
+    // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
+    // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
+    // Use poly to transform to 1/x, with 14 fractional bits
+    //
+    HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
+
+    HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm);  //count leading zeros
+
+    // Get mantissa for 16-bit representation
+    HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
+
+    //Compute Reciprocal Exponent
+    HVX_Vector exp_recip =
+        Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
+    //Convert it for 16-bit representation
+    exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
+    exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
+
+    //Merge exponent and mantissa for reciprocal
+    HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
+    // map 'small' inputs to standard largest value 0x7bff
+    recip            = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
+    // add sign back
+    recip            = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
+    return recip;
+}
+
+static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) {
+    HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
+    HVX_Vector two_sf       = hvx_vec_splat_f32(2.0);
+
+    // First approximation
+    HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
+
+    HVX_Vector r_qf;
+
+    // Refine
+    r_qf = Q6_Vqf32_vmpy_VsfVsf(
+        i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
+    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
+        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
+    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
+        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
+
+    return Q6_Vsf_equals_Vqf32(r_qf);
+}
+
+static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
+    HVX_Vector out = hvx_vec_inverse_f32(v_sf);
+
+    HVX_Vector     masked_out = Q6_V_vand_VV(out, nan_inf_mask);
+    const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out);
+
+    return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
+}
+
+#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store)             \
+    do {                                                                     \
+        dst_type * restrict vdst = (dst_type *) dst;                         \
+        src_type * restrict vsrc = (src_type *) src;                         \
+                                                                             \
+        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);           \
+                                                                             \
+        const uint32_t nvec = n / VLEN_FP32;                                 \
+        const uint32_t nloe = n % VLEN_FP32;                                 \
+                                                                             \
+        uint32_t i = 0;                                                      \
+                                                                             \
+        _Pragma("unroll(4)")                                                 \
+        for (; i < nvec; i++) {                                              \
+             vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask);     \
+        }                                                                    \
+        if (nloe) {                                                          \
+            HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v);             \
+        }                                                                    \
+    } while(0)
+
+static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
+    HVX_Vector out = hvx_vec_inverse_f16(v_sf);
+
+    HVX_Vector     masked_out = Q6_V_vand_VV(out, nan_inf_mask);
+    const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out);
+
+    return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
+}
+
+#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store)             \
+    do {                                                                     \
+        dst_type * restrict vdst = (dst_type *) dst;                         \
+        src_type * restrict vsrc = (src_type *) src;                         \
+                                                                             \
+        const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00);              \
+                                                                             \
+        const uint32_t nvec = n / VLEN_FP16;                                 \
+        const uint32_t nloe = n % VLEN_FP16;                                 \
+                                                                             \
+        uint32_t i = 0;                                                      \
+                                                                             \
+        _Pragma("unroll(4)")                                                 \
+        for (; i < nvec; i++) {                                              \
+             vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask);     \
+        }                                                                    \
+        if (nloe) {                                                          \
+            HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v);             \
+        }                                                                    \
+    } while(0)
+
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
+static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    assert((uintptr_t) src % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \
+} \
+
+// Dispatcher logic
+#define HVX_INV_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_aa(dst, src, num_elems); \
+    } else if (hex_is_aligned((void *) dst, 128)) { \
+        OP_NAME##_au(dst, src, num_elems); \
+    } else if (hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_ua(dst, src, num_elems); \
+    } else { \
+        OP_NAME##_uu(dst, src, num_elems); \
+    } \
+}
+
+DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body)
+DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body)
+
+HVX_INV_DISPATCHER(hvx_inverse_f32)
+HVX_INV_DISPATCHER(hvx_inverse_f16)
+
+#endif // HVX_INVERSE_H
diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h
new file mode 100644
index 00000000..3c0073ef
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h
@@ -0,0 +1,296 @@
+#ifndef HVX_REDUCE_H
+#define HVX_REDUCE_H
+
+#include 
+#include 
+#include 
+#include 
+
+#include "hex-utils.h"
+#include "hvx-base.h"
+#include "hvx-types.h"
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // int32
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(sum, width);     // rotate right
+        sum   = Q6_Vw_vadd_VwVw(sum_t, sum);  // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) {
+    return hvx_vec_reduce_sum_n_i32(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // fp32 nbytes
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width);  // rotate right
+        sum   = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t);             // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
+    return hvx_vec_reduce_sum_n_qf32(in, 32);
+}
+
+#if __HVX_ARCH__ > 75
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) {
+    HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4);
+    HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4);
+    HVX_Vector  sum_sf01  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01));
+    HVX_Vector  sum_sf23  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23));
+
+    HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8);
+    HVX_Vector  sum_sf       = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123));
+
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
+    return sum_sf;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+    HVX_Vector  sum_sf  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
+    return sum_sf;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // fp32 nbytes
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(sum, width);       // rotate right
+        sum   = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) {
+    HVX_VectorPair sum_p01  = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4);
+    HVX_VectorPair sum_p23  = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4);
+    HVX_Vector     sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01));
+    HVX_Vector     sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23));
+
+    HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8);
+    HVX_Vector     sum_qf    = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123));
+
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
+    return Q6_Vsf_equals_Vqf32(sum_qf);
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+    HVX_Vector  sum_qf  = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
+    return Q6_Vsf_equals_Vqf32(sum_qf);
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // fp32 nbytes
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(sum, width);                               // rotate right
+        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t));  // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+#endif
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
+    return hvx_vec_reduce_sum_n_f32(in, 32);
+}
+
+static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 2;    // fp16 nbytes
+
+    HVX_Vector _max = in, _max_t;
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 2;    // fp32 nbytes
+
+    HVX_Vector _max_t;
+
+    _max = Q6_Vhf_vmax_VhfVhf(in, _max);
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 4;    // fp32 nbytes
+
+    HVX_Vector _max = in, _max_t;
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) {
+    unsigned total = 128;  // total vec nbytes
+    unsigned width = 4;    // fp32 nbytes
+
+    HVX_Vector _max_t;
+
+    _max = Q6_Vsf_vmax_VsfVsf(in, _max);
+    while (width < total) {
+        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
+        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
+        width  = width << 1;
+    }
+
+    return _max;
+}
+
+#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \
+    do {                                                                                    \
+        src_type * restrict vsrc = (src_type *) src;                                        \
+        HVX_Vector acc = init_vec;                                                          \
+                                                                                            \
+        const uint32_t elem_size = sizeof(float);                                           \
+        const uint32_t epv  = 128 / elem_size;                                              \
+        const uint32_t nvec = num_elems / epv;                                              \
+        const uint32_t nloe = num_elems % epv;                                              \
+                                                                                            \
+        uint32_t i = 0;                                                                     \
+        _Pragma("unroll(4)")                                                                \
+        for (; i < nvec; i++) {                                                             \
+            acc = vec_op(acc, vsrc[i]);                                                     \
+        }                                                                                   \
+        if (nloe) {                                                                         \
+            const float * srcf = (const float *) src + i * epv;                             \
+            HVX_Vector in = *(HVX_UVector *) srcf;                                          \
+            HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size);               \
+            acc = vec_op(acc, temp);                                                        \
+        }                                                                                   \
+        HVX_Vector v = reduce_op(acc);                                                      \
+        return scalar_reduce(v);                                                            \
+    } while(0)
+
+#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val)
+#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val)
+#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val))
+#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v)
+#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v))
+
+// Max variants
+
+static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) {
+    HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
+    assert((unsigned long) src % 128 == 0);
+    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
+}
+
+static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) {
+    HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]);
+    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR);
+}
+
+static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) {
+    if (hex_is_aligned((void *) src, 128)) {
+        return hvx_reduce_max_f32_a(src, num_elems);
+    } else {
+        return hvx_reduce_max_f32_u(src, num_elems);
+    }
+}
+
+// Sum variants
+
+static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) {
+    HVX_Vector init_vec = Q6_V_vsplat_R(0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) {
+    HVX_Vector init_vec = Q6_V_vsplat_R(0);
+    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) {
+    if (hex_is_aligned((void *) src, 128)) {
+        return hvx_reduce_sum_f32_a(src, num_elems);
+    } else {
+        return hvx_reduce_sum_f32_u(src, num_elems);
+    }
+}
+
+// Sum of squares variants
+
+static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) {
+    HVX_Vector init_vec = Q6_V_vsplat_R(0);
+    assert((uintptr_t) src % 128 == 0);
+    hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) {
+    HVX_Vector init_vec = Q6_V_vsplat_R(0);
+    hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR);
+}
+
+static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
+    if (hex_is_aligned((void *) src, 128)) {
+        return hvx_sum_of_squares_f32_a(src, num_elems);
+    } else {
+        return hvx_sum_of_squares_f32_u(src, num_elems);
+    }
+}
+
+#undef hvx_reduce_loop_body
+#undef HVX_REDUCE_MAX_OP
+#undef HVX_REDUCE_SUM_OP
+#undef HVX_REDUCE_MAX_SCALAR
+#undef HVX_REDUCE_SUM_SCALAR
+#undef HVX_SUM_SQ_OP
+
+#endif /* HVX_REDUCE_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-scale.h b/ggml/src/ggml-hexagon/htp/hvx-scale.h
new file mode 100644
index 00000000..c65c9863
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-scale.h
@@ -0,0 +1,133 @@
+#ifndef HVX_SCALE_H
+#define HVX_SCALE_H
+
+#include 
+#include 
+#include 
+
+#include "hvx-base.h"
+
+#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store)                       \
+    do {                                                                             \
+        dst_type * restrict vdst = (dst_type *) dst;                                 \
+        src_type * restrict vsrc = (src_type *) src;                                 \
+                                                                                     \
+        HVX_Vector vs = hvx_vec_splat_f32(scale);                                    \
+                                                                                     \
+        const uint32_t elem_size = sizeof(float);                                    \
+        const uint32_t epv = 128 / elem_size;                                        \
+        const uint32_t nvec = n / epv;                                               \
+        const uint32_t nloe = n % epv;                                               \
+                                                                                     \
+        uint32_t i = 0;                                                              \
+                                                                                     \
+        _Pragma("unroll(4)")                                                         \
+        for (; i < nvec; ++i) {                                                      \
+            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);                        \
+            vdst[i]      = Q6_Vsf_equals_Vqf32(v);                                   \
+        }                                                                            \
+        if (nloe) {                                                                  \
+            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);                        \
+            vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v));  \
+        }                                                                            \
+    } while(0)
+
+static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    assert((size_t) dst % 128 == 0);
+    assert((size_t) src % 128 == 0);
+    hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    assert((size_t) dst % 128 == 0);
+    hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    assert((size_t) src % 128 == 0);
+    hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    if (((size_t) dst & 127) == 0) {
+        if (((size_t) src & 127) == 0) {
+            hvx_scale_f32_aa(dst, src, n, scale);
+        } else {
+            hvx_scale_f32_au(dst, src, n, scale);
+        }
+    } else {
+        if (((size_t) src & 127) == 0) {
+            hvx_scale_f32_ua(dst, src, n, scale);
+        } else {
+            hvx_scale_f32_uu(dst, src, n, scale);
+        }
+    }
+}
+
+#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store)                \
+    do {                                                                             \
+        dst_type * restrict vdst = (dst_type *) dst;                                 \
+        src_type * restrict vsrc = (src_type *) src;                                 \
+                                                                                     \
+        HVX_Vector vs = hvx_vec_splat_f32(scale);                                    \
+        HVX_Vector vo = hvx_vec_splat_f32(offset);                                   \
+                                                                                     \
+        const uint32_t elem_size = sizeof(float);                                    \
+        const uint32_t epv = 128 / elem_size;                                        \
+        const uint32_t nvec = n / epv;                                               \
+        const uint32_t nloe = n % epv;                                               \
+                                                                                     \
+        uint32_t i = 0;                                                              \
+                                                                                     \
+        _Pragma("unroll(4)")                                                         \
+        for (; i < nvec; ++i) {                                                      \
+            HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
+            vdst[i] = Q6_Vsf_equals_Vqf32(v);                                        \
+        }                                                                            \
+        if (nloe) {                                                                  \
+            HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \
+            vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v));  \
+        }                                                                            \
+    } while(0)
+
+static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    assert((size_t) dst % 128 == 0);
+    assert((size_t) src % 128 == 0);
+    hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    assert((size_t) dst % 128 == 0);
+    hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    assert((size_t) src % 128 == 0);
+    hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    if (((size_t) dst & 127) == 0) {
+        if (((size_t) src & 127) == 0) {
+            hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
+        } else {
+            hvx_scale_offset_f32_au(dst, src, n, scale, offset);
+        }
+    } else {
+        if (((size_t) src & 127) == 0) {
+            hvx_scale_offset_f32_ua(dst, src, n, scale, offset);
+        } else {
+            hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
+        }
+    }
+}
+
+#endif // HVX_SCALE_H
diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c
deleted file mode 100644
index 15ac6469..00000000
--- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c
+++ /dev/null
@@ -1,49 +0,0 @@
-#pragma clang diagnostic ignored "-Wunused-variable"
-#pragma clang diagnostic ignored "-Wunused-function"
-#pragma clang diagnostic ignored "-Wunused-but-set-variable"
-
-#include 
-#include 
-#include 
-#include 
-
-#define GGML_COMMON_DECL_C
-#include "ggml-common.h"
-#include "htp-ctx.h"
-#include "htp-dma.h"
-#include "htp-msg.h"
-#include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
-
-#if 0
-// Reference algo used in hvx-utils
-static void fast_sigmoid_f32(const float*  restrict src, float* restrict dst, const int num_elems)
-{
-    const float c1 = 0.03138777;
-    const float c2 = 0.276281267;
-    const float c_log2f = 1.442695022;
-
-    int32_t store_ints[32];
-    float store_floats[3][32];
-
-    for (int i = 0; i < num_elems; i++)
-    {
-        float v = src0[i];
-
-        v *= c_log2f*0.5;
-        int intPart = (int)v;
-        float x = (v - intPart);
-        float xx = x * x;
-        float v1 = c_log2f + c2 * xx;
-        float v2 = x + xx * c1 * x;
-        float v3 = (v2 + v1);
-        *((int*)&v3) += intPart << 24;
-        float v4 = v2 - v1;
-        float v5 = v3 - v4;
-        float res = v3 / v5;
-
-        dst[i] = res;
-    }
-}
-#endif
diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h
new file mode 100644
index 00000000..09519327
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h
@@ -0,0 +1,141 @@
+#ifndef HVX_SIGMOID_H
+#define HVX_SIGMOID_H
+
+#include "hvx-base.h"
+
+#define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022
+#define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777
+#define FAST_SIGMOID_C2    (0x3e8d74bd)  // 0.276281267
+#define FAST_SIGMOID_C3    (0x3f000000)  // 0.5
+
+static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) {
+    v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
+    v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
+
+    HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v));
+    HVX_Vector x      = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
+    HVX_Vector xx     = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
+
+    HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
+    v1            = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
+
+    HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
+    v2            = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
+    v2            = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
+
+    HVX_Vector v3          = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
+    HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
+    v3_exponent            = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
+    v3_exponent            = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
+    v3                     = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
+
+    HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
+    HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
+
+    HVX_Vector res = hvx_vec_inverse_f32(v5);
+    res            = Q6_Vqf32_vmpy_VsfVsf(v3, res);
+
+    return Q6_Vsf_equals_Vqf32(res);
+}
+
+static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v,
+                                                         HVX_Vector one,
+                                                         HVX_Vector max_exp,
+                                                         HVX_Vector min_exp) {
+    const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
+    const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
+
+    HVX_Vector out = hvx_vec_fast_sigmoid_f32(v);
+    out            = Q6_V_vmux_QVV(pred_max, out, one);
+    return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
+}
+
+static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
+    // tanh(x) = 2 * sigmoid(2x) - 1
+    HVX_Vector two = hvx_vec_splat_f32(2.0f);
+    HVX_Vector one = hvx_vec_splat_f32(1.0f);
+    HVX_Vector x2  = Q6_Vqf32_vmpy_VsfVsf(x, two);
+
+    HVX_Vector max_exp = hvx_vec_splat_f32(87.f);
+    HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);
+
+    HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
+
+    HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
+    res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
+    return Q6_Vsf_equals_Vqf32(res);
+}
+
+#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store)    \
+    do {                                                        \
+        dst_type * restrict vdst = (dst_type *) dst;            \
+        src_type * restrict vsrc = (src_type *) src;            \
+                                                                \
+        const HVX_Vector one     = hvx_vec_splat_f32(1.f);      \
+        const HVX_Vector max_exp = hvx_vec_splat_f32(87.f);     \
+        const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f);    \
+                                                                \
+        const uint32_t epv  = 128 / sizeof(float);              \
+        const uint32_t nvec = n / epv;                          \
+        const uint32_t nloe = n % epv;                          \
+                                                                \
+        uint32_t i = 0;                                         \
+                                                                \
+        _Pragma("unroll(4)")                                    \
+        for (; i < nvec; i++) {                                 \
+             vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
+        }                                                       \
+        if (nloe) {                                             \
+             HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \
+             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
+        }                                                       \
+    } while(0)
+
+#define hvx_tanh_loop_body(dst_type, src_type, vec_store)       \
+    do {                                                        \
+        dst_type * restrict vdst = (dst_type *) dst;            \
+        src_type * restrict vsrc = (src_type *) src;            \
+                                                                \
+        const uint32_t epv  = 128 / sizeof(float);              \
+        const uint32_t nvec = n / epv;                          \
+        const uint32_t nloe = n % epv;                          \
+                                                                \
+        uint32_t i = 0;                                         \
+                                                                \
+        _Pragma("unroll(4)")                                    \
+        for (; i < nvec; i++) {                                 \
+             vdst[i] = hvx_vec_tanh_f32(vsrc[i]);               \
+        }                                                       \
+        if (nloe) {                                             \
+             HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]);        \
+             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
+        }                                                       \
+    } while(0)
+
+static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+#endif /* HVX_SIGMOID_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h
new file mode 100644
index 00000000..e31a1006
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h
@@ -0,0 +1,126 @@
+#ifndef HVX_SQRT_H
+#define HVX_SQRT_H
+
+#include 
+#include 
+
+#include "hex-utils.h"
+
+#include "hvx-base.h"
+
+#define RSQRT_CONST        0x5f3759df  // Constant for fast inverse square root calculation
+#define RSQRT_ONE_HALF     0x3f000000  // 0.5
+#define RSQRT_THREE_HALVES 0x3fc00000  // 1.5
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
+    //Algorithm :
+    //  x2 = input*0.5
+    //  y  = * (long *) &input
+    //  y  = 0x5f3759df - (y>>1)
+    //  y  = y*(threehalfs - x2*y*y)
+
+    HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
+    HVX_Vector onehalf    = Q6_V_vsplat_R(RSQRT_ONE_HALF);
+    HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
+
+    HVX_Vector x2, y, ypower2, temp;
+
+    x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
+    x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
+
+    y = Q6_Vw_vasr_VwR(in_vec, 1);
+    y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
+
+    // 1st iteration
+    ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
+    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+    temp    = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
+
+    // 2nd iteration
+    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
+    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
+    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
+
+    // 3rd iteration
+    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
+    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
+    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
+    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
+    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
+
+    return Q6_Vsf_equals_Vqf32(temp);
+}
+
+// Compute sqrt(x) as x*inv_sqrt(x)
+#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store)                \
+    do {                                                                     \
+        dst_type * restrict vdst = (dst_type *) dst;                         \
+        src_type * restrict vsrc = (src_type *) src;                         \
+                                                                             \
+        const uint32_t nvec = n / VLEN_FP32;                                 \
+        const uint32_t nloe = n % VLEN_FP32;                                 \
+                                                                             \
+        uint32_t i = 0;                                                      \
+                                                                             \
+        _Pragma("unroll(4)")                                                 \
+        for (; i < nvec; i++) {                                              \
+            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \
+            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \
+            vdst[i] = sqrt_res;                                              \
+        }                                                                    \
+        if (nloe) {                                                          \
+            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \
+            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res);      \
+        }                                                                    \
+    } while(0)
+
+static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
+    if ((unsigned long) dst % 128 == 0) {
+        if ((unsigned long) src % 128 == 0) {
+            hvx_sqrt_f32_aa(dst, src, num_elems);
+        } else {
+            hvx_sqrt_f32_au(dst, src, num_elems);
+        }
+    } else {
+        if ((unsigned long) src % 128 == 0) {
+            hvx_sqrt_f32_ua(dst, src, num_elems);
+        } else {
+            hvx_sqrt_f32_uu(dst, src, num_elems);
+        }
+    }
+}
+
+#endif /* HVX_SQRT_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-types.h b/ggml/src/ggml-hexagon/htp/hvx-types.h
new file mode 100644
index 00000000..d495a59f
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/hvx-types.h
@@ -0,0 +1,36 @@
+#ifndef HVX_TYPES_H
+#define HVX_TYPES_H
+
+#include 
+#include 
+
+#include 
+
+#define SIZEOF_FP32 (4)
+#define SIZEOF_FP16 (2)
+#define VLEN        (128)
+#define VLEN_FP32   (VLEN / SIZEOF_FP32)
+#define VLEN_FP16   (VLEN / SIZEOF_FP16)
+
+typedef union {
+    HVX_Vector v;
+    uint8_t    b[VLEN];
+    uint16_t   h[VLEN_FP16];
+    uint32_t   w[VLEN_FP32];
+    __fp16     fp16[VLEN_FP16];
+    float      fp32[VLEN_FP32];
+} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
+
+typedef struct {
+    HVX_Vector v[2];
+} HVX_Vector_x2;
+
+typedef struct {
+    HVX_Vector v[4];
+} HVX_Vector_x4;
+
+typedef struct {
+    HVX_Vector v[8];
+} HVX_Vector_x8;
+
+#endif /* HVX_TYPES_H */
diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c
deleted file mode 100644
index 29d73b86..00000000
--- a/ggml/src/ggml-hexagon/htp/hvx-utils.c
+++ /dev/null
@@ -1,1020 +0,0 @@
-#pragma clang diagnostic ignored "-Wunused-variable"
-#pragma clang diagnostic ignored "-Wunused-function"
-#pragma clang diagnostic ignored "-Wunused-but-set-variable"
-
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
-
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-
-#define GGML_COMMON_DECL_C
-#include "ggml-common.h"
-#include "hvx-utils.h"
-
-#define htp_binary_ops_preamble                                                                                \
-    int step_of_4 = num_elems >> 7;                                                                            \
-    int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6;                                              \
-    int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5;                  \
-    int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \
-                                                                                                               \
-    const uint8_t * restrict src0_curr = src0;                                                                 \
-    const uint8_t * restrict src1_curr = src1;                                                                 \
-    uint8_t * restrict dst_curr        = dst;
-
-void hvx_mul_f32(const uint8_t * restrict src0,
-                 const uint8_t * restrict src1,
-                 uint8_t * restrict dst,
-                 const int num_elems) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
-        (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-
-    bool handled_leftover = false;
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
-        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
-        }
-    } else {
-        int step_of_1 = num_elems_whole >> 5;  // divby 32, because 32 float = 128 bytes per HVX vector
-        int leftover_size = left_over * sizeof(float);
-
-
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
-        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
-        HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
-
-        HVX_Vector slinep;
-        HVX_Vector slinec;
-        HVX_Vector sline;
-        HVX_Vector sline2p;
-        HVX_Vector sline2c;
-        HVX_Vector sline2;
-
-        slinep  = *vec_in1++;
-        sline2p = *vec_in2++;
-        #pragma unroll(4)
-        for (int i = step_of_1 - 1; i > 0; i--) {
-            slinec  = *vec_in1++;
-            sline2c = *vec_in2++;
-            sline   = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
-            sline2  = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
-
-            *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
-            slinep                         = slinec;
-            sline2p                        = sline2c;
-        }
-        if (step_of_1 > 1) {
-            slinec  = htp_is_aligned(vec_in1, VLEN) && left_over == 0 ? slinep : *vec_in1++;
-            sline2c = htp_is_aligned(vec_in2, VLEN) && left_over == 0 ? sline2p : *vec_in2++;
-
-            sline                          = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
-            sline2                         = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
-            *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
-            slinep                         = slinec;
-            sline2p                        = sline2c;
-        }
-        if (left_over > 0) {
-            slinec = (is_in_one_chunk(vec_in1, leftover_size, VLEN) ? slinep : *vec_in1++);
-
-            sline   = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
-            sline2c = (is_in_one_chunk(vec_in2, leftover_size, VLEN) ? sline2p : *vec_in2++);
-            sline2  = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
-
-            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(sline, sline2);
-            hvx_vec_store_u(vec_out, leftover_size, Q6_Vsf_equals_Vqf32(out));
-            handled_leftover = true;
-        }
-    }
-
-
-    if (left_over > 0 && !handled_leftover) {
-        const float * src0f = (const float *) src0 + num_elems_whole;
-        const float * src1f = (const float *) src1 + num_elems_whole;
-        float *       dstf  = (float *) dst + num_elems_whole;
-
-        HVX_Vector in1 = *(HVX_UVector *) src0f;
-        HVX_Vector in2 = *(HVX_UVector *) src1f;
-
-        HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
-    }
-}
-
-void hvx_mul_f32_opt(const uint8_t * restrict src0,
-                     const uint8_t * restrict src1,
-                     uint8_t * restrict dst,
-                     const int num_elems) {
-    htp_binary_ops_preamble;
-
-    for (int i = 0; i < step_of_4; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
-
-        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
-
-        HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
-
-        src0_curr += 4 * VLEN;
-
-        HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b);
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
-
-        *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
-
-        HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b);
-
-        src1_curr += 4 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
-
-        dst_curr += 4 * VLEN;
-    }
-
-    for (int i = 0; i < step_of_2; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        src0_curr += 2 * VLEN;
-
-        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
-
-        src1_curr += 2 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        dst_curr += 2 * VLEN;
-    }
-
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector va = *(HVX_Vector *) src0_curr;
-
-        src0_curr += VLEN;
-
-        HVX_Vector vb = *(HVX_Vector *) src1_curr;
-
-        src1_curr += VLEN;
-
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
-
-        dst_curr += VLEN;
-    }
-
-    if (remaining > 0) {
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
-        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-void hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
-                         const uint8_t * restrict src1,
-                         const uint8_t * restrict src2,
-                         uint8_t * restrict dst,
-                         const int num_elems) {
-    const uint8_t * restrict src0_curr = src0;
-    const uint8_t * restrict src1_curr = src1;
-    const uint8_t * restrict src2_curr = src2;
-    uint8_t * restrict dst_curr        = dst;
-
-    int step_of_2 = num_elems >> 6;
-    int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5;
-    int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32;
-
-    for (int i = 0; i < step_of_2; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-        HVX_Vector v1c = *(HVX_Vector *) src2_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
-        HVX_Vector v1  = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN);
-
-        src0_curr += 2 * VLEN;
-
-        HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
-        HVX_Vector v2  = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c);
-
-        src1_curr += 2 * VLEN;
-        src2_curr += 2 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        dst_curr += 2 * VLEN;
-    }
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector va = *(HVX_Vector *) src0_curr;
-        src0_curr += VLEN;
-
-        HVX_Vector vb = *(HVX_Vector *) src1_curr;
-        src1_curr += VLEN;
-
-        HVX_Vector vc = *(HVX_Vector *) src2_curr;
-        src2_curr += VLEN;
-
-        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb);
-        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2);
-        dst_curr += VLEN;
-    }
-    if (remaining > 0) {
-        HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
-        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr);
-        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2));
-    }
-}
-
-void hvx_add_f32(const uint8_t * restrict src0,
-                 const uint8_t * restrict src1,
-                 uint8_t * restrict dst,
-                 const int num_elems) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
-        (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
-        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
-            HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
-
-            HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
-
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * src0f = (const float *) src0 + num_elems_whole;
-        const float * src1f = (const float *) src1 + num_elems_whole;
-        float *       dstf  = (float *) dst + num_elems_whole;
-
-        HVX_Vector in1 = *(HVX_UVector *) src0f;
-        HVX_Vector in2 = *(HVX_UVector *) src1f;
-
-        HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
-    }
-}
-
-void hvx_add_f32_opt(const uint8_t * restrict src0,
-                     const uint8_t * restrict src1,
-                     uint8_t * restrict dst,
-                     const int num_elems) {
-    htp_binary_ops_preamble;
-
-    for (int i = 0; i < step_of_4; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
-
-        HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
-
-        HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
-
-        src0_curr += 4 * VLEN;
-
-        HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b);
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
-
-        *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
-
-        HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b);
-
-        src1_curr += 4 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
-
-        dst_curr += 4 * VLEN;
-    }
-    for (int i = 0; i < step_of_2; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        src0_curr += 2 * VLEN;
-
-        HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
-
-        src1_curr += 2 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        dst_curr += 2 * VLEN;
-    }
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector va = *(HVX_Vector *) src0_curr;
-
-        src0_curr += VLEN;
-
-        HVX_Vector vb = *(HVX_Vector *) src1_curr;
-
-        src1_curr += VLEN;
-
-        HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
-
-        dst_curr += VLEN;
-    }
-    if (remaining > 0) {
-        HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
-        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
-    size_t left_over       = num_elems & (VLEN_FP32 - 1);
-    size_t num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    static const float kInf    = INFINITY;
-    const HVX_Vector   inf     = hvx_vec_splat_fp32(kInf);
-    HVX_Vector         val_vec = hvx_vec_splat_fp32(val);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector           in       = *vec_in1++;
-            const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
-            HVX_Vector           v        = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
-            v                             = Q6_Vsf_equals_Vqf32(v);
-            v                             = Q6_V_vmux_QVV(pred_inf, inf, v);
-            *vec_out++                    = v;
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
-            const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
-            HVX_Vector           out      = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
-            out                           = Q6_Vsf_equals_Vqf32(out);
-            out                           = Q6_V_vmux_QVV(pred_inf, inf, out);
-
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = out;
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in = *(HVX_UVector *) srcf;
-
-        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
-        HVX_Vector           out      = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
-        out                           = Q6_Vsf_equals_Vqf32(out);
-        out                           = Q6_V_vmux_QVV(pred_inf, inf, out);
-
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
-    }
-}
-
-void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
-    size_t left_over       = num_elems & (VLEN_FP32 - 1);
-    size_t num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    HVX_Vector val_vec = hvx_vec_splat_fp32(val);
-    bool handled_leftover = false;
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
-        }
-    } else {
-        int step_of_1 = num_elems >> 5;  // divby 32, because 32 float = 128 bytes per HVX vector
-        int leftover_size = left_over * sizeof(float);
-
-        HVX_Vector *  input_v_ptr  = (HVX_Vector *) src;
-        HVX_UVector * output_v_ptr = (HVX_UVector *) dst;
-
-        HVX_Vector slinep;
-        HVX_Vector slinec;
-        HVX_Vector sline;
-
-        slinep = *input_v_ptr++;
-
-        #pragma unroll(4)
-        for (int i = step_of_1 - 1; i > 0; i--) {
-            slinec                              = *input_v_ptr++;
-            sline                               = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
-            *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
-            /* Prepare slinep for next iteration */
-            slinep                              = slinec;
-        }
-
-        if (step_of_1 > 0) {
-            slinec = htp_is_aligned(input_v_ptr, VLEN) && left_over == 0 ? slinep : *input_v_ptr++;
-            sline  = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
-            *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
-
-            slinep = slinec;
-        }
-
-        if (leftover_size > 0) {
-            slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++);
-
-            sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
-
-            HVX_Vector sout = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
-            hvx_vec_store_u(output_v_ptr, leftover_size, sout);
-            handled_leftover = true;
-        }
-    }
-
-    if (left_over > 0 && !handled_leftover) {
-        const float * srcf = (const float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in = *(HVX_UVector *) srcf;
-
-        HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
-    }
-}
-
-void hvx_sub_f32(const uint8_t * restrict src0,
-                 const uint8_t * restrict src1,
-                 uint8_t * restrict dst,
-                 const int num_elems) {
-    size_t left_over       = num_elems & (VLEN_FP32 - 1);
-    size_t num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
-        (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
-        HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
-            HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
-
-            HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
-
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * src0f = (const float *) src0 + num_elems_whole;
-        const float * src1f = (const float *) src1 + num_elems_whole;
-        float *       dstf  = (float *) dst + num_elems_whole;
-
-        HVX_Vector in1 = *(HVX_UVector *) src0f;
-        HVX_Vector in2 = *(HVX_UVector *) src1f;
-
-        HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
-    }
-}
-
-void hvx_sub_f32_opt(const uint8_t * restrict src0,
-                     const uint8_t * restrict src1,
-                     uint8_t * restrict dst,
-                     const int num_elems) {
-    htp_binary_ops_preamble;
-
-    for (int i = 0; i < step_of_4; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
-
-        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
-
-        HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
-
-        src0_curr += 4 * VLEN;
-
-        HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b);
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
-
-        *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
-
-        HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b);
-
-        src1_curr += 4 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
-
-        dst_curr += 4 * VLEN;
-    }
-    for (int i = 0; i < step_of_2; i++) {
-        HVX_Vector v1a = *(HVX_Vector *) src0_curr;
-
-        HVX_Vector v1b = *(HVX_Vector *) src1_curr;
-
-        HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
-
-        HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
-
-        src0_curr += 2 * VLEN;
-
-        HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
-
-        src1_curr += 2 * VLEN;
-
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
-
-        dst_curr += 2 * VLEN;
-    }
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector va = *(HVX_Vector *) src0_curr;
-
-        src0_curr += VLEN;
-
-        HVX_Vector vb = *(HVX_Vector *) src1_curr;
-
-        src1_curr += VLEN;
-
-        HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb);
-
-        *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
-
-        dst_curr += VLEN;
-    }
-    if (remaining > 0) {
-        HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
-        hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
-    size_t left_over       = num_elems & (VLEN_FP32 - 1);
-    size_t num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    HVX_Vector val_vec = hvx_vec_splat_fp32(val);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
-            HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
-
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in = *(HVX_UVector *) srcf;
-
-        HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
-    }
-}
-
-float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    if (0 == htp_is_aligned((void *) src, VLEN)) {
-        FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n");
-    }
-
-    assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
-
-    HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
-
-    HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000);
-    HVX_Vector zero_vec    = Q6_V_vsplat_R(0x00000000);
-
-    #pragma unroll(4)
-    for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1);
-        sum_vec_acc  = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v);
-        vec_in1++;
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-
-        HVX_Vector vec_left = *(HVX_UVector *) srcf;
-
-        HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left);
-        HVX_Vector vec_tmp     = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32);
-
-        sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp);
-    }
-
-    HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc);
-    return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
-}
-
-float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if (0 == htp_is_aligned((void *) src, VLEN)) {
-        FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    HVX_Vector sum_vec  = Q6_V_vsplat_R(0x00000000);
-    HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * vec_in = (HVX_Vector *) src;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++);
-            sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
-            sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-
-        HVX_Vector vec_left = *(HVX_UVector *) srcf;
-        HVX_Vector vec_tmp  = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32);
-        // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp);
-        sum_vec             = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp);
-    }
-
-    HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec);
-    return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
-}
-
-float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if (0 == htp_is_aligned((void *) src, VLEN)) {
-        FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    HVX_Vector vec_max   = hvx_vec_splat_fp32(((const float *) src)[0]);
-    HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * restrict vec_in = (HVX_Vector *) src;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
-            vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-
-        HVX_Vector in = *(HVX_UVector *) srcf;
-
-        HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32);
-        vec_max         = Q6_Vsf_vmax_VsfVsf(vec_max, temp);
-    }
-
-    HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max);
-    return hvx_vec_get_fp32(v);
-}
-
-void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
-    size_t left_over       = num_elems & (VLEN_FP32 - 1);
-    size_t num_elems_whole = num_elems - left_over;
-    int unalign_address = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
-        unalign_address = 1;
-    }
-
-    const float * src_f = (const float *) src;
-
-    HVX_Vector vec_min = hvx_vec_splat_fp32(val);
-
-    if(unalign_address == 0){
-        HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector min_clamp    = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
-            *vec_out++ = (min_clamp);
-        }
-    }else{
-        HVX_UVector * restrict vec_in  = (HVX_Vector *) src;
-        HVX_UVector * restrict vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector min_clamp     = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
-            *vec_out++ = (min_clamp);
-        }
-    }
-
-    if (left_over > 0 ) {
-        const float * srcf = (const float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_UVector in = *(HVX_UVector *) srcf;
-
-        HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in);
-
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp));
-    }
-}
-
-void hvx_clamp_scalar_f32(const uint8_t * restrict src,
-                          const float limit_left,
-                          const float limit_right,
-                          uint8_t * restrict dst,
-                          const int num_elems) {
-    size_t left_over       = num_elems & (VLEN_FP32 - 1);
-    size_t num_elems_whole = num_elems - left_over;
-
-    int unalign_address = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
-        unalign_address = 1;
-    }
-
-    HVX_Vector range_left  = hvx_vec_splat_fp32(limit_left);
-    HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
-
-    if(unalign_address == 0){
-        HVX_Vector * restrict vec_in  = (HVX_Vector *) src;
-        HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
-
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in_vec = *vec_in++;
-            HVX_Vector temp_v = in_vec;
-
-            HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
-            HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
-
-            in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
-            in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
-
-            *vec_out++ = in_vec;
-        }
-
-    }else{
-
-        HVX_UVector * restrict vec_in  = (HVX_UVector *) src;
-        HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in_vec = *vec_in++;
-            HVX_Vector temp_v = in_vec;
-
-            HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
-            HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
-
-            in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
-            in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
-
-            *vec_out++ = in_vec;
-        }
-
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in_vec = *(HVX_UVector *) srcf;
-
-        HVX_Vector temp_v = in_vec;
-
-        HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
-        HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
-
-        in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
-        in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
-
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
-    }
-}
-
-
diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h
index 22876e6d..08343798 100644
--- a/ggml/src/ggml-hexagon/htp/hvx-utils.h
+++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h
@@ -1,1353 +1,26 @@
 #ifndef HVX_UTILS_H
 #define HVX_UTILS_H
 
-#include "ops-utils.h"
+#include "hex-utils.h"
 
-#include 
-#include 
+#include "hvx-types.h"
+#include "hvx-copy.h"
+#include "hvx-scale.h"
+#include "hvx-exp.h"
+#include "hvx-inverse.h"
+#include "hvx-reduce.h"
+#include "hvx-sigmoid.h"
+#include "hvx-sqrt.h"
+#include "hvx-arith.h"
+#include "hvx-div.h"
+#include "hvx-base.h"
 
-#define SIZEOF_FP32 (4)
-#define SIZEOF_FP16 (2)
-#define VLEN        (128)
-#define VLEN_FP32   (VLEN / SIZEOF_FP32)
-#define VLEN_FP16   (VLEN / SIZEOF_FP16)
-
-typedef union {
-    HVX_Vector v;
-    uint8_t    b[VLEN];
-    uint16_t   h[VLEN_FP16];
-    uint32_t   w[VLEN_FP32];
-    __fp16     fp16[VLEN_FP16];
-    float      fp32[VLEN_FP32];
-} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
-
-/* Q6_Vsf_equals_Vw is only available on v73+.*/
-#if __HVX_ARCH__ < 73
-static inline HVX_Vector int32_to_qfloat(HVX_Vector const in)
-{
-    HVX_Vector const vzero = Q6_V_vzero();
-    HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
-    HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
-    HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
-    HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
-    HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
-    HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
-    return ret;
-}
-
-static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
-{
-    return Q6_Vsf_equals_Vqf32(int32_to_qfloat(in));
-}
+#ifndef GATHER_TYPE
+#    if defined(__hexagon__)
+#        define GATHER_TYPE(_a) (intptr_t) _a
+#    else
+#        define GATHER_TYPE(_a) (HVX_Vector *) _a
+#    endif
 #endif
 
-static inline HVX_Vector hvx_vec_splat_fp32(float v) {
-    union {
-        float    f;
-        uint32_t i;
-    } fp32 = { .f = v };
-
-    return Q6_V_vsplat_R(fp32.i);
-}
-
-static inline HVX_Vector hvx_vec_splat_fp16(float v) {
-    union {
-        __fp16   f;
-        uint16_t i;
-    } fp16 = { .f = v };
-
-    return Q6_Vh_vsplat_R(fp16.i);
-}
-
-static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
-    // Rotate as needed.
-    v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
-
-    uint32_t left_off  = (size_t) addr & 127;
-    uint32_t right_off = left_off + n;
-
-    HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr);
-    HVX_VectorPred qr     = Q6_Q_vsetq2_R(right_off);
-
-    if (right_off > 128) {
-        Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v);
-        // all 1's
-        qr = Q6_Q_vcmp_eq_VbVb(v, v);
-    }
-
-    ql_not = Q6_Q_or_QQn(ql_not, qr);
-    Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v);
-}
-
-static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) {
-    assert((unsigned long) ptr % 128 == 0);
-
-    HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr);
-    HVX_VectorPred qr     = Q6_Q_vsetq2_R(n);
-    ql_not                = Q6_Q_or_QQn(ql_not, qr);
-    Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v);
-}
-
-static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
-    // vdelta control to replicate first 4 bytes across all elements
-    static const uint8_t __attribute__((aligned(128))) repl[128] = {
-        0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-        0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-    };
-
-    HVX_Vector ctrl = *(HVX_Vector *) repl;
-    return Q6_V_vdelta_VV(v, ctrl);
-}
-
-// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
-static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_Vector * restrict vdst = (HVX_Vector *) dst;
-    HVX_Vector * restrict vsrc = (HVX_Vector *) src;
-
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src % 128 == 0);
-
-    uint32_t nvec = n / 64;
-    uint32_t nloe = n % 64;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
-    }
-}
-
-// copy n fp16 elements : source is aligned, destination is potentially unaligned
-static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_UVector * restrict vdst = (HVX_UVector *) dst;
-    HVX_Vector * restrict vsrc  = (HVX_Vector *) src;
-
-    assert((unsigned long) src % 128 == 0);
-
-    uint32_t nvec = n / 64;
-    uint32_t nloe = n % 64;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
-    }
-}
-
-// copy n fp16 elements : source is aligned, destination is potentially unaligned
-static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;
-    HVX_UVector * restrict vsrc = (HVX_UVector *) src;
-
-    assert((unsigned long) dst % 128 == 0);
-
-    uint32_t nvec = n / 64;
-    uint32_t nloe = n % 64;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
-    }
-}
-
-// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
-static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_Vector * restrict vdst = (HVX_Vector *) dst;
-    HVX_Vector * restrict vsrc = (HVX_Vector *) src;
-
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src % 128 == 0);
-
-    uint32_t nvec = n / 32;
-    uint32_t nloe = n % 32;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
-    }
-}
-
-// copy n fp32 elements : source is aligned, destination is unaligned
-static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_UVector * restrict vdst = (HVX_UVector *) dst;
-    HVX_Vector * restrict vsrc  = (HVX_Vector *) src;
-
-    assert((unsigned long) src % 128 == 0);
-
-    uint32_t nvec = n / 32;
-    uint32_t nloe = n % 32;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
-    }
-}
-
-// copy n fp32 elements : source is unaligned, destination is aligned
-static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_Vector * restrict vdst  = (HVX_Vector *) dst;
-    HVX_UVector * restrict vsrc = (HVX_UVector *) src;
-
-    assert((unsigned long) dst % 128 == 0);
-
-    uint32_t nvec = n / 32;
-    uint32_t nloe = n % 32;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
-    }
-}
-
-// copy n fp32 elements : source is unaligned, destination unaligned
-static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_UVector * restrict vdst = (HVX_UVector *) dst;
-    HVX_UVector * restrict vsrc = (HVX_UVector *) src;
-
-    assert((unsigned long) dst % 128 == 0);
-
-    uint32_t nvec = n / 32;
-    uint32_t nloe = n % 32;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        HVX_Vector v = vsrc[i];
-        vdst[i]      = v;
-    }
-
-    if (nloe) {
-        HVX_Vector v = vsrc[i];
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
-    }
-}
-
-// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
-static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
-    HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
-
-    const HVX_Vector zero = Q6_V_vsplat_R(0);
-
-    uint32_t nvec = n / 64;
-    uint32_t nloe = n % 64;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
-        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
-        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
-        vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
-    }
-
-    if (nloe) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
-        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
-        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
-    }
-}
-
-// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
-static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
-    HVX_Vector  * restrict vsrc = (HVX_Vector *)  src; // fp32
-
-    const HVX_Vector zero = Q6_V_vsplat_R(0);
-
-    uint32_t nvec = n / 64;
-    uint32_t nloe = n % 64;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
-        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
-        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
-        vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
-    }
-
-    if (nloe) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
-        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
-        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
-    }
-}
-
-// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
-static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    HVX_Vector  * restrict vdst = (HVX_Vector *)  dst; // fp16
-    HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
-
-    const HVX_Vector zero = Q6_V_vsplat_R(0);
-
-    uint32_t nvec = n / 64;
-    uint32_t nloe = n % 64;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
-        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
-        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
-        vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
-    }
-
-    if (nloe) {
-        // Load y (fp32) and convert into fp16
-        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
-        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
-        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
-    }
-}
-
-// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
-static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
-    HVX_Vector * restrict vdst = (HVX_Vector *) dst;
-
-    HVX_Vector velem = hvx_vec_splat_fp32(elem);
-
-    assert((unsigned long) dst % 128 == 0);
-
-    uint32_t nvec = n / 32;
-    uint32_t nloe = n % 32;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (; i < nvec; i++) {
-        vdst[i] = velem;
-    }
-
-    if (nloe) {
-        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem);
-    }
-}
-
-
-/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */
-static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
-    uint32_t left_off  = (size_t) addr & (chunk_size - 1);
-    uint32_t right_off = left_off + n;
-    return right_off <= chunk_size;
-}
-
-static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
-    HVX_VectorAlias u = { .v = v };
-
-    const uint32_t n0 = n / 16;
-    const uint32_t n1 = n % 16;
-    int            i  = 0;
-    for (; i < n0; i++) {
-        htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16);
-    }
-    if (n1) {
-        htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1);
-    }
-}
-
-static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) {
-    hvx_vec_dump_fp16_n(pref, v, 64);
-}
-
-static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) {
-    union {
-        HVX_Vector v;
-        float      d[32];
-    } u = { .v = v };
-
-    const uint32_t n0 = n / 16;
-    const uint32_t n1 = n % 16;
-    int            i  = 0;
-    for (; i < n0; i++) {
-        htp_dump_fp32_line(pref, u.d + (16 * i), 16);
-    }
-    if (n1) {
-        htp_dump_fp32_line(pref, u.d + (16 * i), n1);
-    }
-}
-
-static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) {
-    union {
-        HVX_Vector v;
-        float      d[32];
-    } u = { .v = v };
-
-    FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ...  %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
-         u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
-}
-
-static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) {
-    hvx_vec_dump_fp32_n(pref, v, 32);
-}
-
-static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
-    union {
-        HVX_Vector v;
-        int32_t    d[32];
-    } u = { .v = v };
-
-    for (int i = 0; i < 32 / 16; i++) {
-        htp_dump_int32_line(pref, u.d + (16 * i), 16);
-    }
-}
-
-static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
-    union {
-        HVX_Vector v;
-        int32_t    d[32];
-    } u = { .v = v };
-
-    FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
-         u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
-}
-
-static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
-    union {
-        HVX_Vector v;
-        int8_t     d[128];
-    } u = { .v = v };
-
-    FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
-         u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
-}
-
-static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
-    union {
-        HVX_Vector v;
-        int8_t     d[128];
-    } u = { .v = v };
-
-    for (int i = 0; i < 128 / 16; i++) {
-        htp_dump_int8_line(pref, u.d + (16 * i), 16);
-    }
-}
-
-static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
-    union {
-        HVX_Vector v;
-        uint8_t    d[128];
-    } u = { .v = v };
-
-    for (int i = 0; i < 128 / 16; i++) {
-        htp_dump_uint8_line(pref, u.d + (16 * i), 16);
-    }
-}
-
-static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
-    typedef union {
-        HVX_Vector v;
-        int8_t     d[128];
-    } U;
-
-    U u0 = { .v = v0 };
-    U u1 = { .v = v1 };
-
-    for (int i = 0; i < n; i++) {
-        if (u0.d[i] != u1.d[i]) {
-            return false;
-        }
-    }
-
-    return true;
-}
-
-static inline float hvx_vec_get_fp32(HVX_Vector v) {
-    float __attribute__((aligned(128))) x;
-    hvx_vec_store_a(&x, 4, v);
-    return x;
-}
-
-static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) {
-    unsigned int total = n * 4;  // total vec nbytes
-    unsigned int width = 4;      // int32
-
-    HVX_Vector sum = in, sum_t;
-    while (width < total) {
-        sum_t = Q6_V_vror_VR(sum, width);     // rotate right
-        sum   = Q6_Vw_vadd_VwVw(sum_t, sum);  // elementwise sum
-        width = width << 1;
-    }
-    return sum;
-}
-
-static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) {
-    return hvx_vec_int32_reduce_sum_n(in, 32);
-}
-
-static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) {
-    unsigned int total = n * 4;  // total vec nbytes
-    unsigned int width = 4;      // fp32 nbytes
-
-    HVX_Vector sum = in, sum_t;
-    while (width < total) {
-        sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width);  // rotate right
-        sum   = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t);             // elementwise sum
-        width = width << 1;
-    }
-    return sum;
-}
-
-static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) {
-    return hvx_vec_qf32_reduce_sum_n(in, 32);
-}
-
-static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) {
-    unsigned int total = n * 4;  // total vec nbytes
-    unsigned int width = 4;      // fp32 nbytes
-
-    HVX_Vector sum = in, sum_t;
-    while (width < total) {
-        sum_t = Q6_V_vror_VR(sum, width);                               // rotate right
-        sum   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t));  // elementwise sum
-        width = width << 1;
-    }
-    return sum;
-}
-
-static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) {
-    return hvx_vec_fp32_reduce_sum_n(in, 32);
-}
-
-static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) {
-    unsigned total = 128;  // total vec nbytes
-    unsigned width = 2;    // fp16 nbytes
-
-    HVX_Vector _max = in, _max_t;
-    while (width < total) {
-        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
-        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
-        width  = width << 1;
-    }
-
-    return _max;
-}
-
-static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) {
-    unsigned total = 128;  // total vec nbytes
-    unsigned width = 2;    // fp32 nbytes
-
-    HVX_Vector _max_t;
-
-    _max = Q6_Vhf_vmax_VhfVhf(in, _max);
-    while (width < total) {
-        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
-        _max   = Q6_Vhf_vmax_VhfVhf(_max_t, _max);  // elementwise max
-        width  = width << 1;
-    }
-
-    return _max;
-}
-
-static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) {
-    unsigned total = 128;  // total vec nbytes
-    unsigned width = 4;    // fp32 nbytes
-
-    HVX_Vector _max = in, _max_t;
-    while (width < total) {
-        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
-        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
-        width  = width << 1;
-    }
-
-    return _max;
-}
-
-static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) {
-    unsigned total = 128;  // total vec nbytes
-    unsigned width = 4;    // fp32 nbytes
-
-    HVX_Vector _max_t;
-
-    _max = Q6_Vsf_vmax_VsfVsf(in, _max);
-    while (width < total) {
-        _max_t = Q6_V_vror_VR(_max, width);         // rotate right
-        _max   = Q6_Vsf_vmax_VsfVsf(_max_t, _max);  // elementwise max
-        width  = width << 1;
-    }
-
-    return _max;
-}
-
-static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
-    // abs by clearing the fp16 sign bit
-    HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
-    return Q6_V_vand_VV(v, mask);
-}
-
-static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
-    // neg by setting the fp16 sign bit
-    HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
-    return Q6_V_vxor_VV(v, mask);
-}
-
-static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
-    // abs by clearing the fp32 sign bit
-    HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
-    return Q6_V_vand_VV(v, mask);
-}
-
-static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
-#if __HVX_ARCH__ > 75
-    return Q6_Vsf_vfneg_Vsf(v);
-#else
-    // neg by setting the fp32 sign bit
-    HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
-    return Q6_V_vxor_VV(v, mask);
-#endif  // __HVX_ARCH__ > 75
-}
-
-// ====================================================
-// FUNCTION: 1/(x+1)     y(0) = 1,  y(0.5) = 0.6667, y(1) = 0.5
-// Order:3; continuity: True; Ends forced: True
-// Mode: unsigned;   Result fractional bits: 14
-// Peak Error: 1.1295e-04  Rms Error: 2.8410e-05   Mean Error: 1.1370e-05
-//      32769  -32706   31252  -10589
-//      32590  -30635   22793   -4493
-//      32066  -27505   16481   -2348
-//      31205  -24054   11849   -1306
-
-static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
-    // input is 0..0xffff representing 0.0  .. 1.0
-    HVX_Vector p;
-    p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
-    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
-    p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
-    p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
-    return p;  // signed result, 14 fractional bits
-}
-
-// Find reciprocal of fp16.
-// (1) first, convert to fp32, multiplying by 1.0; this is done to
-//    handle denormals. Ignoring sign and zero, result should be at
-//    least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
-//    (exponent in range [103,143])
-// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
-// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
-// (4) convert that to fp16
-// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
-//     the result with the max value.
-static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) {
-    HVX_Vector     em_mask  = Q6_Vh_vsplat_R(0x7FFF);
-    HVX_Vector     avals    = Q6_V_vand_VV(vals, em_mask);
-    HVX_VectorPred is_neg   = Q6_Q_vcmp_gt_VhVh(avals, vals);
-    // is too small to 1/x ? for 'standard' fp16, this would be 0x101
-    HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
-
-    HVX_VectorPair to_qf32  = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00));  // *1.0
-    HVX_Vector     to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
-    HVX_Vector     to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
-
-    // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
-    HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
-    // likewise extract the upper 16 from each, containing the exponents in range 103..142
-    HVX_Vector exp_u16  = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
-    //Get exponent in IEEE 32-bit representation
-    exp_u16             = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
-
-    // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
-    // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
-    // Use poly to transform to 1/x, with 14 fractional bits
-    //
-    HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
-
-    HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm);  //count leading zeros
-
-    // Get mantissa for 16-bit represenation
-    HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
-
-    //Compute Reciprocal Exponent
-    HVX_Vector exp_recip =
-        Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
-    //Convert it for 16-bit representation
-    exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
-    exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
-
-    //Merge exponent and mantissa for reciprocal
-    HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
-    // map 'small' inputs to standard largest value 0x7bff
-    recip            = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
-    // add sign back
-    recip            = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
-    return recip;
-}
-
-#define IEEE_VSF_EXPLEN   (8)
-#define IEEE_VSF_EXPBIAS  (127)
-#define IEEE_VSF_EXPMASK  (0xFF)
-#define IEEE_VSF_MANTLEN  (23)
-#define IEEE_VSF_MANTMASK (0x7FFFFF)
-#define IEEE_VSF_MIMPMASK (0x800000)
-
-static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) {
-    HVX_Vector mask_mant_v  = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
-    HVX_Vector mask_impl_v  = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
-    HVX_Vector const_zero_v = Q6_V_vzero();
-
-    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
-
-    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
-    expval_v &= IEEE_VSF_EXPMASK;
-    expval_v -= IEEE_VSF_EXPBIAS;
-
-    // negative exp == fractional value
-    HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
-
-    HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v;         // fractional bits - exp shift
-
-    HVX_Vector mant_v = in_vec & mask_mant_v;                  // obtain mantissa
-    HVX_Vector vout   = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v);  // add implicit 1.0
-
-    vout = Q6_Vw_vasr_VwVw(vout, rshift_v);                    // shift to obtain truncated integer
-    vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout);        // expval<0 -> 0
-
-    HVX_Vector neg_vout = -vout;
-
-    vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout);  // handle negatives
-
-    return (vout);
-}
-
-static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) {
-    HVX_Vector mask_mant_v    = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
-    HVX_Vector mask_impl_v    = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
-    HVX_Vector const_mnlen_v  = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
-    HVX_Vector const_zero_v   = Q6_V_vzero();
-    HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000);  // -1 IEEE vsf
-
-    HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
-
-    HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
-    expval_v &= IEEE_VSF_EXPMASK;
-    expval_v -= IEEE_VSF_EXPBIAS;
-
-    HVX_VectorPred q_negexp     = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
-    HVX_VectorPred q_expltmn    = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
-    HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
-    HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
-
-    // if expval < 0 (q_negexp)         // <0, floor is 0
-    //    if vin > 0
-    //       floor = 0
-    //    if vin < 0
-    //       floor = -1
-    // if expval < mant_len (q_expltmn) // >0, but fraction may exist
-    //    get sign (q_negative)
-    //    mask >> expval                // fraction bits to mask off
-    //    vout = ~(mask)                // apply mask to remove fraction
-    //    if (qneg)                     // negative floor is one less (more, sign bit for neg)
-    //      vout += ((impl_mask) >> expval)
-    //    if (mask && vin)
-    //      vout = vin
-    // else                             // already an integer
-    //    ;                             // no change
-
-    // compute floor
-    mask_mant_v >>= expval_v;
-    HVX_Vector neg_addin_v    = mask_impl_v >> expval_v;
-    HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
-    HVX_Vector vout           = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
-
-    HVX_Vector     mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v);  // chk if bits set
-    HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
-
-    HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v);        // frac bits to clear
-    HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v);  // clear frac bits
-
-    vout = in_vec;
-    vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout);         // expval0 -> 0
-    vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout);  // expval<0 x<0 -> -1
-
-    return vout;
-}
-
-static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
-    // This looks complicated.
-    // Ideally should just be Q6_Vh_equals_Vhf(vin)
-    // but that instruction does not do proper rounding.
-
-    // convert to qf32, multiplying by 1.0 in the process.
-    HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
-
-    // 'in-range' values are +/32752.
-    // add 192K to it, convert to sf
-    HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
-    HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
-    HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
-
-    // for in-range cases, result is {163858... 229360} so the exponent is always 144.
-    // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
-    // Start by <<10 to get the final 'sign' bit in bit 15...
-    vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
-    vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
-
-    // now round down to 16
-    return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
-}
-
-static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
-    HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
-    HVX_Vector two_sf       = hvx_vec_splat_fp32(2.0);
-
-    // First approximation
-    HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
-
-    HVX_Vector r_qf;
-
-    // Refine
-    r_qf = Q6_Vqf32_vmpy_VsfVsf(
-        i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
-    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
-        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
-    r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
-        r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
-
-    return Q6_Vsf_equals_Vqf32(r_qf);
-}
-
-#define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022
-#define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777
-#define FAST_SIGMOID_C2    (0x3e8d74bd)  // 0.276281267
-#define FAST_SIGMOID_C3    (0x3f000000)  // 0.5
-
-static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) {
-    v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
-    v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
-
-    HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v));
-    HVX_Vector x      = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
-    HVX_Vector xx     = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
-
-    HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
-    v1            = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
-
-    HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
-    v2            = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
-    v2            = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
-
-    HVX_Vector v3          = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
-    HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
-    v3_exponent            = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
-    v3_exponent            = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
-    v3                     = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
-
-    HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
-    HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
-
-    HVX_Vector res = hvx_vec_inverse_fp32(v5);
-    res            = Q6_Vqf32_vmpy_VsfVsf(v3, res);
-
-    return Q6_Vsf_equals_Vqf32(res);
-}
-
-#define EXP_COEFF_5 (0x39506967)  // 0.000198757 = 1/(7!)
-#define EXP_COEFF_4 (0x3AB743CE)  // 0.0013982   = 1/(6!)
-#define EXP_COEFF_3 (0x3C088908)  // 0.00833345  = 1/(5!)
-#define EXP_COEFF_2 (0x3D2AA9C1)  // 0.416658    = 1/(4!)
-#define EXP_COEFF_1 (0x3E2AAAAA)  // 0.16666667  = 1/(3!)
-#define EXP_COEFF_0 (0x3F000000)  // 0.5         = 1/(2!)
-#define EXP_LOGN2   (0x3F317218)  // ln(2)   = 0.6931471805
-#define EXP_LOG2E   (0x3FB8AA3B)  // log2(e) = 1/ln(2) = 1.4426950408
-#define EXP_ONE     (0x3f800000)  // 1.0
-#define EXP_RANGE_R (0x41a00000)  // 20.0
-#define EXP_RANGE_L (0xc1a00000)  // -20.0
-
-static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) {
-    HVX_Vector z_qf32_v;
-    HVX_Vector x_v;
-    HVX_Vector x_qf32_v;
-    HVX_Vector y_v;
-    HVX_Vector k_v;
-    HVX_Vector f_v;
-    HVX_Vector epsilon_v;
-    HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
-    HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
-    HVX_Vector E_const;
-    HVX_Vector zero_v = Q6_V_vzero();
-
-    // exp(x) is approximated as follows:
-    //   f = floor(x/ln(2)) = floor(x*log2(e))
-    //   epsilon = x - f*ln(2)
-    //   exp(x) = exp(epsilon+f*ln(2))
-    //          = exp(epsilon)*exp(f*ln(2))
-    //          = exp(epsilon)*2^f
-    //
-    //   Since epsilon is close to zero, it can be approximated with its Taylor series:
-    //            exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
-    //   Preserving the first eight elements, we get:
-    //            exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
-    //                   =  1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
-
-    HVX_Vector temp_v = in_vec;
-
-    // Clamp inputs to (-20.0, 20.0)
-    HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
-    HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
-
-    in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
-    in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
-
-    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
-    epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
-
-    //    f_v is the floating point result and k_v is the integer result
-    f_v = hvx_vec_floor_fp32(epsilon_v);
-    k_v = hvx_vec_truncate_fp32(f_v);
-
-    x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
-
-    //  x = x - f_v * logn2;
-    epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
-    x_qf32_v  = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
-    // normalize before every QFloat's vmpy
-    x_qf32_v  = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
-
-    // z = x * x;
-    z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
-    z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
-
-    x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
-
-    // y = E4 + E5 * x;
-    E_const = Q6_V_vsplat_R(EXP_COEFF_5);
-    y_v     = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
-    E_const = Q6_V_vsplat_R(EXP_COEFF_4);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
-
-    // y = E3 + y * x;
-    E_const = Q6_V_vsplat_R(EXP_COEFF_3);
-    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
-
-    // y = E2 + y * x;
-    E_const = Q6_V_vsplat_R(EXP_COEFF_2);
-    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
-
-    // y = E1 + y * x;
-    E_const = Q6_V_vsplat_R(EXP_COEFF_1);
-    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
-
-    // y = E0 + y * x;
-    E_const = Q6_V_vsplat_R(EXP_COEFF_0);
-    y_v     = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
-    y_v     = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
-
-    // y = x + y * z;
-    y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
-    y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
-    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
-
-    // y = y + 1.0;
-    y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
-
-    // insert exponents
-    //        y = ldexpf(y, k);
-    //    y_v += k_v; // qf32
-    // modify exponent
-
-    y_v = Q6_Vsf_equals_Vqf32(y_v);
-
-    // add k_v to the exponent of y_v
-    HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
-
-    y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
-    y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
-
-    // exponent cannot be negative; if overflow is detected, result is set to zero
-    HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
-
-    y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
-
-    y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
-
-    return y_v;
-}
-
-#define RSQRT_CONST        0x5f3759df  // Constant for fast inverse square root calculation
-#define RSQRT_ONE_HALF     0x3f000000  // 0.5
-#define RSQRT_THREE_HALVES 0x3fc00000  // 1.5
-
-static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
-    //Algorithm :
-    //  x2 = input*0.5
-    //  y  = * (long *) &input
-    //  y  = 0x5f3759df - (y>>2)
-    //  y  = y*(threehalfs - x2*y*y)
-
-    HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
-    HVX_Vector onehalf    = Q6_V_vsplat_R(RSQRT_ONE_HALF);
-    HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
-
-    HVX_Vector x2, y, ypower2, temp;
-
-    x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
-    x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
-
-    y = Q6_Vw_vasr_VwR(in_vec, 1);
-    y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
-
-    // 1st iteration
-    ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
-    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
-    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
-    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
-    temp    = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
-
-    // 2nd iteration
-    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
-    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
-    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
-    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
-    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
-    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
-
-    // 3rd iteration
-    y       = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
-    ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
-    ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
-    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
-    temp    = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
-    temp    = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
-
-    return Q6_Vsf_equals_Vqf32(temp);
-}
-
-static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
-                                                         HVX_Vector one,
-                                                         HVX_Vector max_exp,
-                                                         HVX_Vector min_exp) {
-    const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v);
-    const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp);
-
-    HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v);
-    out            = Q6_V_vmux_QVV(pred_max, out, one);
-    return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
-}
-
-static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
-    // tanh(x) = 2 * sigmoid(2x) - 1
-    HVX_Vector two = hvx_vec_splat_fp32(2.0f);
-    HVX_Vector one = hvx_vec_splat_fp32(1.0f);
-    HVX_Vector x2  = Q6_Vqf32_vmpy_VsfVsf(x, two);
-
-    static const float kMinExp = -87.f;  // 0
-    static const float kMaxExp = 87.f;   // 1
-    HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
-    HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
-
-    HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
-
-    HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
-    res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
-    return Q6_Vsf_equals_Vqf32(res);
-}
-
-static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
-    int step_of_1 = num_elems >> 5;
-    int remaining = num_elems - step_of_1 * VLEN_FP32;
-
-    const HVX_Vector * restrict v_src = (HVX_Vector *) src;
-    HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
-
-    static const float kMinExp = -87.f;  // 0
-    static const float kMaxExp = 87.f;   // 1
-
-    const HVX_Vector one     = hvx_vec_splat_fp32(1.f);
-    const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
-    const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
-
-    #pragma unroll(4)
-    for (int i = 0; i < step_of_1; i++) {
-        v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp);
-    }
-
-    if (remaining > 0) {
-        const float * srcf = ((const float *) src) + step_of_1* VLEN_FP32;
-        float *       dstf = (float *) dst + step_of_1*VLEN_FP32;
-
-        HVX_Vector in  = *(HVX_UVector *) srcf;
-        HVX_Vector out = hvx_vec_fast_sigmoid_fp32_guard(in, one, max_exp, min_exp);
-        hvx_vec_store_u((void *) dstf, remaining * SIZEOF_FP32, out);
-    }
-}
-
-static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){
-    int step_of_1 = num_elems >> 5;  // divby 32, because 32 float = 128 bytes per HVX vector
-    int leftover = num_elems - (step_of_1 * VLEN_FP32);
-
-    int32_t leftover_size = leftover * sizeof(float);
-
-    static const float kMinExp = -87.f;  // 0
-    static const float kMaxExp = 87.f;   // 1
-
-    const HVX_Vector one     = hvx_vec_splat_fp32(1.f);
-    const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
-    const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
-
-    const float *input = (float *)src;
-    float *output = (float *)dst;
-
-    HVX_Vector *  input_v_ptr  = (HVX_Vector *) input;
-    HVX_UVector * output_v_ptr = (HVX_UVector *) output;
-
-    HVX_Vector slinep;
-    HVX_Vector slinec;
-    HVX_Vector sline;
-
-    slinep = *input_v_ptr++;
-    #pragma unroll(4)
-    for (int i = step_of_1 - 1; i > 0; i--) {
-        slinec                              = *input_v_ptr++;
-        sline                               = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
-        *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
-        /* Prepare slinep for next iteration */
-        slinep                              = slinec;
-    }
-
-    if (step_of_1 > 0) {
-        slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++;
-        sline  = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
-        *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
-        ;
-
-        slinep = slinec;
-    }
-    if (leftover > 0) {
-        slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128) ? slinep : *input_v_ptr++);
-
-        sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
-
-        HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
-        hvx_vec_store_u(output_v_ptr, leftover_size, sout);
-    }
-}
-
-static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
-    int nvec = n / VLEN_FP32;
-    int nloe = n % VLEN_FP32;
-
-    HVX_Vector vs = hvx_vec_splat_fp32(scale);
-
-    HVX_Vector * vsrc = (HVX_Vector *) src;
-    HVX_Vector * vdst = (HVX_Vector *) dst;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (i = 0; i < nvec; ++i) {
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
-        vdst[i]      = Q6_Vsf_equals_Vqf32(v);
-    }
-
-    if (nloe) {
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
-        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
-    int nvec = n / VLEN_FP32;
-    int nloe = n % VLEN_FP32;
-
-    HVX_Vector vs = hvx_vec_splat_fp32(scale);
-
-    HVX_UVector * vsrc = (HVX_UVector *) src;
-    HVX_UVector * vdst = (HVX_UVector *) dst;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (i = 0; i < nvec; ++i) {
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
-        vdst[i]      = Q6_Vsf_equals_Vqf32(v);
-    }
-
-    if (nloe) {
-        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
-        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
-    if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
-        hvx_scale_f32_aa(dst, src, n, scale);
-    } else {
-        hvx_scale_f32_uu(dst, src, n, scale);
-    }
-}
-
-static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
-    int nvec = n / VLEN_FP32;
-    int nloe = n % VLEN_FP32;
-
-    HVX_Vector vs = hvx_vec_splat_fp32(scale);
-    HVX_Vector vo = hvx_vec_splat_fp32(offset);
-
-    HVX_Vector * vsrc = (HVX_Vector *) src;
-    HVX_Vector * vdst = (HVX_Vector *) dst;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (i = 0; i < nvec; ++i) {
-        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
-        vdst[i] = Q6_Vsf_equals_Vqf32(v);
-    }
-
-    if (nloe) {
-        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
-        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
-    int nvec = n / VLEN_FP32;
-    int nloe = n % VLEN_FP32;
-
-    HVX_Vector vs = hvx_vec_splat_fp32(scale);
-    HVX_Vector vo = hvx_vec_splat_fp32(offset);
-
-    HVX_UVector * vsrc = (HVX_UVector *) src;
-    HVX_UVector * vdst = (HVX_UVector *) dst;
-
-    uint32_t i = 0;
-
-    #pragma unroll(4)
-    for (i = 0; i < nvec; ++i) {
-        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
-        vdst[i] = Q6_Vsf_equals_Vqf32(v);
-    }
-
-    if (nloe) {
-        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
-        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
-    }
-}
-
-static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
-    if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
-        hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
-    } else {
-        hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
-    }
-}
-
-float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
-void  hvx_mul_f32(const uint8_t * restrict src0,
-                  const uint8_t * restrict src1,
-                  uint8_t * restrict dst,
-                  const int num_elems);
-void  hvx_mul_f32_opt(const uint8_t * restrict src0,
-                      const uint8_t * restrict src1,
-                      uint8_t * restrict dst,
-                      const int num_elems);
-void  hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
-                          const uint8_t * restrict src1,
-                          const uint8_t * restrict src2,
-                          uint8_t * restrict dst,
-                          const int num_elems);
-void  hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
-void  hvx_add_f32(const uint8_t * restrict src0,
-                  const uint8_t * restrict src1,
-                  uint8_t * restrict dst,
-                  const int num_elems);
-void  hvx_add_f32_opt(const uint8_t * restrict src0,
-                      const uint8_t * restrict src1,
-                      uint8_t * restrict dst,
-                      const int num_elems);
-void  hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
-void  hvx_sub_f32(const uint8_t * restrict src0,
-                  const uint8_t * restrict src1,
-                  uint8_t * restrict dst,
-                  const int num_elems);
-void  hvx_sub_f32_opt(const uint8_t * restrict src0,
-                      const uint8_t * restrict src1,
-                      uint8_t * restrict dst,
-                      const int num_elems);
-void  hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
-void  hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
-void  hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
-void  hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
-float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems);
-float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems);
-void  hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
-void  hvx_clamp_scalar_f32(const uint8_t * restrict src,
-                           const float limit_left,
-                           const float limit_right,
-                           uint8_t * restrict dst,
-                           const int num_elems);
-
 #endif /* HVX_UTILS_H */
diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c
index 24b3e90e..3f99dbb3 100644
--- a/ggml/src/ggml-hexagon/htp/main.c
+++ b/ggml/src/ggml-hexagon/htp/main.c
@@ -1,17 +1,13 @@
 #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
 #pragma clang diagnostic ignored "-Wunused-function"
 
-#define FARF_ERROR  1
-#define FARF_HIGH   1
-#define FARF_MEDIUM 0
-#define FARF_LOW    0
+#include 
+#include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -19,13 +15,14 @@
 #include 
 #include 
 
+#include "hex-dma.h"
+#include "hex-utils.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "ops-utils.h"
 #include "worker-pool.h"
 
 AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
@@ -192,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
     // otherwise we'll release it once we're done with the current Op.
 
     if (ctx->vtcm_inuse) {
-        ctx->vtcm_needs_release = false;
+        ctx->vtcm_needs_release = true;
         return 0;
     }
 
@@ -362,14 +359,14 @@ struct profile_data {
 
 static inline void profile_start(struct profile_data * d) {
     d->usecs  = HAP_perf_get_qtimer_count();
-    d->cycles = htp_get_cycles();
-    d->pkts   = htp_get_pktcnt();
+    d->cycles = hex_get_cycles();
+    d->pkts   = hex_get_pktcnt();
 }
 
 static inline void profile_stop(struct profile_data * d) {
     d->usecs  = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
-    d->cycles = htp_get_cycles() - d->cycles;
-    d->pkts   = htp_get_pktcnt() - d->pkts;
+    d->cycles = hex_get_cycles() - d->cycles;
+    d->pkts   = hex_get_pktcnt() - d->pkts;
 }
 
 static int send_htp_rsp(struct htp_context *     c,
@@ -443,6 +440,82 @@ static void proc_matmul_req(struct htp_context *     ctx,
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[1];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[1].fd;
+    rsp_bufs[0].ptr    = bufs[1].ptr;
+    rsp_bufs[0].offset = bufs[1].offset;
+    rsp_bufs[0].size   = bufs[1].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_argsort(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[1];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[1].fd;
+    rsp_bufs[0].ptr    = bufs[1].ptr;
+    rsp_bufs[0].offset = bufs[1].offset;
+    rsp_bufs[0].size   = bufs[1].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_cpy(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
     struct dspqueue_buffer rsp_bufs[1];
 
@@ -645,6 +718,86 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[1].fd;
+    rsp_bufs[0].ptr    = bufs[1].ptr;
+    rsp_bufs[0].offset = bufs[1].offset;
+    rsp_bufs[0].size   = bufs[1].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_sum_rows(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+    // We've written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[2].fd;
+    rsp_bufs[0].ptr    = bufs[2].ptr;
+    rsp_bufs[0].offset = bufs[2].offset;
+    rsp_bufs[0].size   = bufs[2].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup OP context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.dst.data  = (uint32_t) bufs[2].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_ssm_conv(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_activations_req(struct htp_context *     ctx,
                                  struct htp_general_req * req,
                                  struct dspqueue_buffer * bufs,
@@ -917,6 +1070,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
             case HTP_OP_MUL:
             case HTP_OP_ADD:
             case HTP_OP_SUB:
+            case HTP_OP_DIV:
                 if (n_bufs != 3) {
                     FARF(ERROR, "Bad binary-req buffer list");
                     continue;
@@ -934,6 +1088,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_unary_req(ctx, &req, bufs);
                 break;
 
+            case HTP_OP_SQR:
+            case HTP_OP_SQRT:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad unary-req buffer list");
+                    continue;
+                }
+
+                proc_unary_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_SUM_ROWS:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad unary-req buffer list");
+                    continue;
+                }
+
+                proc_sum_rows_req(ctx, &req, bufs);
+                break;
+
             case HTP_OP_UNARY_SILU:
             case HTP_OP_UNARY_GELU:
                 if (n_bufs != 2) {
@@ -946,6 +1119,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
             case HTP_OP_GLU_SWIGLU:
             case HTP_OP_GLU_SWIGLU_OAI:
             case HTP_OP_SOFTMAX:
+            case HTP_OP_GLU_GEGLU:
                 if ((n_bufs != 2) && (n_bufs != 3)) {
                     FARF(ERROR, "Bad act-req buffer list");
                     continue;
@@ -993,6 +1167,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_get_rows_req(ctx, &req, bufs);
                 break;
 
+            case HTP_OP_CPY:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad cpy-req buffer list");
+                    continue;
+                }
+                proc_cpy_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_ARGSORT:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad argsort-req buffer list");
+                    continue;
+                }
+                proc_argsort_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_SSM_CONV:
+                if (n_bufs != 3) {
+                    FARF(ERROR, "Bad ssm-conv-req buffer list");
+                    continue;
+                }
+                proc_ssm_conv_req(ctx, &req, bufs);
+                break;
+
             default:
                 FARF(ERROR, "Unknown Op %u", req.op);
                 break;
diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c
index 9bb39db9..73aaba79 100644
--- a/ggml/src/ggml-hexagon/htp/matmul-ops.c
+++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c
@@ -3,105 +3,50 @@
 #pragma clang diagnostic ignored "-Wunused-variable"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
-
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
+
 #include 
-#include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+#include "hvx-dump.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
 #define MM_SPAD_SRC0_NROWS 16
 #define MM_SPAD_SRC1_NROWS 16
 #define MM_SPAD_DST_NROWS  2
 
-struct htp_matmul_type {
+struct htp_matmul_context {
     const char * type;
-    void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-    void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
-};
+    struct htp_ops_context * octx;
 
-typedef struct {
-    HVX_Vector v[2];
-} HVX_Vector_x2;
+    void (*vec_dot_1x1)(const int n, float * restrict s0,
+         const void * restrict vx0,
+         const void * restrict vy0);
 
-typedef struct {
-    HVX_Vector v[4];
-} HVX_Vector_x4;
+    void (*vec_dot_2x1)(const int n, float * restrict s0,
+         const void * restrict vx0, const void * restrict vx1,
+         const void * restrict vy0);
 
-typedef struct {
-    HVX_Vector v[8];
-} HVX_Vector_x8;
+    void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1,
+         const void * restrict vx0, const void * restrict vx1,
+         const void * restrict vy0, const void * restrict vy1);
 
-// vdelta control to replicate first 4x fp32 values across lanes
-static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = {
-    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
-    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
-    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
-    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
-    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
-    0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
-    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
-};
+    // Precomputed values
+    uint32_t src0_nrows_per_thread;
+    uint32_t src1_nrows_per_thread;
 
-// vdelta control to replicate and interleave first 8x fp32 values across lanes
-static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = {
-    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
-    0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
-    0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
-    0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
-    0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
-    0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
-    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
-};
-
-// vdelta control to replicate first fp32 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = {
-    0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
-    0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
-    0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
-    0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
-    0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
-    0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
-    0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-};
-
-// vdelta control to replicate first fp16 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
-    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
-    0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
-    0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
-    0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
-    0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
-    0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
-    0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-};
-
-// vdelta control to replicate first fp16 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
-    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-    0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+    struct fastdiv_values mm_div_ne12_ne1;
+    struct fastdiv_values mm_div_ne1;
+    struct fastdiv_values mm_div_r2;
+    struct fastdiv_values mm_div_r3;
 };
 
 // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
@@ -129,10 +74,10 @@ static inline size_t q8x4x2_row_size(uint32_t ne) {
     // ensures perfect alignment of quants and full row
     const uint32_t qk = QK_Q8_0x4x2;
     const uint32_t nb = (ne + qk - 1) / qk;
-    return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128);
+    return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128);
 }
 
-static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
+static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) {
     const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
 
     HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)
@@ -141,10 +86,11 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
     HVX_Vector v6_7 = vptr[3];  // ...
 
     const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
 
-    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
-    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
-    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F
+    HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F : first  128 elements
+    HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4   : second 128 elements
+    HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4);  // & 0x0F ...
     HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4);    // >> 4
     HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4);  // & 0x0F
     HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4);    // >> 4
@@ -152,21 +98,54 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) {
     HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
 
     // Convert uint4 to int4 (i.e. x - 8)
-    const HVX_Vector i8 = Q6_Vb_vsplat_R(8);
-    v0                  = Q6_Vb_vsub_VbVb(v0, i8);
-    v1                  = Q6_Vb_vsub_VbVb(v1, i8);
-    v2                  = Q6_Vb_vsub_VbVb(v2, i8);
-    v3                  = Q6_Vb_vsub_VbVb(v3, i8);
-    v4                  = Q6_Vb_vsub_VbVb(v4, i8);
-    v5                  = Q6_Vb_vsub_VbVb(v5, i8);
-    v6                  = Q6_Vb_vsub_VbVb(v6, i8);
-    v7                  = Q6_Vb_vsub_VbVb(v7, i8);
+    v0 = Q6_Vb_vsub_VbVb(v0, i8);
+    v1 = Q6_Vb_vsub_VbVb(v1, i8);
+    v2 = Q6_Vb_vsub_VbVb(v2, i8);
+    v3 = Q6_Vb_vsub_VbVb(v3, i8);
+    v4 = Q6_Vb_vsub_VbVb(v4, i8);
+    v5 = Q6_Vb_vsub_VbVb(v5, i8);
+    v6 = Q6_Vb_vsub_VbVb(v6, i8);
+    v7 = Q6_Vb_vsub_VbVb(v7, i8);
 
     HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
     return r;
 }
 
-static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) {
+static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
+    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+    const uint32_t qk   = QK_Q4_0x4x2; // 256
+    const uint32_t nb   = n / qk;
+    const uint32_t nloe = n % qk;
+
+    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+    const HVX_Vector i8      = Q6_Vb_vsplat_R(8);
+
+    HVX_Vector_x8 r;
+    uint32_t i = 0;
+
+    #pragma unroll(2)
+    for (i=0; i < nb; i++) {
+        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)
+        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : first  128 elements
+        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : second 128 elements
+        r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8);
+        r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8);
+    }
+
+    if (nloe) {
+        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)
+        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : even 128 elements
+        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : odd  128 elements
+        HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
+        r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8);
+        r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8);
+    }
+
+    return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) {
     const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
 
     HVX_Vector v0_1 = vptr[0];  // first 256 elements (128 bytes)
@@ -175,6 +154,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
     HVX_Vector v6_7 = vptr[3];  // ...
 
     const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+    const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
 
     HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4);  // & 0x0F
     HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4);    // >> 4
@@ -185,21 +165,54 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr)
     HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4);  // & 0x0F
     HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4);    // >> 4
 
-    HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut;
-    v0             = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
-    v1             = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
-    v2             = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
-    v3             = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
-    v4             = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
-    v5             = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
-    v6             = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
-    v7             = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
+    v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+    v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+    v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
+    v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
+    v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
+    v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
+    v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
+    v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
 
     HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
     return r;
 }
 
-static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
+static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
+    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
+
+    const uint32_t qk   = QK_Q4_0x4x2; // 256
+    const uint32_t nb   = n / qk;
+    const uint32_t nloe = n % qk;
+
+    const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
+    const HVX_Vector lut     = *(const HVX_Vector *) kvalues_mxfp4_lut;
+
+    HVX_Vector_x8 r;
+    uint32_t i = 0;
+
+    #pragma unroll(2)
+    for (i=0; i < nb; i++) {
+        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)
+        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : first  128 elements
+        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : second 128 elements
+        r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
+        r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
+    }
+
+    if (nloe) {
+        HVX_Vector v = vptr[i];                    // 256 elements (128 bytes)
+        HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4);  // & 0x0F : even 128 elements
+        HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4);    // >> 4   : odd  128 elements
+        HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
+        r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
+        r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
+    }
+
+    return r;
+}
+
+static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) {
     const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
 
     HVX_Vector v0 = vptr[0];  // first  128 vals
@@ -215,44 +228,8 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) {
     return r;
 }
 
-static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) {
-    const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
-
-    HVX_Vector v0 = vptr[0];  // first  64 vals
-    HVX_Vector v1 = vptr[1];  // second 64 vals
-    HVX_Vector v2 = vptr[2];  // third  64 vals
-    HVX_Vector v3 = vptr[3];  // forth  64 vals
-
-    HVX_Vector_x4 r = { v0, v1, v2, v3 };
-    return r;
-}
-
-static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) {
-    const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr;
-
-    HVX_VectorPair v0 = vptr[0];  // first  64 vals
-    HVX_VectorPair v1 = vptr[1];  // second 64 vals
-    HVX_VectorPair v2 = vptr[2];  // third  64 vals
-    HVX_VectorPair v3 = vptr[3];  // forth  64 vals
-
-    HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero());
-    HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero());
-    HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero());
-    HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero());
-    HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero());
-    HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero());
-    HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero());
-    HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero());
-
-    HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo));
-    HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo));
-    HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo));
-    HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo));
-
-    // vcombine does a shuffle, use vdeal to undo
-
-    HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) };
-    return r;
+static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) {
+    return hvx_vec_load_q8x4x8_full(ptr);
 }
 
 // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors).
@@ -262,14 +239,14 @@ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict
 // if() checks are optimized out at compile time -- make sure to pass N as a constexpr.
 
 static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
-    HVX_Vector r0 = Q6_V_vsplat_R(0);
-    HVX_Vector r1 = Q6_V_vsplat_R(0);
-    HVX_Vector r2 = Q6_V_vsplat_R(0);
-    HVX_Vector r3 = Q6_V_vsplat_R(0);
-    HVX_Vector r4 = Q6_V_vsplat_R(0);
-    HVX_Vector r5 = Q6_V_vsplat_R(0);
-    HVX_Vector r6 = Q6_V_vsplat_R(0);
-    HVX_Vector r7 = Q6_V_vsplat_R(0);
+    HVX_Vector r0 = Q6_V_vzero();
+    HVX_Vector r1 = Q6_V_vzero();
+    HVX_Vector r2 = Q6_V_vzero();
+    HVX_Vector r3 = Q6_V_vzero();
+    HVX_Vector r4 = Q6_V_vzero();
+    HVX_Vector r5 = Q6_V_vzero();
+    HVX_Vector r6 = Q6_V_vzero();
+    HVX_Vector r7 = Q6_V_vzero();
 
     HVX_VectorPair p3;
     HVX_VectorPair p2;
@@ -308,40 +285,67 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns
 }
 
 static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) {
-    return hvx_vec_rmpy_x8_n(x, y, 1024);
+    HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]);
+    HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]);
+    HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]);
+    HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]);
+    HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]);
+    HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]);
+    HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]);
+    HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]);
+
+    HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4);
+    HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4);
+    HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4);
+    HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4);
+
+    r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
+    r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
+    r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2));
+    r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3));
+
+    p0 = Q6_W_vdeal_VVR(r1, r0, -4);
+    p1 = Q6_W_vdeal_VVR(r3, r2, -4);
+
+    r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
+    r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1));
+
+    p0 = Q6_W_vdeal_VVR(r1, r0, -4);
+    r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0));
+
+    return r0;
 }
 
-// Handle most common cases of tensors not multiple of 1024.
-static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
-    if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); };
-    if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); };
-    if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); };
-    return hvx_vec_rmpy_x8_n(x, y, 1024);
+static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) {
+    if (n >= 512)
+        return hvx_vec_rmpy_x8_full(x, y);
+
+    return hvx_vec_rmpy_x8_partial(x, y, 512);
 }
 
-static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
-    const uint32_t x_qblk_size = qk / 2;                                     // int4
-    const uint32_t x_qrow_size = n / 2;                                      // int4 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                      // int4
+    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                  // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                         // int8
-    const uint32_t y_qrow_size = n;                                          // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size);  // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 
-    // Row sum (qf32)
-    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+    // Row sum (sf)
+    HVX_Vector r0_sum = Q6_V_vzero();
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
@@ -352,79 +356,77 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
 
     uint32_t i = 0;
     for (; i < nb; i++) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    // Process leftovers
     if (nloe) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
 
-        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 
-        // Zero out unused scales
+        // Zero out unused elements
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Reduce and convert into fp32
-    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
-    hvx_vec_store_u(&s[0], 4, r0_sum);
+    hvx_vec_store_u(s0, 4, r0_sum);
 }
 
-static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
-                                      float * restrict s,
-                                      const void * restrict vx,
-                                      uint32_t vx_row_size,
-                                      const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+                                      const void * restrict vx0, const void * restrict vx1,
+                                      const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t x_qblk_size = qk / 2;                                                           // int4
-    const uint32_t x_qrow_size = n / 2;                                                            // int4 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                      // int4
+    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                                               // int8
-    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 
-    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
-
-    // Row sum (qf32)
-    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
-    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+    // Row sum (sf)
+    HVX_Vector r0_sum = Q6_V_vzero();
+    HVX_Vector r1_sum = Q6_V_vzero();
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
@@ -435,14 +437,14 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
 
     uint32_t i = 0;
     for (; i < nb; i++) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
-        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
         HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
         HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 
@@ -452,50 +454,178 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    // Process leftovers
     if (nloe) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
-        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
 
-        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
-        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
         HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
         HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
 
-        // Zero out unused scales
+        // Zero out unused elements
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
         r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
+        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Convert into fp32 and reduce
-    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
-    r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+    hvx_vec_store_u(s0, 8, rsum);
 }
 
-static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+                                        const void * restrict vx0, const void * restrict vx1,
+                                        const void * restrict vy0, const void * restrict vy1) {
+    assert(n % 32 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
+    assert((unsigned long) vy1 % 128 == 0);
+
+    const uint32_t qk = QK_Q4_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk / 2;                                      // int4
+    const uint32_t x_qrow_size = n / 2;                                       // int4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
+
+    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
+    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
+    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
+    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vzero();
+    HVX_Vector r0_c1_sum = Q6_V_vzero();
+    HVX_Vector r1_c0_sum = Q6_V_vzero();
+    HVX_Vector r1_c1_sum = Q6_V_vzero();
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        // Load src1 columns (reused across both src0 rows)
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
+
+        // Load src0 rows (reused across both src1 columns)
+        HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size);
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+        // Load scales
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        // Compute combined scales
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Apply scales and accumulate
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q   + i * y_qblk_size, nloe);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q   + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q  = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+        HVX_Vector_x8 r1_q  = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
+
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
+
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(s0, 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(s1, 8, r0_r1_c1_sum);  // row0,col1 row1,col1
+}
+
+static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
@@ -507,14 +637,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
     const uint32_t y_qblk_size = qk;                                         // int8
     const uint32_t y_qrow_size = n;                                          // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales
 
-    // Row sum (qf32)
-    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+    // Row sum (sf)
+    HVX_Vector r0_sum = Q6_V_vzero();
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
@@ -525,79 +655,77 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
 
     uint32_t i = 0;
     for (; i < nb; i++) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    // Process leftovers
     if (nloe) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
 
-        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
 
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
 
-        // Zero out unused scales
+        // Zero out unused elements
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Reduce and convert into fp32
-    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
-    hvx_vec_store_u(&s[0], 4, r0_sum);
+    hvx_vec_store_u(s0, 4, r0_sum);
 }
 
-static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
-                                      float * restrict s,
-                                      const void * restrict vx,
-                                      uint32_t vx_row_size,
-                                      const void * restrict vy) {
+static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+                                      const void * restrict vx0, const void * restrict vx1,
+                                      const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_Q4_0x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t x_qblk_size = qk;                                                               // int8
-    const uint32_t x_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk;                                          // int8
+    const uint32_t x_qrow_size = n;                                           // int8 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                                               // int8
-    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 
-    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
-
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);     // then scales
 
     // Row sum (qf32)
-    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
-    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+    HVX_Vector r0_sum = Q6_V_vzero();
+    HVX_Vector r1_sum = Q6_V_vzero();
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
@@ -608,14 +736,14 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
 
     uint32_t i = 0;
     for (; i < nb; i++) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
-        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q    + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
         HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
 
-        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
+        HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d    + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
         HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
 
@@ -625,18 +753,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
+    // Process leftovers
     if (nloe) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
-        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q    + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
 
-        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe));
-        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe));
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
+        HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
 
         HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
         HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
@@ -645,33 +773,158 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
         HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
 
-        // Zero out unused scales
+        // Zero out unused elements
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
         r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
+        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Convert into fp32 and reduce
-    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
-    r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+    hvx_vec_store_u(s0, 8, rsum);
 }
 
-static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
-                                     float * restrict s,
-                                     const void * restrict vx,
-                                     const void * restrict vy) {
+static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+                                        const void * restrict vx0, const void * restrict vx1,
+                                        const void * restrict vy0, const void * restrict vy1) {
+    assert(n % 32 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
+    assert((unsigned long) vy1 % 128 == 0);
+
+    const uint32_t qk = QK_Q8_0x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t x_qblk_size = qk;                                          // int8
+    const uint32_t x_qrow_size = n;                                           // int8 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
+
+    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
+    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
+    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
+    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vzero();
+    HVX_Vector r0_c1_sum = Q6_V_vzero();
+    HVX_Vector r1_c0_sum = Q6_V_vzero();
+    HVX_Vector r1_c1_sum = Q6_V_vzero();
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        // Load src1 columns (reused across both src0 rows)
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
+
+        // Load src0 rows (reused across both src1 columns)
+        HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size);
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+        // Load scales
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        // Compute combined scales
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Apply scales and accumulate
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q   + i * y_qblk_size, nloe);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q   + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q  = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+        HVX_Vector_x8 r1_q  = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
+
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
+
+        HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d   + i * y_dblk_size));
+        HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d   + i * y_dblk_size));
+        HVX_Vector r0_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
+        HVX_Vector r1_d  = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
+
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
+
+        // Zero out unused elements
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
+}
+
+static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_MXFP4x4x2 * 4;
 
@@ -683,14 +936,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
     const uint32_t y_qblk_size = qk;                                         // int8
     const uint32_t y_qrow_size = n;                                          // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0);           // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0);              // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size);    // then scales
 
-    // Row sum (qf32)
-    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
+    // Row sum (sf)
+    HVX_Vector r0_sum = Q6_V_vzero();
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
@@ -701,8 +954,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
 
     uint32_t i = 0;
     for (; i < nb; i++) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(   y_q    + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
 
@@ -728,17 +981,17 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
     // Process leftovers
     if (nloe) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(   y_q    + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
 
-        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
+        HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
 
-        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+        HVX_Vector vy_d = *(const HVX_UVector *) (y_d    + i * y_dblk_size);
         HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
 
         // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
@@ -761,62 +1014,60 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
         // Zero-out unused scales
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Reduce and convert into fp32
-    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
+    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
-    hvx_vec_store_u(&s[0], 4, r0_sum);
+    hvx_vec_store_u(s0, 4, r0_sum);
 }
 
-static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
-                                         float * restrict s,
-                                         const void * restrict vx,
-                                         uint32_t vx_row_size,
-                                         const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0,
+                                      const void * restrict vx0, const void * restrict vx1,
+                                      const void * restrict vy0) {
     assert(n % 32 == 0);  // min sub-block size
-    assert((unsigned long) vx % 128 == 0);
-    assert((unsigned long) vy % 128 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
 
     const uint32_t qk = QK_MXFP4x4x2 * 4;
 
-    const uint32_t x_dblk_size = 8 * 4 * 1;                                                        // 32x e8m0
-    const uint32_t x_qblk_size = qk / 2;                                                           // fp4
-    const uint32_t x_qrow_size = n / 2;                                                            // fp4 (not padded)
+    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0
+    const uint32_t x_qblk_size = qk / 2;                                      // fp4
+    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)
 
-    const uint32_t y_dblk_size = 8 * 4 * 2;                                                        // 32x __fp16
-    const uint32_t y_qblk_size = qk;                                                               // int8
-    const uint32_t y_qrow_size = n;                                                                // int8 (not padded)
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
 
-    const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
 
-    const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0);            // quants first
-    const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size);  // then scales
+    const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0;               // quants first
+    const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size;     // then scales
 
-    const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
-    const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
-
-    // Row sum (qf32)
-    HVX_Vector r0_sum = Q6_V_vsplat_R(0);
-    HVX_Vector r1_sum = Q6_V_vsplat_R(0);
+    // Row sum (sf)
+    HVX_Vector r0_sum = Q6_V_vzero();
+    HVX_Vector r1_sum = Q6_V_vzero();
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
-    // Apply scale to acc and accumulate into the row sum (qf32).
+    // Apply scale to acc and accumulate into the row sum (f32).
 
     const uint32_t nb   = n / qk;  // num full blocks
     int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
 
     uint32_t i = 0;
     for (; i < nb; i++) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
-        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(   y_q    + i * y_qblk_size);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
         HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
@@ -849,20 +1100,20 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
     // Process leftovers
     if (nloe) {
-        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size);
-        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
-        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
+        HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(   y_q    + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
 
         HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
         HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
 
-        HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size);
+        HVX_Vector vy_d = *(const HVX_UVector *) (y_d    + i * y_dblk_size);
         HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
         HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
 
@@ -887,111 +1138,326 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
         HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
 
-        // Zero-out unused scales
+        // Zero-out unused values
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
         r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
+        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Convert into fp32 and reduce
-    r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum));
-    r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+    hvx_vec_store_u(s0, 8, rsum);
 }
 
-static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1,
+                                        const void * restrict vx0, const void * restrict vx1,
+                                        const void * restrict vy0, const void * restrict vy1) {
+    assert(n % 32 == 0);
+    assert((unsigned long) vx0 % 128 == 0);
+    assert((unsigned long) vx1 % 128 == 0);
+    assert((unsigned long) vy0 % 128 == 0);
+    assert((unsigned long) vy1 % 128 == 0);
+
+    const uint32_t qk = QK_MXFP4x4x2 * 4;
+
+    const uint32_t x_dblk_size = 8 * 4 * 1;                                   // 32x e8m0
+    const uint32_t x_qblk_size = qk / 2;                                      // fp4
+    const uint32_t x_qrow_size = n / 2;                                       // fp4 (not padded)
+
+    const uint32_t y_dblk_size = 8 * 4 * 2;                                   // 32x __fp16
+    const uint32_t y_qblk_size = qk;                                          // int8
+    const uint32_t y_qrow_size = n;                                           // int8 (not padded)
+
+    const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;            // quants first
+    const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;  // then scales
+    const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;            // quants first
+    const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;  // then scales
+
+    const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;              // quants first
+    const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;    // then scales
+    const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;              // quants first
+    const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;    // then scales
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_Vector r0_c0_sum = Q6_V_vzero();
+    HVX_Vector r0_c1_sum = Q6_V_vzero();
+    HVX_Vector r1_c0_sum = Q6_V_vzero();
+    HVX_Vector r1_c1_sum = Q6_V_vzero();
+
+    const uint32_t nb   = n / qk;  // num full blocks
+    const uint32_t nloe = n % qk;  // num leftover elements
+
+    uint32_t i = 0;
+    for (; i < nb; i++) {
+        // Load src1 columns (reused across both src0 rows)
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
+
+        // Load src0 rows (reused across both src1 columns)
+        HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size);
+        HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size);
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
+
+        // Load scales
+        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);
+        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);
+        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);
+        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
+        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
+        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
+
+        // Compute combined scales
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+        // Apply scales and accumulate
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Process leftovers
+    if (nloe) {
+        HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(   y0_q   + i * y_qblk_size, nloe);
+        HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(   y1_q   + i * y_qblk_size, nloe);
+        HVX_Vector_x8 r0_q  = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe);
+        HVX_Vector_x8 r1_q  = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe);
+
+        HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
+        HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
+        HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
+        HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
+
+        HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d   + i * y_dblk_size);
+        HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d   + i * y_dblk_size);
+        HVX_Vector r0_d  = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
+        HVX_Vector r1_d  = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
+
+        // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
+        HVX_Vector half = Q6_Vh_vsplat_R(0x3800);  // 0.5 in fp16
+        vy0_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
+        vy0_d           = Q6_Vsf_equals_Vqf32(vy0_d);
+        vy1_d           = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
+        vy1_d           = Q6_Vsf_equals_Vqf32(vy1_d);
+
+        // Convert rX_d scales from e8m0 to fp32
+        // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
+        // Left shift with zero fill to create FP32
+        // FIXME: might need to handle zero as a special case (see ggml-cpu code)
+        HVX_Vector expand    = *(const HVX_Vector *) expand_x32_e8m0;
+        HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
+        r0_d                 = Q6_V_vdelta_VV(r0_d, expand);
+        r0_d                 = Q6_V_vand_VV(r0_d, e8m0_mask);
+        r0_d                 = Q6_Vw_vasl_VwR(r0_d, 23);
+        r1_d                 = Q6_V_vdelta_VV(r1_d, expand);
+        r1_d                 = Q6_V_vand_VV(r1_d, e8m0_mask);
+        r1_d                 = Q6_Vw_vasl_VwR(r1_d, 23);
+
+        HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
+        HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
+        HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
+        HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
+
+        // Zero out unused scales
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
+        r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
+        r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
+        r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
+        r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
+        r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
+        r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
+        r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
+        r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
+
+        HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
+        HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
+        HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
+        HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
+
+        r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
+        r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
+        r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
+        r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
+    }
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
+}
+
+static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const HVX_Vector * restrict x = (const HVX_Vector *) vx;
     const HVX_Vector * restrict y = (const HVX_Vector *) vy;
 
     uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
     uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-    HVX_Vector rsum = Q6_V_vsplat_R(0);
+    HVX_VectorPair rsum_p = Q6_W_vzero();
 
     uint32_t i = 0;
 
     #pragma unroll(4)
     for (i = 0; i < nvec; i++) {
-        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
     }
 
     if (nloe) {
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
         HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
         HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
-
-        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+        rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
     }
 
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
-    hvx_vec_store_u(&s[0], 4, rsum);
+    HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
+    hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
 }
 
-static void vec_dot_f16_f16_aa_rx2(const int n,
-                                float * restrict s,
-                                const void * restrict vx,
-                                uint32_t vx_row_size,
-                                const void * restrict vy) {
-    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
-    const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
-    const HVX_Vector * restrict y  = (const HVX_Vector *) vy;
+static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0,
+                                const void * restrict vx0, const void * restrict vx1,
+                                const void * restrict vy0) {
+    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+    const HVX_Vector * restrict y  = (const HVX_Vector *) vy0;
 
     uint32_t nvec = n / VLEN_FP16;
     uint32_t nloe = n % VLEN_FP16;
 
-    HVX_Vector rsum0 = Q6_V_vsplat_R(0);
-    HVX_Vector rsum1 = Q6_V_vsplat_R(0);
+    HVX_VectorPair rsum0_p = Q6_W_vzero();
+    HVX_VectorPair rsum1_p = Q6_W_vzero();
 
     uint32_t i = 0;
 
     #pragma unroll(2)
     for (i = 0; i < nvec; i++) {
         HVX_Vector y_hf = y[i];
-        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
-        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
-
-        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
-        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
+        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
     }
 
     if (nloe) {
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector y_hf  = Q6_V_vand_QV(bmask, y[i]);
         HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
         HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
-        HVX_Vector y_hf  = Q6_V_vand_QV(bmask, y[i]);
-
-        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
-        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
-
-        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
-        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+        rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
+        rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
     }
 
-    rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
-    rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
+    HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
+    HVX_Vector rsum  = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
+    hvx_vec_store_u(s0, 8, rsum);
 }
 
-static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1,
+                                const void * restrict vx0, const void * restrict vx1,
+                                const void * restrict vy0, const void * restrict vy1) {
+    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
+    const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
+    const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
+    const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
+
+    uint32_t nvec = n / VLEN_FP16;
+    uint32_t nloe = n % VLEN_FP16;
+
+    // Row sums (sf) - 4 accumulators for 2×2 tile
+    HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
+    HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
+    HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
+    HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
+
+    uint32_t i = 0;
+
+    #pragma unroll(2)
+    for (i = 0; i < nvec; i++) {
+        HVX_Vector r0_hf = x0[i];
+        HVX_Vector r1_hf = x1[i];
+        HVX_Vector c0_hf = y0[i];
+        HVX_Vector c1_hf = y1[i];
+
+        // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
+        r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
+        r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
+        r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
+        r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
+    }
+
+    if (nloe) {
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+
+        HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
+        HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
+        HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
+        HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
+
+        r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
+        r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
+        r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
+        r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
+    }
+
+    HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
+    HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
+    HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
+    HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
+
+    // Reduce and store results
+    HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
+    HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
+
+    hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);  // row0,col0 row1,col0
+    hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);  // row0,col1 row1,col1
+}
+
+static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const HVX_UVector * restrict x = (const HVX_UVector *) vx;
     const HVX_UVector * restrict y = (const HVX_UVector *) vy;
 
     uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
     uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-    HVX_Vector rsum = Q6_V_vsplat_R(0);
+    HVX_Vector rsum = Q6_V_vzero();
 
     uint32_t i = 0;
 
@@ -1010,20 +1476,20 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
         rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
-static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
     const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
     const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
 
     uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
     uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-    const HVX_Vector zero = Q6_V_vsplat_R(0);
+    const HVX_Vector zero = Q6_V_vzero();
 
-    HVX_Vector       rsum = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum = Q6_V_vzero();
 
     uint32_t i = 0;
 
@@ -1062,7 +1528,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
         rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    // Convert into fp32 and reduce
+    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
@@ -1110,14 +1577,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-#define htp_matmul_preamble            \
-    htp_matmul_tensors_preamble;       \
-    dma_queue *dma_queue           = octx->ctx->dma[ith];         \
-    uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+#define htp_matmul_preamble                                     \
+    struct htp_matmul_context * mmctx = data;                   \
+    struct htp_ops_context * octx  = mmctx->octx;               \
+    htp_matmul_tensors_preamble;                                \
+    dma_queue *dma_queue           = octx->ctx->dma[ith];       \
+    uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread;
 
 // *** matmul with support for 4d tensors and full broadcasting
 
-static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_4d(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     uint64_t t1, t2;
@@ -1163,13 +1632,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
         for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
             for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
-                const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
-                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
+                const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1);
+                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1);
                 const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
 
                 // broadcast src0 into src1
-                const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
-                const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
+                const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3);
+                const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2);
 
                 const uint32_t i1 = i11;
                 const uint32_t i2 = i12;
@@ -1182,7 +1651,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
                 const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
                 for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
                     const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
-                    mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
+                    mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col);
                 }
             }
         }
@@ -1197,7 +1666,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx
 }
 
 // src1 tensor is already in VTCM spad
-static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_2d(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
@@ -1222,7 +1691,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     // Per-thread VTCM scratchpads for all tensors
     // Note that the entire src1 tensor is already in VTCM
     // For other tensors we allocate N rows per thread, padded to HVX vector size
-    uint8_t * restrict spad_dst  = dst_spad->data + dst_spad->size_per_thread * ith;
+    uint8_t * restrict spad_dst  = dst_spad->data  + dst_spad->size_per_thread  * ith;
     uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith;
     uint8_t * restrict src1_data = src1_spad->data;
 
@@ -1246,11 +1715,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
 
-        #pragma unroll(2)
-        for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
+        // Process src1 columns in pairs (2×2 tiling)
+        uint32_t ir1 = 0;
+        for (; ir1 + 1 < src1_nrows; ir1 += 2) {
+            const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
+            const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
+            float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
+            float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
+            mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1);
+        }
+
+        // Handle remaining src1 rows (fallback to 2×1)
+        for (; ir1 < src1_nrows; ++ir1) {
             const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
             float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
-            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
+            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col);
         }
 
         // Prefetch next (n + spad_nrows) row
@@ -1274,20 +1753,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
         for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
             const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
             float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
-            mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
         }
     }
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+    FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
          src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
 // q8x4x2 src1 tensor is already in VTCM spad
-static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matvec_2d(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     const uint32_t src0_nrows = ne01;
@@ -1338,7 +1817,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
     // Process src0 rows
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-        mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
+        mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col);
 
         // Prefetch next (n + spad_nrows) row
         const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -1356,14 +1835,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
         dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
                        src0_stride, src0_row_size, 1);
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-        mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
+        mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
     }
 
-    hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
+    hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row);
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth,
+    FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
          src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
          (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
@@ -1377,7 +1856,7 @@ struct mmid_row_mapping {
 };
 
 // src1 tensor is already in VTCM spad
-static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matmul_id(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     struct htp_tensor * restrict     ids = &octx->src2;
@@ -1411,7 +1890,7 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
     const size_t src0_row_size = nb01;
     const size_t src1_row_size = q8x4x2_row_size(ne10);
 
-    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
 
     // Per-thread VTCM scratchpads for all tensors
     // Note that the entire src1 tensor is already in VTCM
@@ -1450,11 +1929,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
                 const int               rm2         = row_mapping.i2;  // token idx
 
                 const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
-                const uint8_t * restrict src1_col =
-                    (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
                 float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
 
-                mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+                mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
             }
 
             // Prefetch next (n + spad_nrows) row
@@ -1480,25 +1958,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
                 const int               rm2         = row_mapping.i2;  // token idx
 
                 const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1;        // src1 row idx
-                const uint8_t * restrict src1_col =
-                    (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
+                const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size);
                 float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0));
 
-                mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+                mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
             }
         }
     }
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+    FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
          ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
          src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1],
          dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
 // src1 tensor is already in VTCM spad
-static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+static void matvec_id(unsigned int nth, unsigned int ith, void * data) {
     htp_matmul_preamble;
 
     struct htp_tensor * restrict     ids = &octx->src2;
@@ -1524,7 +2001,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
     const size_t src0_row_size = nb01;
     const size_t src1_row_size = q8x4x2_row_size(ne10);
 
-    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
 
     const uint32_t n_aids = src2->ne[0];  // num activated experts
     const uint32_t n_ids  = ne02;         // num experts
@@ -1558,7 +2035,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
         // Process src0 rows
         for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
             const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+            mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col);
 
             // Prefetch next (n + spad_nrows) row
             const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
@@ -1576,13 +2053,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
             dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
                            src0_row_size_padded, src0_row_size, 1);
             const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-            mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
+            mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col);
         }
     }
 
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type,
+    FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type,
          ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0],
          src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0],
          dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
@@ -1590,18 +2067,18 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx
 
 // *** dynamic quant
 
-static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
     assert((unsigned long) x % 128 == 0);
     assert((unsigned long) y_q % 128 == 0);
 
     HVX_Vector * vx = (HVX_Vector *) x;
-    HVX_Vector zero   = Q6_V_vsplat_R(0);
+    HVX_Vector zero   = Q6_V_vzero();
 
     // Use reduce max fp32 to find max(abs(e)) first
-    HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
-    HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
-    HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
-    HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
+    HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0]));
+    HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1]));
+    HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2]));
+    HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3]));
     // Load and convert into QF32
     HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
     HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
@@ -1609,10 +2086,10 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
     HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero);  // 32 elements
 
     // Convert to QF32
-    HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
-    HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
-    HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
-    HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
+    HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
+    HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
+    HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
+    HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
 
     // Combine and convert to fp16
     HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
@@ -1622,11 +2099,6 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
     HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
     HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
 
-    // Replicate first fp16 scale across all lanes
-    HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
-    vmax01_hf         = Q6_V_vdelta_VV(vmax01_hf, ctrl);
-    vmax23_hf         = Q6_V_vdelta_VV(vmax23_hf, ctrl);
-
     HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd01_hf   = Q6_Vhf_equals_Vqf16(vd01_qf16);
@@ -1641,8 +2113,8 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
     hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
 
     // Divide input by the scale
-    HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
-    HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
+    HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
+    HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
     vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
     vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
 
@@ -1654,14 +2126,14 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri
     *(HVX_Vector *) y_q = vx_i8;
 }
 
-static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
     assert((unsigned long) x % 128 == 0);
     assert((unsigned long) y_q % 128 == 0);
 
     HVX_Vector * vx = (HVX_Vector *) x;
 
     // Load and convert into QF32
-    HVX_Vector zero   = Q6_V_vsplat_R(0);
+    HVX_Vector zero   = Q6_V_vzero();
     HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
     HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
     HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements
@@ -1672,13 +2144,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
     HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
 
     // Compute max and scale
-    HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
-    HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
-
-    // Replicate first fp16 scale across all lanes
-    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
-    vmax01_hf         = Q6_V_vdelta_VV(vmax01_hf, ctrl);
-    vmax23_hf         = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+    HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
+    HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
 
     HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
@@ -1689,8 +2156,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
     hvx_vec_store_u(y_d + 4, 4, vd23_hf);
 
     // Divide input by the scale
-    HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
-    HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
+    HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf);
+    HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf);
     vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
     vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
 
@@ -1702,14 +2169,14 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri
     *(HVX_Vector *) y_q = vx_i8;
 }
 
-static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
     assert((unsigned long) x % 128 == 0);
     assert((unsigned long) y_q % 128 == 0);
 
     HVX_Vector * vx = (HVX_Vector *) x;
 
     // Load and convert into QF32
-    HVX_Vector zero   = Q6_V_vsplat_R(0);
+    HVX_Vector zero   = Q6_V_vzero();
     HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero);  // 32 elements
     HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero);  // 32 elements
     HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero);  // 32 elements
@@ -1720,12 +2187,8 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
     HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
 
     // Compute max and scale
-    HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
-    vmax_hf            = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf);
-
-    // Replicate first fp16 scale across all lanes
-    HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
-    vmax_hf         = Q6_V_vdelta_VV(vmax_hf, ctrl);
+    HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
+    vmax_hf            = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
 
     HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008));  // 1.0 / 127.0
     HVX_Vector vd_hf   = Q6_Vhf_equals_Vqf16(vd_qf16);
@@ -1733,7 +2196,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
     *(HVX_UVector *) y_d = vd_hf;
 
     // Divide input by the scale
-    HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf);
+    HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf);
     vx01_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf));
     vx23_hf              = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf));
 
@@ -1746,7 +2209,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
 }
 
 // Overrides input x
-static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
+static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) {
     assert(k % 32 == 0);
     const uint32_t qk = QK_Q8_0x4x2;
     const uint32_t nb = (k + qk - 1) / qk;
@@ -1764,29 +2227,31 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
 
     for (uint32_t i = 0; i < nb; i++) {
 #if FP32_QUANTIZE_GROUP_SIZE == 32
-        quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
-        quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
+        quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+        quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
 #elif FP32_QUANTIZE_GROUP_SIZE == 64
-        quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
-        quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
+        quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+        quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
 #elif FP32_QUANTIZE_GROUP_SIZE == 128
-        quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
-        quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
+        quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+        quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
 #else
 #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
 #endif
     }
 
     // now copy the scales into final location
-    hvx_copy_fp16_ua(y_d, t_d, nb * 8);
+    hvx_copy_f16_ua(y_d, t_d, nb * 8);
 }
 
-static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
-                                 uint8_t * restrict dst,
-                                 struct htp_spad * spad,
-                                 uint32_t          nth,
-                                 uint32_t          ith,
-                                 uint32_t          nrows_per_thread) {
+static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_matmul_context * mmctx = data;
+    struct htp_ops_context * octx = mmctx->octx;
+
+    const struct htp_tensor * src = &octx->src1;
+    uint8_t * restrict dst = octx->src1_spad.data;
+    struct htp_spad * spad = &octx->src0_spad;
+    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
 
     uint64_t t1 = HAP_perf_get_qtimer_count();
 
@@ -1807,27 +2272,33 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
     uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first);
     uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith);
 
-    const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
+    const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float));
     memset(tmp_data, 0, src_row_size_padded);  // zero-out temp row data for padding
 
     for (uint32_t i = ir_first; i < ir_last; ++i) {
-        htp_l2fetch(src_data, 2, src_row_size, src_row_size);
-        hvx_copy_fp32_aa(tmp_data, src_data, ne0);
+        hex_l2fetch(src_data, src_row_size, src_row_size, 2);
+        hvx_copy_f32_aa(tmp_data, src_data, ne0);
 
         // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
-        quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0);
+        quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0);
         dst_data += dst_row_size;
         src_data += src_row_size;
     }
 
     uint64_t t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
+    FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first,
          ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
-                              uint32_t nrows_per_thread, uint32_t dst_stride) {
+static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_matmul_context * mmctx = data;
+    struct htp_ops_context * octx = mmctx->octx;
+
+    const struct htp_tensor * src = &octx->src1;
+    uint8_t * restrict dst = octx->src1_spad.data;
+    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+    uint32_t dst_stride = octx->src1_spad.stride;
 
     uint64_t t1 = HAP_perf_get_qtimer_count();
 
@@ -1848,8 +2319,8 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
     uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);
 
     for (uint32_t i = ir_first; i < ir_last; ++i) {
-        htp_l2fetch(src_data, 2, src_row_size, src_stride);
-        hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
+        hex_l2fetch(src_data, src_row_size, src_stride, 2);
+        hvx_copy_f16_f32_au(dst_data, src_data, ne0);
 
         dst_data += dst_stride;
         src_data += src_stride;
@@ -1857,13 +2328,19 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict
 
     uint64_t t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+    FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
         ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
 // TODO just a plain copy that should be done via the DMA during the Op setup
-static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
-                              uint32_t nrows_per_thread, uint32_t dst_stride) {
+static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_matmul_context * mmctx = data;
+    struct htp_ops_context * octx = mmctx->octx;
+
+    const struct htp_tensor * src = &octx->src1;
+    uint8_t * restrict dst = octx->src1_spad.data;
+    uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread;
+    uint32_t dst_stride = octx->src1_spad.stride;
 
     uint64_t t1 = HAP_perf_get_qtimer_count();
 
@@ -1884,8 +2361,8 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
     uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);
 
     for (uint32_t i = ir_first; i < ir_last; ++i) {
-        htp_l2fetch(src_data, 2, src_row_size, src_stride);
-        hvx_copy_fp16_au(dst_data, src_data, ne0);
+        hex_l2fetch(src_data, src_row_size, src_stride, 2);
+        hvx_copy_f16_au(dst_data, src_data, ne0);
 
         dst_data += dst_stride;
         src_data += src_stride;
@@ -1893,400 +2370,177 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict
 
     uint64_t t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+    FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
         ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
-}
-
-static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
-}
-
-static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-    quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
-}
-
-// ** matmul/matvec callbacks for worker_pool
-
-static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f16";
-    mt.vec_dot     = vec_dot_f16_f16_aa;
-    mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
-
-    matvec_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f16";
-    mt.vec_dot     = vec_dot_f16_f16_aa;
-    mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
-
-    matmul_2d(&mt, octx, n, i);
-}
-
-static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f32";
-    mt.vec_dot     = vec_dot_f16_f32_uu;
-
-    matmul_4d(&mt, octx, n, i);
-}
-
-static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "f16-f16";
-    mt.vec_dot     = vec_dot_f16_f16_uu;
-
-    matmul_4d(&mt, octx, n, i);
-}
-
-// ** matmul-id callbacks for worker_pool
-
-static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
-
-    matmul_id(&mt, octx, n, i);
-}
-
-static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "q8x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
-
-    matmul_id(&mt, octx, n, i);
-}
-
-static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matvec_id(&mt, octx, n, i);
-}
-
-static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = data;
-
-    struct htp_matmul_type mt;
-    mt.type        = "mxfp4x4x2-q8x4x2";
-    mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
-    mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
-
-    matmul_id(&mt, octx, n, i);
-}
-
-// ** main matmul entry point
 
 static inline bool htp_is_permuted(const struct htp_tensor * t) {
     return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
 }
 
+static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) {
+    switch (type) {
+        case HTP_TYPE_Q4_0:
+            mmctx->type        = "q4x4x2-f32";
+            mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2;
+            return 0;
+        case HTP_TYPE_Q8_0:
+            mmctx->type        = "q8x4x2-f32";
+            mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
+            return 0;
+        case HTP_TYPE_MXFP4:
+            mmctx->type        = "mxfp4x4x2-f32";
+            mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2;
+            return 0;
+        default:
+            return -1;
+    }
+}
+
+static void htp_mminit_spad(struct htp_ops_context * octx,
+                                 size_t dst_row_size,
+                                 size_t src0_row_size_padded,
+                                 size_t src1_row_size,
+                                 uint32_t src1_nrows,
+                                 size_t src2_spad_size_per_thread) {
+    octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+    octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+    octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
+
+    if (src2_spad_size_per_thread > 0) {
+        octx->src2_spad.size_per_thread = src2_spad_size_per_thread;
+        octx->src2_spad.size            = octx->src2_spad.size_per_thread;
+    }
+
+    // src0 spad is also used in dynamic quantizer to store padded src1 rows
+    size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
+    if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
+        octx->src0_spad.size_per_thread = src1_row_size_padded;
+    }
+
+    octx->src1_spad.size = octx->src1_spad.size_per_thread;
+    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+}
+
 int op_matmul(struct htp_ops_context * octx) {
     htp_matmul_tensors_preamble;
 
-    const char * op_type;
+    struct htp_matmul_context mmctx_struct = {0};
+    struct htp_matmul_context * mmctx = &mmctx_struct;
+    mmctx->octx = octx;
 
     const uint32_t src0_nrows = ne01 * ne02 * ne03;
     const uint32_t src1_nrows = ne11 * ne12 * ne13;
 
+    // Compute src0_nrows_per_thread
+    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
     const size_t src0_row_size = nb01;
     const size_t dst_row_size  = nb1;
     size_t       src1_row_size = nb11;
 
-    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
     size_t       src1_row_size_padded;
 
     worker_callback_t quant_job_func;
-    worker_callback_t matmul_job_func;
+    worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d;
 
     bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE);
 
-    switch (src0->type) {
-        case HTP_TYPE_Q4_0:
-            op_type        = "q4x4x2-fp32";
-            quant_job_func = htp_quantize_fp32_q8x4x2;
-            if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
-            } else {
-                matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
-            }
+    if (src0->type == HTP_TYPE_F16) {
+        // Try optimized f16-f16 path first (src1 in VTCM)
+        const size_t f16_src1_row_size  = hex_round_up(ne10 * 2, 128);
+        const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256);
+        const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+        const size_t f16_dst_spad_size  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
 
-            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
+        const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
 
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
+        // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+        // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+        const bool is_batched  = (ne02 > 1) || (ne03 > 1);
+        const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
 
-            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+        if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+            // Optimized path
+            quant_job_func     = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16;
+            mmctx->type        = "f16-f16";
+            mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1;
+            mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1;
+            mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2;
 
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
+            src1_row_size = f16_src1_row_size;  // row size post quantization
+
+            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256);
 
             octx->src1_spad.size = octx->src1_spad.size_per_thread;
             octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
             octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_Q8_0:
-            op_type        = "q8x4x2-fp32";
-            quant_job_func = htp_quantize_fp32_q8x4x2;
-            if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
+        } else {
+            // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+            quant_job_func = NULL;
+            if (src1->type == HTP_TYPE_F32) {
+                mmctx->type        = "f16-f32";
+                mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1;
+                matmul_job_func    = matmul_4d;
             } else {
-                matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
+                mmctx->type        = "f16-f16";
+                mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1;
+                matmul_job_func    = matmul_4d;
             }
 
-            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
+            src1_row_size = nb11;  // original row size in DDR
 
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
+            octx->dst_spad.size_per_thread  = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+            octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
 
-            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
             octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
             octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
 
-        case HTP_TYPE_MXFP4:
-            op_type        = "mxfp4x4x2-f32";
-            quant_job_func = htp_quantize_fp32_q8x4x2;
-            if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
-            } else {
-                matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
-            }
+            // Init fastdiv for matmul_4d (supports broadcasting)
+            mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+            mmctx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);
+            mmctx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+            mmctx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
 
-            src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-
-            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_F16:
-            {
-                // Try optimized f16-f16 path first (src1 in VTCM)
-                const size_t f16_src1_row_size  = htp_round_up(ne10 * 2, 128);
-                const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
-                const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
-                const size_t f16_dst_spad_size  = htp_round_up(MM_SPAD_DST_NROWS  * dst_row_size, 256) * octx->n_threads;
-
-                const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
-
-                // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
-                // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
-                const bool is_batched  = (ne02 > 1) || (ne03 > 1);
-                const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
-
-                if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
-                    // Optimized path
-                    op_type        = "f16-f16";
-                    quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
-                    if (src1_nrows > 1) {
-                        matmul_job_func = htp_matmul_2d_f16_f16;
-                    } else {
-                        matmul_job_func = htp_matvec_2d_f16_f16;
-                    }
-
-                    src1_row_size = f16_src1_row_size; // row size post quantization
-
-                    octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-                    octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-                    octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
-
-                    octx->src1_spad.size = octx->src1_spad.size_per_thread;
-                    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-                    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-                } else {
-                    // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
-                    quant_job_func  = NULL;
-                    if (src1->type == HTP_TYPE_F32) {
-                        op_type         = "f16-f32";
-                        matmul_job_func = htp_matmul_4d_f16_f32;
-                    } else {
-                        op_type         = "f16-f16";
-                        matmul_job_func = htp_matmul_4d_f16_f16;
-                    }
-
-                    src1_row_size = nb11; // original row size in DDR
-
-                    octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-                    octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
-                    octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
-
-                    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-                    octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
-                    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-
-                    // Init fastdiv for matmul_4d (supports broadcasting)
-                    octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
-                    octx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);
-                    octx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
-                    octx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
-
-                    need_quant = false;
-                }
-            }
-            break;
-
-        default:
+            need_quant = false;
+        }
+    } else {
+        if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
             return HTP_STATUS_NO_SUPPORT;
+        }
+
+        quant_job_func = quantize_f32_q8x4x2;
+        src1_row_size  = q8x4x2_row_size(ne10);
+        htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0);
     }
 
     // VTCM scratchpads for all tensors
     size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
 
-    FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+    FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
          octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size);
 
-    FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0],
+    FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0],
          src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
          dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data);
 
     // Make sure the reserved vtcm size is sufficient
     if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
+        FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type,
              octx->ctx->vtcm_size, spad_size);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
@@ -2295,48 +2549,47 @@ int op_matmul(struct htp_ops_context * octx) {
     octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
     octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
 
-    octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
-    octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
-
     octx->src0_spad.stride = src0_row_size_padded;
     octx->src1_spad.stride = src1_row_size;
 
     if (need_quant) {
-        // Run quant jobs
-        const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
-        octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+        const uint32_t n_quant_jobs  = MIN(src1_nrows, octx->n_threads);
+        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
     }
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        // Run matmul jobs
         const uint32_t n_matmul_jobs = octx->n_threads;
-        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs);
     }
 
     return HTP_STATUS_OK;
 }
 
-// ** main matmul-id entry point
-
 int op_matmul_id(struct htp_ops_context * octx) {
     htp_matmul_tensors_preamble;
 
+    struct htp_matmul_context mmctx_struct = {0};
+    struct htp_matmul_context * mmctx = &mmctx_struct;
+    mmctx->octx = octx;
+
     struct htp_tensor * restrict ids = &octx->src2;
 
-    const char * op_type;
-
-    worker_callback_t quant_job_func;
-    worker_callback_t matmul_id_job_func;
-
     const size_t src0_row_size = nb01;
     const size_t dst_row_size  = nb1;
 
-    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128);
 
     const uint32_t src0_nrows = ne01;  // per expert
     const uint32_t src1_nrows = ne11 * ne12 * ne13;
 
+    worker_callback_t quant_job_func;
+    worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id;
+
+    // Compute src0_nrows_per_thread
+    mmctx->src0_nrows_per_thread  = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
+    mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even
+
     size_t src1_row_size;
     size_t src1_row_size_padded;
 
@@ -2347,112 +2600,29 @@ int op_matmul_id(struct htp_ops_context * octx) {
     size_t matrix_row_counts_size = n_as * sizeof(uint32_t);
     size_t matrix_row_map_size    = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping);
 
-    switch (src0->type) {
-        case HTP_TYPE_Q4_0:
-            op_type        = "q4x2x2-f32";
-            quant_job_func = htp_quantize_fp32_q8x4x2;
-            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
-            if (src1_nrows > 1) {
-                matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2;
-            } else {
-                matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2;
-            }
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
-            octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src2_spad.size = octx->src2_spad.size_per_thread;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_Q8_0:
-            op_type        = "q8x2x2-f32";
-            quant_job_func = htp_quantize_fp32_q8x4x2;
-            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
-            if (src1_nrows > 1) {
-                matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2;
-            } else {
-                matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2;
-            }
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
-            octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src2_spad.size = octx->src2_spad.size_per_thread;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        case HTP_TYPE_MXFP4:
-            op_type        = "mxfp4x2x2-f32";
-            quant_job_func = htp_quantize_fp32_q8x4x2;
-            src1_row_size  = q8x4x2_row_size(ne10);  // row size post quantization
-            if (src1_nrows > 1) {
-                matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2;
-            } else {
-                matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2;
-            }
-
-            // Entire src1 tensor is placed into the VTCM
-            // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
-            octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
-
-            // src0 spad is also used in dynamic quantizer to store padded src1 rows
-            src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float));
-            if (octx->src0_spad.size_per_thread < src1_row_size_padded) {
-                octx->src0_spad.size_per_thread = src1_row_size_padded;
-            }
-
-            octx->src2_spad.size = octx->src2_spad.size_per_thread;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread;
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-            break;
-
-        default:
-            return HTP_STATUS_NO_SUPPORT;
+    if (htp_mminit_vec_dot(mmctx, src0->type) != 0) {
+        return HTP_STATUS_NO_SUPPORT;
     }
 
+    quant_job_func = quantize_f32_q8x4x2;
+    src1_row_size  = q8x4x2_row_size(ne10);
+
+    const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
+    htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread);
+
     size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size;
 
-    FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type,
+    FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type,
          octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size);
 
-    FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type,
+    FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
          ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data,
          src1->data, dst->data);
 
     // Make sure the reserved vtcm size is sufficient
     if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
-             octx->ctx->vtcm_size, spad_size);
+        FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
@@ -2461,8 +2631,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
     octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
     octx->dst_spad.data  = octx->src2_spad.data + octx->src2_spad.size;
 
-    octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
-    octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
+    octx->src0_spad.stride = src0_row_size_padded;
+    octx->src1_spad.stride = src1_row_size;
 
     if (src1_nrows > 1) {
         // initialize matrix_row_counts and map
@@ -2474,8 +2644,7 @@ int op_matmul_id(struct htp_ops_context * octx) {
         // group rows by src0 matrix
         for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {  // token idx
             for (uint32_t id = 0; id < n_ids; ++id) {         // expert idx
-                const uint32_t i02 =
-                    *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
+                const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
 
                 assert(i02 >= 0 && i02 < n_as);
 
@@ -2487,16 +2656,14 @@ int op_matmul_id(struct htp_ops_context * octx) {
 
     // Setup worker pool callbacks
     if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) {
-        // Run quant jobs
         const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
-        octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs);
+        mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
+        worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
     }
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        // Run matmul-id jobs
         const uint32_t n_matmul_jobs = octx->n_threads;
-        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs);
     }
 
     return HTP_STATUS_OK;
diff --git a/ggml/src/ggml-hexagon/htp/ops-utils.h b/ggml/src/ggml-hexagon/htp/ops-utils.h
deleted file mode 100644
index af9c3305..00000000
--- a/ggml/src/ggml-hexagon/htp/ops-utils.h
+++ /dev/null
@@ -1,149 +0,0 @@
-#ifndef OPS_UTILS_H
-#define OPS_UTILS_H
-
-#include "htp-msg.h"
-
-#ifndef MAX
-#    define MAX(a, b) ((a) > (b) ? (a) : (b))
-#endif
-
-#ifndef MIN
-#    define MIN(a, b) ((a) < (b) ? (a) : (b))
-#endif
-
-static inline uint64_t htp_get_cycles() {
-    uint64_t cycles = 0;
-    asm volatile(" %0 = c15:14\n" : "=r"(cycles));
-    return cycles;
-}
-
-static inline uint64_t htp_get_pktcnt() {
-    uint64_t pktcnt;
-    asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
-    return pktcnt;
-}
-
-static inline int32_t htp_is_aligned(void * addr, uint32_t align) {
-    return ((size_t) addr & (align - 1)) == 0;
-}
-
-static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
-    return m * ((n + m - 1) / m);
-}
-
-// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
-// Precompute mp (m' in the paper) and L such that division
-// can be computed using a multiply (high 32b of 64b result)
-// and a shift:
-//
-// n/d = (mulhi(n, mp) + n) >> L;
-struct fastdiv_values {
-    uint32_t mp;
-    uint32_t l;
-};
-
-static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
-    struct fastdiv_values result = { 0, 0 };
-    // compute L = ceil(log2(d));
-    while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
-        ++(result.l);
-    }
-
-    result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
-    return result;
-}
-
-static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
-    // Compute high 32 bits of n * mp
-    const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32);  // mulhi(n, mp)
-    // add n, apply bit shift
-    return (hi + n) >> vals->l;
-}
-
-static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
-    return n - fastdiv(n, vals) * d;
-}
-
-static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
-    const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
-    asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
-}
-
-static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
-    uint32_t left_off  = (size_t) addr & (chunk_size - 1);
-    uint32_t right_off = left_off + n;
-    return right_off <= chunk_size;
-}
-
-static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) {
-    char str[1024], *p = str, *p_end = str + sizeof(str);
-    p += snprintf(p, p_end - p, "%s: ", pref);
-    for (int i = 0; i < n && p < p_end; i++) {
-        p += snprintf(p, p_end - p, "%d, ", x[i]);
-    }
-    FARF(HIGH, "%s\n", str);
-}
-
-static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
-    char str[1024], *p = str, *p_end = str + sizeof(str);
-    p += snprintf(p, p_end - p, "%s: ", pref);
-    for (int i = 0; i < n && p < p_end; i++) {
-        p += snprintf(p, p_end - p, "%d, ", x[i]);
-    }
-    FARF(HIGH, "%s\n", str);
-}
-
-static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
-    char str[1024], *p = str, *p_end = str + sizeof(str);
-    p += snprintf(p, p_end - p, "%s: ", pref);
-    for (int i = 0; i < n; i++) {
-        p += snprintf(p, p_end - p, "%d, ", (int) x[i]);
-    }
-    FARF(HIGH, "%s\n", str);
-}
-
-static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) {
-    char str[1024], *p = str, *p_end = str + sizeof(str);
-    p += snprintf(p, p_end - p, "%s: ", pref);
-    for (int i = 0; i < n; i++) {
-        p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]);
-    }
-    FARF(HIGH, "%s\n", str);
-}
-
-static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) {
-    char str[1024], *p = str, *p_end = str + sizeof(str);
-    p += snprintf(p, p_end - p, "%s: ", pref);
-    for (int i = 0; i < n; i++) {
-        p += snprintf(p, p_end - p, "%.6f, ", x[i]);
-    }
-    FARF(HIGH, "%s\n", str);
-}
-
-static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) {
-    uint32_t n0 = n / 16;
-    uint32_t n1 = n % 16;
-
-    uint32_t i = 0;
-    for (; i < n0; i++) {
-        htp_dump_fp32_line(pref, x + (16 * i), 16);
-    }
-    if (n1) {
-        htp_dump_fp32_line(pref, x + (16 * i), n1);
-    }
-}
-
-static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
-    uint32_t n0 = n / 16;
-    uint32_t n1 = n % 16;
-
-    uint32_t i = 0;
-    for (; i < n0; i++) {
-        htp_dump_fp16_line(pref, x + (16 * i), 16);
-    }
-    if (n1) {
-        htp_dump_fp16_line(pref, x + (16 * i), n1);
-    }
-}
-
-#endif /* OPS_UTILS_H */
diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c
index a4399704..be946953 100644
--- a/ggml/src/ggml-hexagon/htp/rope-ops.c
+++ b/ggml/src/ggml-hexagon/htp/rope-ops.c
@@ -2,32 +2,29 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
+
 #include 
-#include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+#include "hex-fastdiv.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
-// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h
+// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h
 #define HTP_ROPE_TYPE_NORMAL 0
 #define HTP_ROPE_TYPE_NEOX   2
 
+#define HTP_ROPE_SPAD_NROWS  16
+#define HTP_ROPE_SPAD_BLOCK  (HTP_ROPE_SPAD_NROWS/2)
+
 #define htp_rope_preamble              \
     const uint32_t ne00 = src0->ne[0]; \
     const uint32_t ne01 = src0->ne[1]; \
@@ -49,7 +46,7 @@
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-struct rope_th_ctx {
+struct htp_rope_context {
     int32_t n_dims;
     int32_t mode;
     int32_t n_ctx_orig;
@@ -64,7 +61,19 @@ struct rope_th_ctx {
     float theta_scale;
     float corr_dims[2];
 
+    uint32_t src0_nrows_per_thread;
+    size_t spad_stride;
+
     struct htp_ops_context * octx;
+
+    size_t src0_row_size;
+    size_t dst_row_size;
+    size_t src0_row_size_aligned;
+    size_t dst_row_size_aligned;
+    size_t theta_cache_offset;
+    uint32_t src0_nrows;
+
+    uint64_t t_start;
 };
 
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
@@ -124,64 +133,23 @@ static void rope_corr_dims(int     n_dims,
     dims[1]     = MIN(n_dims - 1, end);
 }
 
-static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
-    memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
+static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
+    const HVX_Vector * restrict vsrc   = (const HVX_Vector *) src0;
+    const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
+    HVX_Vector       * restrict vdst   = (HVX_Vector *) dst;
 
-    const int32_t * op_params = &octx->op_params[0];
+    uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
 
-    rope_ctx->n_dims     = ((const int32_t *) op_params)[1];
-    rope_ctx->mode       = ((const int32_t *) op_params)[2];
-    rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
+    uint32_t he = ne / 2;         // half_dims offset in elements
+    uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
 
-    memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
-    memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
-    memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
-    memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
-    memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
-    memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
-    memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
+    #pragma unroll(2)
+    for (uint32_t i = 0; i < nvec; i += 2) {
+        HVX_Vector v0 = vsrc[i/2+0];
+        HVX_Vector v1 = vsrc[i/2+hv];
 
-    rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
-
-    rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
-                   rope_ctx->beta_slow, rope_ctx->corr_dims);
-
-    rope_ctx->octx = octx;
-    FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
-         rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
-}
-
-static void hvx_calc_rope_neox_f32(const float * restrict src0,
-                                   float * restrict dst,
-                                   const int num_elems,
-                                   const float * restrict theta_cache) {
-    // for (int i = 0; i < num_elems; i += 2) {
-    //const float cos_theta = theta_cache[i + 0];
-    //const float sin_theta = theta_cache[i + 1];
-
-    //const float x0 = src[0];
-    //const float x1 = src[num_elems/2];
-
-    //dst[0] = x0*cos_theta - x1*sin_theta;
-    //dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
-
-    //src += 1;
-    //dst += 1;
-    // }
-
-    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
-    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
-    uint8_t * restrict dst_curr         = (uint8_t *) dst;
-
-    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
-    int half_size = (sizeof(float) * (num_elems / 2));
-
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
-        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
-
-        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
-        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+        HVX_Vector v2 = vtheta[i+0];
+        HVX_Vector v3 = vtheta[i+1];
 
         HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
 
@@ -193,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
         HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
         HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
 
-        *(HVX_Vector *) dst_curr               = Q6_Vsf_equals_Vqf32(v4);
-        *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
+        vdst[i/2+0]  = Q6_Vsf_equals_Vqf32(v4);
+        vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
+    }
 
-        src0_curr += VLEN;
-        theta_curr += 2 * VLEN;
-        dst_curr += VLEN;
+    for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
+        const float cos_theta = theta_cache[i+0];
+        const float sin_theta = theta_cache[i+1];
+        float x0 = src0[i/2];
+        float x1 = src0[i/2 + he];
+        dst[i/2]      = x0 * cos_theta - x1 * sin_theta;
+        dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
     }
 }
 
-static void hvx_calc_rope_f32(const float * restrict src0,
-                              float * restrict dst,
-                              const int num_elems,
-                              const float * restrict theta_cache) {
-    // for (int i = 0; i < num_elems; i += 2) {
-    //const float cos_theta = theta_cache[i + 0];
-    //const float sin_theta = theta_cache[i + 1];
+static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
+    const HVX_Vector * restrict vsrc   = (const HVX_Vector *) src0;
+    const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
+    HVX_Vector       * restrict vdst   = (HVX_Vector *) dst;
 
-    //const float x0 = src[0];
-    //const float x1 = src[1];
+    uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
 
-    //dst[0] = x0*cos_theta - x1*sin_theta;
-    //dst[1] = x0*sin_theta + x1*cos_theta;
+    #pragma unroll(2)
+    for (uint32_t i = 0; i < nvec; i+=2) {
+        HVX_Vector v0 = vsrc[i+0];
+        HVX_Vector v1 = vsrc[i+1];
 
-    //src += 2;
-    //dst += 2;
-    // }
-
-    const uint8_t * restrict src0_curr  = (const uint8_t *) src0;
-    const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
-    uint8_t * restrict dst_curr         = (uint8_t *) dst;
-
-    int step_of_1 = num_elems >> 6;  // 6 because we process two vectors at once
-
-    for (int i = 0; i < step_of_1; i++) {
-        HVX_Vector v0 = *(HVX_Vector *) src0_curr;
-        HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
-
-        HVX_Vector v2 = *(HVX_Vector *) theta_curr;
-        HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+        HVX_Vector v2 = vtheta[i+0];
+        HVX_Vector v3 = vtheta[i+1];
 
         HVX_VectorPair vx0_x1   = Q6_W_vdeal_VVR(v1, v0, -4);  // vx0_x1[0] = x0, vx0_x1[1] = x1
         HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4);  // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
@@ -246,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0,
 
         HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
 
-        *(HVX_Vector *) dst_curr          = Q6_V_lo_W(vstore);
-        *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
+        vdst[i+0] = Q6_V_lo_W(vstore);
+        vdst[i+1] = Q6_V_hi_W(vstore);
+    }
 
-        src0_curr += 2 * VLEN;
-        theta_curr += 2 * VLEN;
-        dst_curr += 2 * VLEN;
+    for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
+        const float cos_theta = theta_cache[i+0];
+        const float sin_theta = theta_cache[i+1];
+        float x0 = src0[i+0];
+        float x1 = src0[i+1];
+        dst[i+0] = x0 * cos_theta - x1 * sin_theta;
+        dst[i+1] = x0 * sin_theta + x1 * cos_theta;
     }
 }
 
-static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
-                         const uint32_t       ir0,
-                         const uint32_t       ir1,
-                         int                  nth,
-                         int                  ith,
-                         const int            opt_path) {
-    struct htp_ops_context * octx = rope_ctx->octx;
+static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
+                   uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
+    #pragma unroll(4)
+    for (uint32_t i = 0; i < nr; i++) {
+        float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
+        float * s = (float *) (src + i * rctx->src0_row_size_aligned);
+
+        hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
+
+        // fill the remain channels with data from src tensor
+        if (rctx->n_dims < ne0) {
+            hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
+        }
+    }
+}
+
+static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
+                   uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
+    #pragma unroll(4)
+    for (uint32_t i = 0; i < nr; i++) {
+        float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
+        float * s = (float *) (src + i * rctx->src0_row_size_aligned);
+
+        hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
+
+        // fill the remain channels with data from src tensor
+        if (rctx->n_dims < ne0) {
+            hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
+        }
+    }
+}
+
+static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_rope_context * rctx = (struct htp_rope_context *) data;
+    struct htp_ops_context * octx = rctx->octx;
 
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
     const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
-    const int32_t mode    = rope_ctx->mode;
-    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;
-
     htp_rope_preamble;
 
-    const int32_t * pos = (const int32_t *) src1->data;
-
-    float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
-
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        freq_factors = (const float *) src2->data;
-    }
-
-    const uint32_t i1_end       = MIN(ir1, ne1);
-    const int32_t  half_dims    = rope_ctx->n_dims / 2;
-    const size_t   remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
-    for (uint32_t i3 = 0; i3 < ne3; i3++) {      // batch
-        for (uint32_t i2 = 0; i2 < ne2; i2++) {  // seq-len
-            const int32_t p = pos[i2];
-
-            rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
-                            rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
-
-            for (uint32_t i1 = ir0; i1 < i1_end; i1++) {  // attn-heads
-                const float * src      = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
-                float *       dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
-
-                const float * src_loc      = src;
-                float *       dst_data_loc = dst_data;
-
-                if (1 == opt_path) {
-                    if (is_neox) {
-                        hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
-                    } else {
-                        hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
-                    }
-
-                    src_loc += rope_ctx->n_dims;
-                    dst_data_loc += rope_ctx->n_dims;
-                } else {
-                    for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
-                        const float cos_theta = wp0[i0 + 0];
-                        const float sin_theta = wp0[i0 + 1];
-
-                        if (is_neox) {
-                            const float x0 = src_loc[0];
-                            const float x1 = src_loc[half_dims];
-
-                            dst_data_loc[0]         = x0 * cos_theta - x1 * sin_theta;
-                            dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
-
-                            src_loc += 1;
-                            dst_data_loc += 1;
-                        } else {
-                            const float x0 = src_loc[0];
-                            const float x1 = src_loc[1];
-
-                            dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
-                            dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
-
-                            src_loc += 2;
-                            dst_data_loc += 2;
-                        }
-                    }
-
-                    src_loc += (is_neox ? half_dims : 0);
-                    dst_data_loc += (is_neox ? half_dims : 0);
-                }
-
-                // TODO: use simd to speed up the remaining elements copy
-                memcpy(dst_data_loc, src_loc, remain_bytes);
-            }
-        }
-    }
-}
-
-static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
-    struct htp_ops_context * octx = rope_ctx->octx;
-
-    const struct htp_tensor * src0 = &octx->src0;
-    const struct htp_tensor * src1 = &octx->src1;
-    struct htp_tensor *       dst  = &octx->dst;
-
-    htp_rope_preamble;
-
-    const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
-    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+    const uint32_t src0_nrows = rctx->src0_nrows;
+    const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -365,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
         return;
     }
 
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
+    uint64_t tt = HAP_perf_get_qtimer_count();
 
-    int is_aligned = 1;
-    int opt_path   = 0;
-    if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
-        (0 == htp_is_aligned((void *) dst->data, VLEN))) {
-        FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
-        is_aligned = 0;
-    }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
+    const int32_t mode    = rctx->mode;
+    const bool    is_neox = mode & HTP_ROPE_TYPE_NEOX;
+
+    // VTCM setup
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    float *   theta_cache    = (float *) (src0_spad_base);
+              src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
+    uint8_t * dst_spad_base  = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+
+    dma_queue * dma_queue = octx->ctx->dma[ith];
+    const int32_t * pos = (const int32_t *) src1->data;
+    const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
+
+    uint32_t ir = 0;
+    uint32_t prev_i2 = (uint32_t) -1;
+
+    for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
+        for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
+            for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
+                if (ir < src0_start_row) { ir++; i1++; continue; }
+                if (ir >= src0_end_row) goto done;
+
+                // Rows in this block
+                const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
+
+                // Depth before prefetch
+                uint32_t dma_depth = dma_queue_depth(dma_queue);
+
+                // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
+                //             (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
+
+                // Prefetch loop
+                for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
+                    pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
+
+                    uint32_t pi1 = i1 + pr;
+                    uint32_t pir = ir + pr;
+
+                    // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
+                    dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
+
+                    const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
+                          uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
+                    dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
+                        rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
+
+                    // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
+                }
+
+                // Update theta cache
+                if (i2 != prev_i2) {
+                    prev_i2 = i2;
+
+                    const int32_t p = pos[i2];
+                    rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
+
+                    // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
+                    //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
+                }
+
+                // Skip DMA transactions from prev block (if any)
+                // No need to wait for these since the DMA is setup for in-order processing
+                for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
+
+                // Compute loop
+                for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
+                    // Number of rows to compute
+                    cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
+
+                    uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
+                    uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
+
+                    // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
+                    //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
+
+                    if (is_neox) {
+                        rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
+                    } else {
+                        rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
+                    }
+
+                    uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
+                    dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
+
+                    // Prefetch more rows (if any)
+                    if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
+                        uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
+                        uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
+                        uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
+
+                        const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
+                        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
+                            rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
+
+                        // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
+                    }
+                }
+            }
+        }
     }
 
-    rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
+done:
+    dma_queue_flush(dma_queue);
+    tt = HAP_perf_get_qtimer_count() - tt;
 
-    t2 = HAP_perf_get_qtimer_count();
-
-    FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
-         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
-}
-
-static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
-    struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
-
-    rope_job_f32_per_thread(rope_ctx, n, i);
+    FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
 }
 
 static int execute_op_rope_f32(struct htp_ops_context * octx) {
@@ -401,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
     const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t op_func;
-    const char *      op_type = NULL;
-
-    struct rope_th_ctx rope_ctx;
+    const char * op_type = "rope-f32";
 
     switch (octx->op) {
         case HTP_OP_ROPE:
-            op_func = rope_job_dispatcher_f32;
-            op_type = "rope-f32";
-
-            init_rope_ctx(&rope_ctx, octx);
             break;
 
         default:
@@ -419,52 +400,81 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const uint32_t n_threads = octx->n_threads;
+    const uint32_t ne0 = dst->ne[0];
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
 
     const size_t src0_row_size = src0->nb[1];
-    const size_t src1_row_size = src0_row_size;
     const size_t dst_row_size  = dst->nb[1];
 
-    // VTCM scratchpads for all tensors
-    // N rows per thread, padded to HVX vector size
-    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
-    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
+    // Aligned row sizes for VTCM
+    const size_t src0_row_size_aligned    = hex_round_up(src0_row_size, VLEN);
+    const size_t dst_row_size_aligned     = hex_round_up(dst_row_size, VLEN);
+    const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
 
-    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+    // Calculate spad sizes per thread
+    size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
+    size_t dst_spad_per_thread  = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
+    size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
 
-    if (src2->ne[0]) {
-        FARF(HIGH,
-             "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
-             "dst-spad-size %u\n",
-             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
-             src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
-             dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
-    } else {
-        FARF(HIGH,
-             "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
-             op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
-             src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
-             octx->dst_spad.size);
-    }
-
-    // Make sure the reserved vtcm size is sufficient
-    if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
-             spad_size);
+    // Check if we fit in VTCM
+    size_t total_vtcm_needed = spad_per_thread * n_threads;
+    if (octx->ctx->vtcm_size < total_vtcm_needed) {
+        FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
-    octx->src0_spad.data = octx->ctx->vtcm_base;
-    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
-    octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+    // Assign sizes
+    octx->src0_spad.size_per_thread = src0_spad_per_thread;
+    octx->dst_spad.size_per_thread  = dst_spad_per_thread;
+    octx->src0_spad.size = n_threads * src0_spad_per_thread;
+    octx->dst_spad.size  = n_threads * dst_spad_per_thread;
+    octx->src1_spad.size = 0;
 
-    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    // Assign pointers
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = NULL;
+    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
+
+    // Fill context
+    struct htp_rope_context rctx;
+    memset(&rctx, 0, sizeof(struct htp_rope_context));
+
+    rctx.t_start = HAP_perf_get_qtimer_count();
+
+    rctx.octx = octx;
+
+    const int32_t * op_params = &octx->op_params[0];
+    rctx.n_dims     = ((const int32_t *) op_params)[1];
+    rctx.mode       = ((const int32_t *) op_params)[2];
+    rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
+
+    memcpy(&rctx.freq_base,   (int32_t *) op_params + 5,  sizeof(float));
+    memcpy(&rctx.freq_scale,  (int32_t *) op_params + 6,  sizeof(float));
+    memcpy(&rctx.ext_factor,  (int32_t *) op_params + 7,  sizeof(float));
+    memcpy(&rctx.attn_factor, (int32_t *) op_params + 8,  sizeof(float));
+    memcpy(&rctx.beta_fast,   (int32_t *) op_params + 9,  sizeof(float));
+    memcpy(&rctx.beta_slow,   (int32_t *) op_params + 10, sizeof(float));
+    memcpy(&rctx.sections,    (int32_t *) op_params + 11, sizeof(int) * 4);
+
+    rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
+
+    rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
+
+    rctx.src0_row_size = src0_row_size;
+    rctx.dst_row_size  = dst_row_size;
+    rctx.src0_row_size_aligned = src0_row_size_aligned;
+    rctx.dst_row_size_aligned  = dst_row_size_aligned;
+    rctx.theta_cache_offset    = theta_cache_size_aligned;
+
+    rctx.src0_nrows = src0_nrows;
+    rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+
+    FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
+         rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads);
     }
 
     return err;
diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c
index bdd64fcc..4b696774 100644
--- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c
+++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c
@@ -2,24 +2,20 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
 #include 
-#include 
 #include 
-#include 
-#include 
+
 #include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
 #define set_rows_preamble \
     const uint32_t ne00 = octx->src0.ne[0]; \
@@ -47,11 +43,21 @@
                                             \
     const uint32_t nr  = ne01;
 
-static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+struct htp_set_rows_context {
+    struct htp_ops_context * octx;
+    struct fastdiv_values div_ne12;
+    struct fastdiv_values div_ne11;
+    uint32_t src0_nrows_per_thread;
+};
+
+static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
+    struct htp_ops_context * octx = srctx->octx;
+
     set_rows_preamble;
 
     // parallelize by rows of src0
-    const uint32_t dr  = octx->src0_nrows_per_thread;
+    const uint32_t dr  = srctx->src0_nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 
@@ -60,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
     for (uint32_t i03 = 0; i03 < ne03; ++i03) {
         for (uint32_t i02 = 0; i02 < ne02; ++i02) {
             for (uint32_t i = ir0; i < ir1; ++i) {
-                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
-                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
+                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
                 const uint32_t i10 = i;
 
                 const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@@ -76,19 +82,20 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
                 const uintptr_t dst_ptr  = octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;
 
                 // copy row
-                hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+                hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
             }
         }
     }
-
-    return HTP_STATUS_OK;
 }
 
-static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
+    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
+    struct htp_ops_context * octx = srctx->octx;
+
     set_rows_preamble;
 
     // parallelize by rows of src0
-    const uint32_t dr  = octx->src0_nrows_per_thread;
+    const uint32_t dr  = srctx->src0_nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
 
@@ -97,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
     for (uint32_t i03 = 0; i03 < ne03; ++i03) {
         for (uint32_t i02 = 0; i02 < ne02; ++i02) {
             for (uint32_t i = ir0; i < ir1; ++i) {
-                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
-                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
+                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
                 const uint32_t i10 = i;
 
                 const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
@@ -112,25 +119,17 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
                 const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
                 uint8_t*       dst_ptr  = (uint8_t *)       octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;
 
-                hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
+                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
             }
         }
     }
-
-    return HTP_STATUS_OK;
-}
-
-static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
-    set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
-}
-
-static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
-    set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
 }
 
 int op_set_rows(struct htp_ops_context * octx) {
     set_rows_preamble;
 
+    const uint32_t n_threads = MIN(nr, octx->n_threads);
+
     if (octx->src0.type != HTP_TYPE_F32) {
         return HTP_STATUS_NO_SUPPORT;
     }
@@ -147,18 +146,19 @@ int op_set_rows(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
-    octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
+    struct htp_set_rows_context srctx;
+    srctx.octx = octx;
+    srctx.div_ne12 = init_fastdiv_values(ne12);
+    srctx.div_ne11 = init_fastdiv_values(ne11);
 
-    const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
 
     switch(octx->dst.type) {
     case HTP_TYPE_F32:
-        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
         break;
     case HTP_TYPE_F16:
-        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);
         break;
     default:
         return HTP_STATUS_NO_SUPPORT;
diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c
index 80d249a2..8dae7f1e 100644
--- a/ggml/src/ggml-hexagon/htp/softmax-ops.c
+++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c
@@ -2,27 +2,21 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
+
 #include 
-#include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+#include "hex-fastdiv.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
 
 #define htp_softmax_preamble3                              \
     const uint32_t ne00 = src0->ne[0];                     \
@@ -55,7 +49,7 @@
     const uint32_t nb2 = dst->nb[2];                       \
     const uint32_t nb3 = dst->nb[3];
 
-struct softmax_th_ctx {
+struct htp_softmax_context {
     bool     use_f16;
     bool     use_src1;
     uint32_t n_head;
@@ -66,28 +60,48 @@ struct softmax_th_ctx {
     float m0;
     float m1;
 
+    uint32_t src0_nrows_per_thread;
+    struct fastdiv_values fastdiv_ne01;
+    struct fastdiv_values fastdiv_ne02;
+    struct fastdiv_values fastdiv_ne12; // For mask broadcasting
+    struct fastdiv_values fastdiv_ne13; // For mask broadcasting
+    size_t spad_stride;
+
     struct htp_ops_context * octx;
 };
 
-static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
+static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
 
-    memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
+    memset(smctx, 0, sizeof(struct htp_softmax_context));
 
-    memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
-    memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+    memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
+    memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
 
-    softmax_ctx->n_head      = src0->ne[2];
-    softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
+    smctx->n_head      = src0->ne[2];
+    smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
 
-    softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
-    softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
+    smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
+    smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
 
-    softmax_ctx->use_src1 = (src1->ne[0] != 0);
-    softmax_ctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
+    smctx->use_src1 = (src1->ne[0] != 0);
+    smctx->use_f16  = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
 
-    softmax_ctx->octx = octx;
+    smctx->octx = octx;
+
+    // Initialize fastdiv values
+    const uint32_t ne01 = src0->ne[1];
+    const uint32_t ne02 = src0->ne[2];
+
+    if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
+    if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
+
+    const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
+    const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
+
+    if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
+    if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
 }
 
 static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
@@ -100,8 +114,8 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
     uint8_t * restrict dst_curr        = dst;
     const uint8_t * restrict mask_curr = mask;
 
-    HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
-    HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
+    HVX_Vector scale_vec = hvx_vec_splat_f32(scale);
+    HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
 
     int step_of_1 = num_elems >> 5;
 
@@ -134,9 +148,9 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
     HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
 
     HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
-    HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]);
+    HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]);
     HVX_Vector zero_v  = Q6_V_vzero();
-    HVX_Vector one_v   = hvx_vec_splat_fp32(1.0);
+    HVX_Vector one_v   = hvx_vec_splat_f32(1.0);
 
     int step_of_1 = num_elems >> 5;
 
@@ -146,26 +160,24 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
         max_vec       = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
     }
 
-    HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec);
-    max_vec      = hvx_vec_repl4(v);
+    max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
 
     #pragma unroll(4)
     for (int i = 0; i < step_of_1; i++) {
         HVX_Vector v1 = v_src[i];
         HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
 
-        HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2));
+        HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2));
 
         sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
 
         v_pad[i] = v3;
     }
 
-    v       = hvx_vec_qf32_reduce_sum(sum_vec);
-    sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
+    sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
 
     HVX_VectorPred pos_sum   = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
-    HVX_Vector     v4        = hvx_vec_inverse_fp32(sum_vec);
+    HVX_Vector     v4        = hvx_vec_inverse_f32(sum_vec);
     HVX_Vector     scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
 
     #pragma unroll(4)
@@ -181,92 +193,18 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
                              uint8_t * restrict spad,
                              const int   num_elems,
                              const float max) {
-    hvx_sub_scalar_f32(src, max, spad, num_elems);
+    hvx_sub_scalar_f32(spad, src, max, num_elems);
 
     hvx_exp_f32(spad, dst, num_elems, false);
 
-    float sum = hvx_self_sum_f32(dst, num_elems);
+    float sum = hvx_reduce_sum_f32(dst, num_elems);
 
     return sum;
 }
 
-static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
-    struct htp_ops_context * octx = softmax_ctx->octx;
-
-    const struct htp_tensor * src0 = &octx->src0;
-    const struct htp_tensor * src1 = &octx->src1;
-    const struct htp_tensor * dst  = &octx->dst;
-
-    htp_softmax_preamble3;
-
-    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
-    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
-    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * nb1);
-
-    float * wp0 = (float *) src0_spad_data;
-    float * wp1 = (float *) src1_spad_data;
-    float * wp2 = (float *) dst_spad_data;
-
-    for (uint32_t i03 = 0; i03 < ne03; i03++) {
-        for (uint32_t i02 = 0; i02 < ne02; i02++) {
-            for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
-                const uint32_t i11 = i01;
-                const uint32_t i12 = i02 % ne12;
-                const uint32_t i13 = i03 % ne13;
-
-                // ALiBi
-                const uint32_t h = i02;  // head
-
-                const float slope = (softmax_ctx->max_bias > 0.0f) ?
-                                        h < softmax_ctx->n_head_log2 ?
-                                        powf(softmax_ctx->m0, h + 1) :
-                                        powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
-                                        1.0f;
-
-                float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
-                float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-
-                // broadcast the mask across rows
-                __fp16 * mp_f16 = (softmax_ctx->use_src1) ?
-                                      (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
-                                      NULL;
-                float *  mp_f32 = (softmax_ctx->use_src1) ?
-                                      (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
-                                      NULL;
-
-                if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
-                    hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
-                                              (const uint8_t *) mp_f32, slope);
-                } else {
-                    hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
-                    if (mp_f32) {
-                        if (softmax_ctx->use_f16) {
-                            for (int i = 0; i < ne00; ++i) {
-                                wp0[i] += slope * (float) mp_f16[i];
-                            }
-                        } else {
-                            for (int i = 0; i < ne00; ++i) {
-                                wp0[i] += slope * mp_f32[i];
-                            }
-                        }
-                    }
-                }
-
-                if (1 == opt_path) {
-                    hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
-                } else {
-                    float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
-                    float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
-                    sum       = sum > 0.0 ? (1.0 / sum) : 1;
-                    hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
-                }
-            }
-        }
-    }
-}
-
-static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
-    struct htp_ops_context * octx = softmax_ctx->octx;
+static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
+    struct htp_ops_context * octx = smctx->octx;
 
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
@@ -275,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
     htp_softmax_preamble3;
 
     const uint32_t src0_nrows            = ne01 * ne02 * ne03;  // src0 rows
-    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+    const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
 
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
@@ -290,7 +228,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
 
     int is_aligned = 1;
     int opt_path   = 0;
-    if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+    if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) {
         is_aligned = 0;
         FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
     }
@@ -298,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
         opt_path = 1;
     }
 
-    softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
+    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
+    uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
+    uint8_t * dst_spad_data  = octx->dst_spad.data + (ith * smctx->spad_stride);
+
+    float * wp0 = (float *) src0_spad_data;
+    float * wp1 = (float *) src1_spad_data;
+    float * wp2 = (float *) dst_spad_data;
+
+    uint32_t prev_i2 = (uint32_t)-1;
+    float slope = 1.0f;
+
+    for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
+        uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
+        uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
+        uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
+        uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
+
+        // Map to original logic indices
+        // i01 = i1
+        // i02 = i2
+        // i03 = i3
+
+        const uint32_t i11 = i1;
+        // const uint32_t i12 = i2 % ne12;
+        // const uint32_t i13 = i3 % ne13;
+
+        uint32_t i12, i13;
+        if (ne12 == ne02) {
+             i12 = i2;
+        } else {
+             i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
+        }
+
+        if (ne13 == ne03) {
+             i13 = i3;
+        } else {
+             i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
+        }
+
+        // ALiBi
+        if (i2 != prev_i2) {
+            const uint32_t h = i2;  // head
+
+            slope = (smctx->max_bias > 0.0f) ?
+                        h < smctx->n_head_log2 ?
+                        powf(smctx->m0, h + 1) :
+                        powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
+                        1.0f;
+            prev_i2 = i2;
+        }
+
+        float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
+        float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
+
+        // broadcast the mask across rows
+        __fp16 * mp_f16 = (smctx->use_src1) ?
+                              (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+                              NULL;
+        float *  mp_f32 = (smctx->use_src1) ?
+                              (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+                              NULL;
+
+        if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
+            hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
+                                      (const uint8_t *) mp_f32, slope);
+        } else {
+            hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
+            if (mp_f32) {
+                if (smctx->use_f16) {
+                    for (int i = 0; i < ne00; ++i) {
+                        wp0[i] += slope * (float) mp_f16[i];
+                    }
+                } else {
+                    for (int i = 0; i < ne00; ++i) {
+                        wp0[i] += slope * mp_f32[i];
+                    }
+                }
+            }
+        }
+
+        if (1 == opt_path) {
+            hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
+        } else {
+            float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
+            float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
+            sum       = sum > 0.0 ? (1.0 / sum) : 1;
+            hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
+        }
+    }
 
     t2 = HAP_perf_get_qtimer_count();
 
     FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
-         softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
+         smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
          ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
-    struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
-    softmax_job_f32_per_thread(p_softmax_ctx, n, i);
-}
-
 static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     int err = HTP_STATUS_OK;
 
@@ -319,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     const struct htp_tensor * src1 = &octx->src1;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t op_func;
-    const char *      op_type = NULL;
-
-    struct softmax_th_ctx softmax_ctx;
+    struct htp_softmax_context smctx;
+    const char * op_type = "softmax-f32";
 
     switch (octx->op) {
         case HTP_OP_SOFTMAX:
-            op_func = softmax_job_dispatcher_f32;
-            op_type = "softmax-f32";
-
-            init_softmax_ctx(&softmax_ctx, octx);
+            init_softmax_ctx(&smctx, octx);
             break;
 
         default:
@@ -337,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const uint32_t n_threads = octx->n_threads;
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     const size_t src0_row_size = src0->nb[1];
     const size_t src1_row_size = src0_row_size;
@@ -345,9 +362,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
 
     // VTCM scratchpads for all tensors
     // N rows per thread, padded to HVX vector size
-    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
-    octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
+    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
+    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
+    octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
+
+    // Use stride for calculating offset
+    smctx.spad_stride = hex_round_up(src0_row_size, 128);
 
     size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
 
@@ -374,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
     octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
 
-    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
-
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
+        smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+        worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
     }
 
     return err;
diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c
new file mode 100644
index 00000000..b3c1ef95
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c
@@ -0,0 +1,339 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "hex-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+
+#define htp_ssm_conv_tensors_preamble                        \
+    struct htp_tensor * restrict src0    = &octx->src0;      \
+    struct htp_tensor * restrict src1    = &octx->src1;      \
+    struct htp_tensor * restrict dst     = &octx->dst;       \
+    struct htp_spad * restrict src0_spad = &octx->src0_spad; \
+    struct htp_spad * restrict src1_spad = &octx->src1_spad; \
+    struct htp_spad * restrict dst_spad  = &octx->dst_spad;  \
+                                                             \
+    const uint32_t ne00 = src0->ne[0];                       \
+    const uint32_t ne01 = src0->ne[1];                       \
+    const uint32_t ne02 = src0->ne[2];                       \
+    const uint32_t ne03 = src0->ne[3];                       \
+                                                             \
+    const uint32_t ne10 = src1->ne[0];                       \
+    const uint32_t ne11 = src1->ne[1];                       \
+    const uint32_t ne12 = src1->ne[2];                       \
+    const uint32_t ne13 = src1->ne[3];                       \
+                                                             \
+    const uint32_t ne0 = dst->ne[0];                         \
+    const uint32_t ne1 = dst->ne[1];                         \
+    const uint32_t ne2 = dst->ne[2];                         \
+    const uint32_t ne3 = dst->ne[3];                         \
+                                                             \
+    const uint32_t nb00 = src0->nb[0];                       \
+    const uint32_t nb01 = src0->nb[1];                       \
+    const uint32_t nb02 = src0->nb[2];                       \
+    const uint32_t nb03 = src0->nb[3];                       \
+                                                             \
+    const uint32_t nb10 = src1->nb[0];                       \
+    const uint32_t nb11 = src1->nb[1];                       \
+    const uint32_t nb12 = src1->nb[2];                       \
+    const uint32_t nb13 = src1->nb[3];                       \
+                                                             \
+    const uint32_t nb0 = dst->nb[0];                         \
+    const uint32_t nb1 = dst->nb[1];                         \
+    const uint32_t nb2 = dst->nb[2];                         \
+    const uint32_t nb3 = dst->nb[3];
+
+struct htp_ssm_conv_context {
+    struct htp_ops_context * octx;
+    uint32_t nrows_per_thread;
+    uint64_t t_start;
+};
+
+#define htp_ssm_conv_preamble                            \
+    struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \
+    struct htp_ops_context * octx = scctx->octx;         \
+    htp_ssm_conv_tensors_preamble;                       \
+    dma_queue * dma_queue         = octx->ctx->dma[ith];
+
+// Scalar FP32 SSM_CONV implementation
+static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+    htp_ssm_conv_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t d_conv  = src1->ne[0];
+    const uint32_t d_inner = src0->ne[1];
+    const uint32_t n_t     = dst->ne[1];
+    const uint32_t n_s     = dst->ne[2];
+
+    const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension
+    const uint32_t src0_stride_seq   = src0->nb[2] / sizeof(float); // stride for sequence dimension
+    const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension
+    const uint32_t dst_stride_token  = dst->nb[1]  / sizeof(float); // stride for token dimension
+    const uint32_t dst_stride_seq    = dst->nb[2]  / sizeof(float); // stride for sequence dimension
+
+    const float * src0_data = (const float *) src0->data;
+    const float * src1_data = (const float *) src1->data;
+    float *       dst_data  = (float *) dst->data;
+
+    // Calculate row range for this thread
+    const uint32_t d_inner_per_thread = scctx->nrows_per_thread;
+    const uint32_t d_inner_start = d_inner_per_thread * ith;
+    const uint32_t d_inner_end   = MIN(d_inner_start + d_inner_per_thread, d_inner);
+
+    // No work for this thread
+    if (d_inner_start >= d_inner_end) {
+        return;
+    }
+
+    for (uint32_t i3 = 0; i3 < n_s; ++i3) {
+        for (uint32_t i2 = 0; i2 < n_t; ++i2) {
+            for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) {
+                float sumf = 0.0f;
+
+                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
+                    const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq;
+                    const uint32_t src1_idx = i0 + i1 * src1_stride_inner;
+
+                    sumf += src0_data[src0_idx] * src1_data[src1_idx];
+                }
+
+                const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq;
+                dst_data[dst_idx] = sumf;
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
+         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end,
+         src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
+         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension
+static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) {
+    htp_ssm_conv_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const int nc  = src1->ne[0]; // d_conv
+    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
+
+    const uint32_t d_conv  = src1->ne[0];
+    const uint32_t d_inner = src0->ne[1];
+    const uint32_t n_t     = dst->ne[1];
+    const uint32_t n_s     = dst->ne[2];
+
+    const float * src0_data = (const float *) src0->data;
+    const float * src1_data = (const float *) src1->data;
+    float *       dst_data  = (float *) dst->data;
+
+    // Calculate row range for this thread
+    const int dr = scctx->nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = MIN(ir0 + dr, d_inner);
+    const int      ir  = ir1 - ir0;
+
+    if (ir0 >= ir1) {
+        return;  // No work for this thread
+    }
+
+    // src0 and src1 gather offsets
+    uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 };
+    uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 };
+
+    for (uint32_t i = 0; i < VLEN_FP32; ++i) {
+        src0_offsets[i] = i * (ncs)    * sizeof(float);
+        src1_offsets[i] = i * (d_conv) * sizeof(float);
+    }
+
+    const uint32_t src0_gather_len = VLEN * ncs;
+    const uint32_t src1_gather_len = VLEN * d_conv;
+
+    // gather scratchpads
+    HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0);
+    HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN);
+
+    float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]);
+    float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]);
+
+    uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread;
+    uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread;
+
+    // copy src1 workload to VTCM
+    dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir);
+
+    // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir);
+
+    for (uint32_t i3 = 0; i3 < n_s; ++i3) {
+        float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2]));
+
+        // copy src0 workload to VTCM
+        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir);
+
+        // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir);
+
+        dma_queue_flush(dma_queue);
+
+        for (uint32_t i2 = 0; i2 < n_t; ++i2) {
+            float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2]));
+
+            const uint32_t nvec = ir / VLEN_FP32;
+            const uint32_t nloe = ir % VLEN_FP32;
+            uint32_t i1 = 0;
+
+            for (uint32_t vi1 = 0; vi1 < nvec; vi1++) {
+                HVX_Vector acc_vec = Q6_V_vsplat_R(0);
+
+                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
+                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
+                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));
+                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
+                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));
+
+                    HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
+                    acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
+                }
+
+                *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec);
+                i1 += VLEN_FP32;
+            }
+
+            if (nloe) {
+                HVX_Vector acc_vec = Q6_V_vsplat_R(0);
+
+                for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
+                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
+                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));
+                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
+                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));
+
+                    HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
+                    acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
+                }
+
+                hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec));
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n",
+         ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1,
+         src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1],
+         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+int op_ssm_conv_f32(struct htp_ops_context * octx) {
+    htp_ssm_conv_tensors_preamble;
+
+    if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) {
+        FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported");
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    struct htp_ssm_conv_context scctx = { 0 };
+    scctx.octx = octx;
+
+    const uint32_t d_conv  = src1->ne[0];
+    const uint32_t d_inner = src0->ne[1];
+    const uint32_t n_t     = dst->ne[1];  // tokens per sequence
+    const uint32_t n_s     = dst->ne[2];  // number of sequences in the batch
+
+    const uint32_t n_threads = MIN(octx->n_threads, d_inner);
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        uint32_t use_hvx = 0;
+        if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) {
+            int is_aligned = hex_is_aligned((void *) src0->data, VLEN) &&
+                             hex_is_aligned((void *) src1->data, VLEN) &&
+                             hex_is_aligned((void *) dst->data, VLEN);
+
+            if (is_aligned) {
+                use_hvx = 1;
+            }
+        }
+
+        if (use_hvx) {
+            scctx.nrows_per_thread  = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread
+            scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even
+
+            octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256);
+            octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256);
+            octx->dst_spad.size_per_thread  = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256);
+
+            octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads;
+            octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads;
+            octx->dst_spad.size  = octx->dst_spad.size_per_thread  * n_threads;
+
+            // Compute gather scratchpad size for src0 and src1
+            const size_t gather_spad_size = n_threads * VLEN * 2;
+
+            octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size;
+            octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+            octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
+
+            FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n",
+                gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread,
+                octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size,
+                octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data);
+
+            const size_t total_spad_size =
+                gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+
+            if (total_spad_size > octx->ctx->vtcm_size) {
+                FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size,
+                     octx->ctx->vtcm_size);
+                use_hvx = 0;
+            }
+        }
+
+        FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0],
+             src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0],
+             dst->ne[1], dst->ne[2], dst->ne[3], use_hvx);
+
+        if (use_hvx) {
+            worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads);
+        } else {
+            worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads);
+        }
+    }
+
+    return HTP_STATUS_OK;
+}
+
+int op_ssm_conv(struct htp_ops_context * octx) {
+    int                 err = HTP_STATUS_OK;
+    struct htp_tensor * dst = &octx->dst;
+
+    switch (dst->type) {
+        case HTP_TYPE_F32:
+            err = op_ssm_conv_f32(octx);
+            break;
+        default:
+            err = HTP_STATUS_NO_SUPPORT;
+            break;
+    }
+
+    return err;
+}
diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c
new file mode 100644
index 00000000..352650b6
--- /dev/null
+++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c
@@ -0,0 +1,128 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include 
+#include 
+
+#include 
+#include 
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#define sum_rows_preamble                       \
+    struct htp_tensor *src0 =  &octx->src0;\
+    struct htp_tensor *dst  = &octx->dst;  \
+                                           \
+    const uint32_t ne00 = src0->ne[0];     \
+    const uint32_t ne01 = src0->ne[1];     \
+    const uint32_t ne02 = src0->ne[2];     \
+    const uint32_t ne03 = src0->ne[3];     \
+                                           \
+    const uint32_t nb00 = src0->nb[0];     \
+    const uint32_t nb01 = src0->nb[1];     \
+    const uint32_t nb02 = src0->nb[2];     \
+    const uint32_t nb03 = src0->nb[3];     \
+                                           \
+    const uint32_t  ne0 = dst->ne[0];      \
+    const uint32_t  ne1 = dst->ne[1];      \
+    const uint32_t  ne2 = dst->ne[2];      \
+    const uint32_t  ne3 = dst->ne[3];      \
+                                           \
+    const uint32_t  nb0 = dst->nb[0];      \
+    const uint32_t  nb1 = dst->nb[1];      \
+    const uint32_t  nb2 = dst->nb[2];      \
+    const uint32_t  nb3 = dst->nb[3];      \
+
+struct sum_rows_context {
+    const uint8_t * src_data;
+    uint8_t       * dst_data;
+    uint32_t        ne00;
+    size_t          src_stride;
+    size_t          dst_stride;
+    uint32_t        rows_per_thread;
+    uint32_t        total_rows;
+    bool            opt_path;
+};
+
+static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
+    const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
+
+    const uint32_t rows_per_thread = smctx->rows_per_thread;
+    const uint32_t total_rows      = smctx->total_rows;
+
+    const uint32_t start_row = rows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + rows_per_thread, total_rows);
+
+    if (start_row >= end_row) {
+        return;
+    }
+
+    const size_t   src_stride = smctx->src_stride;
+    const size_t   dst_stride = smctx->dst_stride;
+    const uint32_t ne00       = smctx->ne00;
+    const bool     opt_path   = smctx->opt_path;
+
+    const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
+    float       * restrict dst_th = (float *)       (smctx->dst_data + (start_row * dst_stride));
+
+    // Calculate actual number of rows for this thread
+    const uint32_t n_rows = end_row - start_row;
+
+    for (uint32_t ir = 0; ir < n_rows; ir++) {
+        const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
+
+        if (ir + 1 < n_rows) {
+            hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
+        }
+
+        if (opt_path) {
+            dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
+        } else {
+            dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
+        }
+    }
+}
+
+int op_sum_rows(struct htp_ops_context * octx) {
+    sum_rows_preamble;
+
+    if (octx->src0.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+        return HTP_STATUS_OK;
+    }
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+    const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
+    const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+
+    bool opt_path = false;
+    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
+        opt_path = true;
+    }
+
+    struct sum_rows_context smctx = {
+        .src_data        = (const uint8_t *) src0->data,
+        .dst_data        = (uint8_t *) dst->data,
+        .ne00            = ne00,
+        .src_stride      = nb01,
+        .dst_stride      = nb1,
+        .rows_per_thread = rows_per_thread,
+        .total_rows      = src0_nrows,
+        .opt_path        = opt_path,
+    };
+
+    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads);
+
+    return HTP_STATUS_OK;
+}
diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c
index 8ed1e5b6..5bbd5040 100644
--- a/ggml/src/ggml-hexagon/htp/unary-ops.c
+++ b/ggml/src/ggml-hexagon/htp/unary-ops.c
@@ -2,28 +2,42 @@
 #pragma clang diagnostic ignored "-Wunused-function"
 #pragma clang diagnostic ignored "-Wunused-but-set-variable"
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
-
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
+
 #include 
-#include 
 #include 
 
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
 #include "htp-ctx.h"
-#include "htp-dma.h"
 #include "htp-msg.h"
 #include "htp-ops.h"
-#include "hvx-utils.h"
-#include "ops-utils.h"
+
+struct htp_unary_context {
+    struct htp_ops_context * octx;
+
+    // Precomputed values
+    const uint8_t *           data_src0;
+    uint8_t *                 data_dst;
+
+    size_t                    src0_row_size;
+    size_t                    dst_row_size;
+
+    size_t                    src0_row_size_aligned;
+    size_t                    dst_row_size_aligned;
+
+    size_t                    src0_spad_half_size;
+    size_t                    dst_spad_half_size;
+
+    uint32_t                  block;
+    uint32_t                  src0_nrows;
+    uint32_t                  src0_nrows_per_thread;
+    uint32_t                  nc;
+};
 
 #define htp_unary_preamble            \
     const uint32_t ne00 = src->ne[0]; \
@@ -55,7 +69,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
     HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
 
     HVX_Vector sum_v     = Q6_V_vsplat_R(0x00000000);
-    HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
+    HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
 
     int step_of_1 = num_elems >> 5;
     #pragma unroll(4)
@@ -65,15 +79,14 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
         sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
     }
 
-    HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
-    sum_v                  = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
+    sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
 
-    HVX_Vector t_v            = hvx_vec_splat_fp32((float) num_elems);
-    HVX_Vector denom_v        = hvx_vec_inverse_fp32(t_v);
+    HVX_Vector t_v            = hvx_vec_splat_f32((float) num_elems);
+    HVX_Vector denom_v        = hvx_vec_inverse_f32(t_v);
     HVX_Vector mean_v         = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
     HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
 
-    HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
+    HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
 
     #pragma unroll(4)
     for (int i = 0; i < step_of_1; i++) {
@@ -83,78 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
     }
 }
 
-static void scale_htp_f32(const float * restrict src,
-                          float * restrict dst,
-                          uint8_t * restrict spad,
-                          const uint32_t num_rows,
-                          const uint32_t row_elems,
-                          const size_t   row_size,
-                          int32_t *      op_params,
-                          int            opt_path) {
+static void scale_f32(const float * restrict src,
+                      float * restrict dst,
+                      uint8_t * restrict spad,
+                      const uint32_t num_rows,
+                      const uint32_t row_elems,
+                      const size_t   row_size,
+                      int32_t *      op_params) {
     float scale = 0.f;
     float bias  = 0.f;
     memcpy(&scale, &op_params[0], sizeof(float));
     memcpy(&bias,  &op_params[1], sizeof(float));
 
     for (uint32_t ir = 0; ir < num_rows; ir++) {
-        const float * restrict src_local = src + (ir * row_elems);
-        float * restrict dst_local       = dst + (ir * row_elems);
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
 
-        if (ir + 1 < num_rows) {
-            htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
-        }
-
-        hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
+        hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
     }
 }
 
-static void rms_norm_htp_f32(const float * restrict src,
-                             float * restrict dst,
-                             uint8_t * restrict spad,
-                             const uint32_t num_rows,
-                             const uint32_t row_elems,
-                             const size_t   row_size,
-                             int32_t *      op_params,
-                             int            opt_path) {
+static void rms_norm_f32(const float * restrict src,
+                         float * restrict dst,
+                         uint8_t * restrict spad,
+                         const uint32_t num_rows,
+                         const uint32_t row_elems,
+                         const size_t   row_size,
+                         int32_t *      op_params) {
     float epsilon = 0.f;
     memcpy(&epsilon, op_params, sizeof(float));
 
     for (uint32_t ir = 0; ir < num_rows; ir++) {
-        const float * restrict src_local = src + (ir * row_elems);
-        float * restrict dst_local       = dst + (ir * row_elems);
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
 
-        if (ir + 1 < num_rows) {
-            htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
-        }
-
-        if (1 == opt_path) {
-            hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
-        } else {
-            float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
-
-            const float mean  = sum / row_elems;
-            const float scale = 1.0f / sqrtf(mean + epsilon);
-
-            hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
-        }
+        hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
     }
 }
 
-static void unary_job_f32_per_thread(const struct htp_tensor * src,
-                                     struct htp_tensor *       dst,
-                                     uint8_t *                 spad,
-                                     int                       htp_op,
-                                     int32_t *                 op_params,
-                                     uint32_t                  nth,
-                                     uint32_t                  ith,
-                                     uint32_t                  src0_nrows_per_thread) {
+static void sqr_f32(const float * restrict src,
+                    float * restrict dst,
+                    uint8_t * restrict spad,
+                    const uint32_t num_rows,
+                    const uint32_t row_elems,
+                    const size_t   row_size,
+                    int32_t *      op_params) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
+
+        hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+    }
+}
+
+static void sqrt_f32(const float * restrict src,
+                     float * restrict dst,
+                     uint8_t * restrict spad,
+                     const uint32_t num_rows,
+                     const uint32_t row_elems,
+                     const size_t   row_size,
+                     int32_t *      op_params) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
+
+        hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+    }
+}
+
+static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
+    struct htp_ops_context * octx = uctx->octx;
+    const struct htp_tensor * src = &octx->src0;
+    const struct htp_tensor * dst = &octx->dst;
+
     htp_unary_preamble;
 
-    const size_t src0_row_size = nb01;
-    const size_t dst_row_size  = nb1;
+    int                       htp_op = octx->op;
+    int32_t *                 op_params = octx->op_params;
+    uint32_t                  src0_nrows_per_thread = uctx->src0_nrows_per_thread;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    const size_t src0_row_size = uctx->src0_row_size;
+    const size_t dst_row_size  = uctx->dst_row_size;
 
+    const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
+    const size_t dst_row_size_aligned  = uctx->dst_row_size_aligned;
+
+    const uint32_t src0_nrows = uctx->src0_nrows;
     const uint32_t src0_start_row = src0_nrows_per_thread * ith;
     const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
 
@@ -166,66 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
-    int is_aligned = 1;
-    int opt_path   = 0;
-    if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
-        is_aligned = 0;
-        FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
-    }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
+    const uint8_t * restrict data_src = uctx->data_src0;
+    uint8_t * restrict       data_dst = uctx->data_dst;
+
+    uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_data  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half_size = uctx->src0_spad_half_size;
+    size_t dst_spad_half_size  = uctx->dst_spad_half_size;
+
+    const int BLOCK = uctx->block;
+    if (BLOCK == 0) {
+        FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+             octx->src0_spad.size_per_thread, src0_row_size_aligned);
+        return;
     }
 
-    const uint8_t * restrict data_src = (const uint8_t *) src->data;
-    uint8_t * restrict data_dst       = (uint8_t *) dst->data;
+    dma_queue * dma_queue = octx->ctx->dma[ith];
 
-    const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
-    float * restrict dst_th       = (float *) (data_dst + (src0_start_row * dst_row_size));
-    uint8_t * restrict spad_th    = (uint8_t *) spad + (ith * nb01);
+    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
 
-    switch (htp_op) {
-        case HTP_OP_RMS_NORM:
-            rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
-            break;
-        case HTP_OP_SCALE:
-            scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
-            break;
+        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+        dma_queue_push_vtcm_to_ddr(dma_queue,
+            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+            dst_row_size, dst_row_size_aligned, 0);
 
-        default:
-            break;
+        dma_queue_push_ddr_to_vtcm(dma_queue,
+            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
+            src0_row_size_aligned, src0_row_size, block_size);
     }
 
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;
+        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+        // Process block in VTCM
+        switch (htp_op) {
+            case HTP_OP_RMS_NORM:
+                rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_SCALE:
+                scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_SQR:
+                sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_SQRT:
+                sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            default:
+                break;
+        }
+
+        dma_queue_push_vtcm_to_ddr(dma_queue,
+            dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
+            dst_row_size, dst_row_size_aligned, block_size);
+
+        // prefetch N+2 loop iteration if any
+        const uint32_t pref_block = (ir + BLOCK * 2);
+        if (pref_block < src0_end_row) {
+            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+            dma_queue_push_ddr_to_vtcm(dma_queue,
+                dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
+                src0_row_size_aligned, src0_row_size, pref_block_size);
+        }
+    }
+
+    dma_queue_flush(dma_queue);
+
     t2 = HAP_perf_get_qtimer_count();
 
-    FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
+    FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
          src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
          dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
-
-    unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
-                             octx->src0_nrows_per_thread);
-}
-
 static int execute_op_unary_f32(struct htp_ops_context * octx) {
     int err = HTP_STATUS_OK;
 
     const struct htp_tensor * src0 = &octx->src0;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t unary_op_func;
-    const char *      op_type = NULL;
+    const char * op_type = NULL;
 
     switch (octx->op) {
         case HTP_OP_RMS_NORM:
-            unary_op_func = unary_job_dispatcher_f32;
-            op_type       = "rmsnorm-f32";
+            op_type = "rmsnorm-f32";
             break;
         case HTP_OP_SCALE:
-            unary_op_func = unary_job_dispatcher_f32;
-            op_type       = "scale-f32";
+            op_type = "scale-f32";
+            break;
+        case HTP_OP_SQR:
+            op_type = "sqr-f32";
+            break;
+        case HTP_OP_SQRT:
+            op_type = "sqrt-f32";
             break;
 
         default:
@@ -233,38 +301,65 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const int      n_threads  = octx->n_threads;
     const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     const size_t src0_row_size = src0->nb[1];
     const size_t dst_row_size  = dst->nb[1];
 
-    // VTCM scratchpads for all tensors
-    octx->dst_spad.size  = htp_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
+    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
 
-    size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
+    // VTCM scratchpads for all tensors
+    // N rows per thread, padded to HVX vector size
+    // Double buffering requires 2x size per buffer
+
+    size_t spad_size_per_row   = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
+
+    // Make sure the reserved vtcm size is sufficient
+    if (vtcm_row_per_thread == 0) {
+        FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
+             spad_size_per_row * n_threads);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
+    octx->dst_spad.size_per_thread  = dst_row_size_aligned * vtcm_row_per_thread * 2;
+
+    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
 
     FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
          src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
          octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
 
-    // Make sure the reserved vtcm size is sufficient
-    if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
-             spad_size);
-        return HTP_STATUS_VTCM_TOO_SMALL;
-    }
-
-    octx->src0_spad.data = octx->ctx->vtcm_base;
-    octx->dst_spad.data  = octx->src0_spad.data + octx->src0_spad.size;
-
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
+        struct htp_unary_context uctx = {
+            .octx                  = octx,
+            .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
+            .src0_nrows            = src0_nrows,
 
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+            .data_src0             = (const uint8_t *)src0->data,
+            .data_dst              = (uint8_t *)dst->data,
 
-        worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
+            .src0_row_size         = src0_row_size,
+            .dst_row_size          = dst_row_size,
+
+            .src0_row_size_aligned = src0_row_size_aligned,
+            .dst_row_size_aligned  = dst_row_size_aligned,
+
+            .src0_spad_half_size   = octx->src0_spad.size_per_thread / 2,
+            .dst_spad_half_size    = octx->dst_spad.size_per_thread / 2,
+
+            .block                 = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
+            .nc                    = src0->ne[0],
+        };
+
+        worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
     }
 
     return err;
diff --git a/ggml/src/ggml-hexagon/htp/worker-pool.c b/ggml/src/ggml-hexagon/htp/worker-pool.c
index cd38c212..172e2890 100644
--- a/ggml/src/ggml-hexagon/htp/worker-pool.c
+++ b/ggml/src/ggml-hexagon/htp/worker-pool.c
@@ -7,10 +7,6 @@
 #include 
 #include 
 
-#ifdef HTP_DEBUG
-#    define FARF_HIGH 1
-#endif
-
 #include "HAP_farf.h"
 
 #define WORKER_THREAD_STACK_SZ  (2 * 16384)
@@ -60,7 +56,7 @@ static void worker_pool_main(void * context) {
         unsigned int n = atomic_load(&pool->n_jobs);
         unsigned int i = atomic_fetch_add(&pool->next_job, 1);
         if (i >= n) {
-            // Spurios wakeup
+            // Spurious wakeup
             continue;
         }
 
diff --git a/ggml/src/ggml-hexagon/libdl.h b/ggml/src/ggml-hexagon/libdl.h
new file mode 100644
index 00000000..8ca5016f
--- /dev/null
+++ b/ggml/src/ggml-hexagon/libdl.h
@@ -0,0 +1,79 @@
+#pragma once
+
+#ifdef _WIN32
+#   define WIN32_LEAN_AND_MEAN
+#   ifndef NOMINMAX
+#       define NOMINMAX
+#   endif
+#   include 
+#   include 
+#else
+#    include 
+#    include 
+#endif
+#include 
+
+namespace fs = std::filesystem;
+
+#ifdef _WIN32
+
+using dl_handle = std::remove_pointer_t;
+
+struct dl_handle_deleter {
+    void operator()(HMODULE handle) {
+        FreeLibrary(handle);
+    }
+};
+
+static inline dl_handle * dl_load_library(const fs::path & path) {
+    // suppress error dialogs for missing DLLs
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    HMODULE handle = LoadLibraryW(path.wstring().c_str());
+
+    SetErrorMode(old_mode);
+
+    return handle;
+}
+
+static inline void * dl_get_sym(dl_handle * handle, const char * name) {
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    void * p = (void *) GetProcAddress(handle, name);
+
+    SetErrorMode(old_mode);
+
+    return p;
+}
+
+static inline const char * dl_error() {
+    return "";
+}
+
+#else
+
+using dl_handle = void;
+
+struct dl_handle_deleter {
+    void operator()(void * handle) {
+        dlclose(handle);
+    }
+};
+
+static inline dl_handle * dl_load_library(const fs::path & path) {
+    dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
+    return handle;
+}
+
+static inline void * dl_get_sym(dl_handle * handle, const char * name) {
+    return dlsym(handle, name);
+}
+
+static inline const char * dl_error() {
+    const char *rslt = dlerror();
+    return rslt != nullptr ? rslt : "";
+}
+
+#endif
diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf
new file mode 100644
index 00000000..656d2d9a
--- /dev/null
+++ b/ggml/src/ggml-hexagon/libggml-htp.inf
@@ -0,0 +1,38 @@
+[Version]
+Signature   = "$WINDOWS NT$"
+Class       = ComputeAccelerator
+ClassGuid   = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C}
+Provider    = %GGML%
+DriverVer   = 01/01/2026,1.0.0.0
+CatalogFile = libggml-htp.cat
+PnpLockDown = 1
+
+[DestinationDirs]
+Drivers_Dir = 6
+
+[SourceDisksNames]
+1 = %DiskId%
+
+[SourceDisksFiles]
+libggml-htp-v68.so = 1
+libggml-htp-v69.so = 1
+libggml-htp-v73.so = 1
+libggml-htp-v75.so = 1
+libggml-htp-v81.so = 1
+
+[ControlFlags]
+ExcludeFromSelect = *
+
+[DefaultInstall.NTarm64]
+CopyFiles=Drivers_Dir
+
+[Drivers_Dir]
+libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+
+[Strings]
+GGML   = 'GGML'
+DiskId = 'GGML HTP library'
diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt
index 23b68899..b44ed0f7 100644
--- a/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ggml/src/ggml-hip/CMakeLists.txt
@@ -11,6 +11,10 @@ endif()
 list(APPEND CMAKE_PREFIX_PATH  ${ROCM_PATH})
 list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
 
+if (NOT DEFINED CMAKE_HIP_FLAGS_DEBUG)
+    set(CMAKE_HIP_FLAGS_DEBUG "-g -O2")
+endif()
+
 # CMake on Windows doesn't support the HIP language yet
 if (WIN32)
     set(CXX_IS_HIPCC TRUE)
@@ -62,6 +66,8 @@ file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
+file(GLOB   SRCS "../ggml-cuda/template-instances/mmf*.cu")
+list(APPEND GGML_SOURCES_ROCM ${SRCS})
 
 if (GGML_CUDA_FA_ALL_QUANTS)
     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 80e0fd2f..92568655 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
     }
 }
 
+static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
+    return t->view_src != NULL;
+}
+
 static inline float ggml_compute_softplus_f32(float input) {
     return (input > 20.0f) ? input : logf(1 + expf(input));
 }
@@ -487,6 +491,61 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
 #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
 #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
 
+// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
+// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float)
+static inline float ggml_ue4m3_to_fp32(uint8_t x) {
+    if (x == 0 || x == 0x7F) {
+        return 0.0f;
+    }
+    int   exp = (x >> 3) & 0xF;
+    int   man = x & 0x7;
+    float raw;
+    if (exp == 0) {
+        raw = ldexpf((float) man, -9);
+    } else {
+        raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
+    }
+    return raw * 0.5f;
+}
+
+static inline uint8_t ggml_fp32_to_ue4m3(float x) {
+    if (!(x > 0.0f)) {
+        return 0;
+    }
+    if (x > 448.0f) {
+        x = 448.0f;
+    }
+    uint32_t bits;
+    memcpy(&bits, &x, 4);
+    int fp32_exp  = ((bits >> 23) & 0xFF) - 127;
+    int fp32_man  = (bits >> 20) & 0x7;
+    int ue4m3_exp = fp32_exp + 7;
+    if (ue4m3_exp <= 0) {
+        // subnormal: value = man * 2^-9, man = round(x * 2^9)
+        int man = (int) (x * 512.0f + 0.5f);
+        if (man > 7) {
+            man = 7;
+        }
+        if (man < 1) {
+            return 0;
+        }
+        return (uint8_t) man;
+    }
+    if (ue4m3_exp >= 15) {
+        return 0x7E;
+    }
+    int round_bit = (bits >> 19) & 1;
+    int ue4m3_man = fp32_man + round_bit;
+    if (ue4m3_man > 7) {
+        ue4m3_man = 0;
+        ue4m3_exp++;
+        if (ue4m3_exp >= 15) {
+            return 0x7E;
+        }
+    }
+    return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man);
+}
+
 /**
  * Converts brain16 to float32.
  *
@@ -611,6 +670,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in
         if (node->op != ops[i]) {
             return false;
         }
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            return false;
+        }
         if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) {
             return false;
         }
diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt
index 63418fe1..42054d84 100644
--- a/ggml/src/ggml-metal/CMakeLists.txt
+++ b/ggml/src/ggml-metal/CMakeLists.txt
@@ -23,11 +23,6 @@ if (GGML_METAL_NDEBUG)
     add_compile_definitions(GGML_METAL_NDEBUG)
 endif()
 
-# copy metal files to bin directory
-configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)
-configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
-configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
-
 set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
 if (GGML_METAL_EMBED_LIBRARY)
     enable_language(ASM)
@@ -37,12 +32,12 @@ if (GGML_METAL_EMBED_LIBRARY)
     set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
     set(METALLIB_IMPL   "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
 
-    file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
+    file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
 
     # merge ggml-common.h and ggml-metal.metal into a single file
-    set(METALLIB_EMBED_ASM        "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
-    set(METALLIB_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
-    set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
+    set(METALLIB_EMBED_ASM        "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
+    set(METALLIB_SOURCE_EMBED     "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
 
     add_custom_command(
         OUTPUT "${METALLIB_EMBED_ASM}"
@@ -62,6 +57,11 @@ if (GGML_METAL_EMBED_LIBRARY)
 
     target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
 else()
+    # copy metal files to bin directory
+    configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)
+    configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
+    configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
+
     if (GGML_METAL_SHADER_DEBUG)
         # custom command to do the following:
         #   xcrun -sdk macosx metal    -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
@@ -71,7 +71,7 @@ else()
         #       disabling fast math is needed in order to pass tests/test-backend-ops
         # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
         # note: unfortunately, we have to call it default.metallib instead of ggml.metallib
-        #       ref: https://github.com/ggerganov/whisper.cpp/issues/1720
+        #       ref: https://github.com/ggml-org/whisper.cpp/issues/1720
         # note: adding -g causes segmentation fault during compile
         #set(XC_FLAGS -fno-fast-math -fno-inline -g)
         set(XC_FLAGS -fno-fast-math -fno-inline)
diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp
index 95627d38..2eb9820b 100644
--- a/ggml/src/ggml-metal/ggml-metal-common.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-common.cpp
@@ -264,15 +264,26 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector ggml_metal_graph_optimize_reorder(const std::vector capture_scope;
@@ -71,6 +75,10 @@ struct ggml_metal {
     // abort ggml_metal_graph_compute if callback returns true
     ggml_abort_callback abort_callback;
     void *              abort_callback_data;
+
+    // error state - set when a command buffer fails during synchronize
+    // once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
+    bool has_error;
 };
 
 ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
@@ -117,7 +125,11 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
         }
     }
 
-    //const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
+    res->ev_cpy = ggml_metal_device_event_init(dev);
+
+    const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
+
+    snprintf(res->name, sizeof(res->name), "%s", props_dev->name);
 
     res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
@@ -146,10 +158,19 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
     GGML_LOG_INFO("%s: use concurrency    = %s\n", __func__, res->use_concurrency    ? "true" : "false");
     GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
 
-    res->capture_next_compute = false;
+    res->capture_compute = 0;
     res->capture_started = false;
     res->capture_scope = nil;
 
+    {
+        const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE");
+        if (val) {
+            res->capture_compute = atoi(val);
+        }
+    }
+
+    res->has_error = false;
+
     res->gf = nil;
     res->encode_async = nil;
     for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
@@ -206,9 +227,15 @@ void ggml_metal_free(ggml_metal_t ctx) {
 
     dispatch_release(ctx->d_queue);
 
+    ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy);
+
     free(ctx);
 }
 
+const char * ggml_metal_get_name(ggml_metal_t ctx) {
+    return ctx->name;
+}
+
 void ggml_metal_synchronize(ggml_metal_t ctx) {
     // wait for any backend operations to finish
     if (ctx->cmd_buf_last) {
@@ -232,7 +259,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
                 if (status == MTLCommandBufferStatusError) {
                     GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
                 }
-                GGML_ABORT("fatal error");
+                ctx->has_error = true;
+                return;
             }
         }
     }
@@ -248,7 +276,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) {
                 if (status == MTLCommandBufferStatusError) {
                     GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
                 }
-                GGML_ABORT("fatal error");
+
+                // release this and all remaining command buffers before returning
+                for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
+                    [ctx->cmd_bufs_ext[j] release];
+                }
+                [ctx->cmd_bufs_ext removeAllObjects];
+
+                ctx->has_error = true;
+                return;
             }
 
             [cmd_buf release];
@@ -273,8 +309,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor,
         // wrap the source data into a Metal buffer
         id device = ggml_metal_device_get_obj(ctx->dev);
         id buf_src = [device newBufferWithBytes:data
-                                                         length:size
-                                                        options:MTLResourceStorageModeShared];
+                                                    length:size
+                                                   options:MTLResourceStorageModeShared];
 
         GGML_ASSERT(buf_src);
 
@@ -316,9 +352,9 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
     @autoreleasepool {
         id device = ggml_metal_device_get_obj(ctx->dev);
         id buf_dst = [device newBufferWithBytesNoCopy:data
-                                                               length:size
-                                                              options:MTLResourceStorageModeShared
-                                                          deallocator:nil];
+                                                          length:size
+                                                         options:MTLResourceStorageModeShared
+                                                     deallocator:nil];
 
         GGML_ASSERT(buf_dst);
 
@@ -356,9 +392,57 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
     }
 }
 
+bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    @autoreleasepool {
+        struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src);
+        struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst);
+
+        if (bid_src.metal == nil || bid_dst.metal == nil) {
+            return false;
+        }
+
+        // queue the copy operation into the Metal context
+        // this will be queued at the end, after any currently ongoing GPU operations
+        id queue = ggml_metal_device_get_queue(ctx_src->dev);
+        id cmd_buf = [queue commandBuffer];
+        id encoder = [cmd_buf blitCommandEncoder];
+
+        [encoder copyFromBuffer:bid_src.metal
+                   sourceOffset:bid_src.offs
+                       toBuffer:bid_dst.metal
+              destinationOffset:bid_dst.offs
+                           size:ggml_nbytes(src)];
+
+        [encoder endEncoding];
+
+        ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
+        ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
+
+        [cmd_buf commit];
+
+        // do not wait here for completion
+        //[cmd_buf waitUntilCompleted];
+
+        // instead, remember a reference to the command buffer and wait for it later if needed
+        [ctx_src->cmd_bufs_ext addObject:cmd_buf];
+        ctx_src->cmd_buf_last = cmd_buf;
+
+        [cmd_buf retain];
+
+        ggml_metal_event_wait(ctx_dst, ev_cpy);
+
+        return true;
+    }
+}
+
 enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
+    if (ctx->has_error) {
+        GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
+        return GGML_STATUS_FAILED;
+    }
+
     // number of nodes encoded by the main thread (empirically determined)
-    const int n_main = 64;
+    const int n_main = MAX(64, 0.1*gf->n_nodes);
 
     // number of threads in addition to the main thread
     const int n_cb = ctx->n_cb;
@@ -381,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
 
         ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
 
-        const bool use_capture = ctx->capture_next_compute;
+        if (ctx->capture_compute >= 0) {
+            ctx->capture_compute--;
+        }
+
+        const bool use_capture = ctx->capture_compute == 0;
         if (use_capture) {
-            ctx->capture_next_compute = false;
+            ctx->capture_compute = -1;
 
             // make sure all previous computations have finished before starting the capture
             if (ctx->cmd_buf_last) {
@@ -392,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
             }
 
             if (!ctx->capture_started) {
+                NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()];
+
+                GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]);
+
                 // create capture scope
                 id device = ggml_metal_device_get_obj(ctx->dev);
                 ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
@@ -399,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
                 MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
                 descriptor.captureObject = ctx->capture_scope;
                 descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
-                descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
+                descriptor.outputURL = [NSURL fileURLWithPath:path];
 
                 NSError * error = nil;
                 if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
@@ -462,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
 
         // enter here only when capturing in order to wait for all computation to finish
         // otherwise, we leave the graph to compute asynchronously
-        if (!use_capture && ctx->capture_started) {
+        if (use_capture && ctx->capture_started) {
             // wait for completion and check status of each command buffer
             // needed to detect if the device ran out-of-memory for example (#1881)
             {
@@ -514,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph *
 
             [ctx->capture_scope endScope];
             [[MTLCaptureManager sharedCaptureManager] stopCapture];
+
+            ctx->capture_started = false;
         }
     }
 
@@ -530,6 +624,42 @@ void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
     //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
 }
 
+void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) {
+    @autoreleasepool {
+        id queue = ggml_metal_device_get_queue(ctx->dev);
+        id cmd_buf = [queue commandBuffer];
+
+        ggml_metal_event_encode_signal(ev, cmd_buf);
+
+        [cmd_buf commit];
+
+        [ctx->cmd_bufs_ext addObject:cmd_buf];
+        ctx->cmd_buf_last = cmd_buf;
+
+        [cmd_buf retain];
+    }
+}
+
+void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) {
+    @autoreleasepool {
+        id queue = ggml_metal_device_get_queue(ctx->dev);
+        id cmd_buf = [queue commandBuffer];
+
+        ggml_metal_event_encode_wait(ev, cmd_buf);
+
+        [cmd_buf commit];
+
+        [ctx->cmd_bufs_ext addObject:cmd_buf];
+        ctx->cmd_buf_last = cmd_buf;
+
+        [cmd_buf retain];
+    }
+}
+
+ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) {
+    return ctx->ev_cpy;
+}
+
 void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
     if (ctx->n_cb != n_cb) {
         ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
@@ -570,7 +700,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
             idx_end,
             ctx->use_fusion,
             ctx->use_concurrency,
-            ctx->capture_next_compute,
+            ctx->capture_compute,
             ctx->debug_graph,
             ctx->debug_fusion);
 
@@ -605,5 +735,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
 }
 
 void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
-    ctx->capture_next_compute = true;
+    ctx->capture_compute = 1;
 }
diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
index b0734797..72ad876d 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -17,10 +17,12 @@ struct ggml_metal_device_deleter {
 
 typedef std::unique_ptr ggml_metal_device_ptr;
 
-ggml_metal_device_t ggml_metal_device_get(void) {
-    static ggml_metal_device_ptr ctx { ggml_metal_device_init() };
+ggml_metal_device_t ggml_metal_device_get(int device) {
+    static std::vector devs;
 
-    return ctx.get();
+    devs.emplace_back(ggml_metal_device_init(device));
+
+    return devs.back().get();
 }
 
 struct ggml_metal_pipelines {
@@ -94,6 +96,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
+
+    const char * pool_str = "undefined";
+    switch (op_pool) {
+        case GGML_OP_POOL_AVG: pool_str = "avg"; break;
+        case GGML_OP_POOL_MAX: pool_str = "max"; break;
+        default: GGML_ASSERT(false && "not implemented");
+    };
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
+    snprintf(name, sizeof(name), "%s", base);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
@@ -149,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    const int n = op->src[0]->ne[0];
+
+    snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_n=%d", base, n);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    res.nsg  = 1;
+    res.smem = 0;
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
     char base[256];
     char name[256];
@@ -165,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
 }
 
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
     char base[256];
     char name[256];
 
-    const int64_t n = ggml_nelements(op);
+    int op_num = -1;
 
-    const char * op_str = "undefined";
     switch (op->op) {
-        case GGML_OP_SCALE:      op_str = "scale";      break;
-        case GGML_OP_FILL:       op_str = "fill";       break;
-        case GGML_OP_CLAMP:      op_str = "clamp";      break;
-        case GGML_OP_SQR:        op_str = "sqr";        break;
-        case GGML_OP_SQRT:       op_str = "sqrt";       break;
-        case GGML_OP_SIN:        op_str = "sin";        break;
-        case GGML_OP_COS:        op_str = "cos";        break;
-        case GGML_OP_LOG:        op_str = "log";        break;
-        case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
+        case GGML_OP_SCALE:      op_num = OP_UNARY_NUM_SCALE;      break;
+        case GGML_OP_FILL:       op_num = OP_UNARY_NUM_FILL;       break;
+        case GGML_OP_CLAMP:      op_num = OP_UNARY_NUM_CLAMP;      break;
+        case GGML_OP_SQR:        op_num = OP_UNARY_NUM_SQR;        break;
+        case GGML_OP_SQRT:       op_num = OP_UNARY_NUM_SQRT;       break;
+        case GGML_OP_SIN:        op_num = OP_UNARY_NUM_SIN;        break;
+        case GGML_OP_COS:        op_num = OP_UNARY_NUM_COS;        break;
+        case GGML_OP_LOG:        op_num = OP_UNARY_NUM_LOG;        break;
+        case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
-                case GGML_UNARY_OP_TANH:        op_str = "tanh";        break;
-                case GGML_UNARY_OP_RELU:        op_str = "relu";        break;
-                case GGML_UNARY_OP_SIGMOID:     op_str = "sigmoid";     break;
-                case GGML_UNARY_OP_GELU:        op_str = "gelu";        break;
-                case GGML_UNARY_OP_GELU_ERF:    op_str = "gelu_erf";    break;
-                case GGML_UNARY_OP_GELU_QUICK:  op_str = "gelu_quick";  break;
-                case GGML_UNARY_OP_SILU:        op_str = "silu";        break;
-                case GGML_UNARY_OP_ELU:         op_str = "elu";         break;
-                case GGML_UNARY_OP_NEG:         op_str = "neg";         break;
-                case GGML_UNARY_OP_ABS:         op_str = "abs";         break;
-                case GGML_UNARY_OP_SGN:         op_str = "sgn";         break;
-                case GGML_UNARY_OP_STEP:        op_str = "step";        break;
-                case GGML_UNARY_OP_HARDSWISH:   op_str = "hardswish";   break;
-                case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
-                case GGML_UNARY_OP_EXP:         op_str = "exp";         break;
-                case GGML_UNARY_OP_SOFTPLUS:    op_str = "softplus";    break;
-                case GGML_UNARY_OP_EXPM1:       op_str = "expm1";       break;
+                case GGML_UNARY_OP_TANH:        op_num = OP_UNARY_NUM_TANH;        break;
+                case GGML_UNARY_OP_RELU:        op_num = OP_UNARY_NUM_RELU;        break;
+                case GGML_UNARY_OP_SIGMOID:     op_num = OP_UNARY_NUM_SIGMOID;     break;
+                case GGML_UNARY_OP_GELU:        op_num = OP_UNARY_NUM_GELU;        break;
+                case GGML_UNARY_OP_GELU_ERF:    op_num = OP_UNARY_NUM_GELU_ERF;    break;
+                case GGML_UNARY_OP_GELU_QUICK:  op_num = OP_UNARY_NUM_GELU_QUICK;  break;
+                case GGML_UNARY_OP_SILU:        op_num = OP_UNARY_NUM_SILU;        break;
+                case GGML_UNARY_OP_ELU:         op_num = OP_UNARY_NUM_ELU;         break;
+                case GGML_UNARY_OP_NEG:         op_num = OP_UNARY_NUM_NEG;         break;
+                case GGML_UNARY_OP_ABS:         op_num = OP_UNARY_NUM_ABS;         break;
+                case GGML_UNARY_OP_SGN:         op_num = OP_UNARY_NUM_SGN;         break;
+                case GGML_UNARY_OP_STEP:        op_num = OP_UNARY_NUM_STEP;        break;
+                case GGML_UNARY_OP_HARDSWISH:   op_num = OP_UNARY_NUM_HARDSWISH;   break;
+                case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
+                case GGML_UNARY_OP_EXP:         op_num = OP_UNARY_NUM_EXP;         break;
+                case GGML_UNARY_OP_SOFTPLUS:    op_num = OP_UNARY_NUM_SOFTPLUS;    break;
+                case GGML_UNARY_OP_EXPM1:       op_num = OP_UNARY_NUM_EXPM1;       break;
                 default: GGML_ABORT("fatal error");
             } break;
         default: GGML_ABORT("fatal error");
     };
 
-    const char * suffix = "";
-    if (n % 4 == 0) {
-        suffix = "_4";
-    }
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+    const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
+
+    snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
+        ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
+    res.c4  = is_c4;
+    res.cnt = is_cnt;
+
     return res;
 }
 
@@ -273,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
 }
 
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
-    GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
     char base[256];
     char name[256];
 
-    const char * op_str = "undefined";
+    int op_num = -1;
+
     switch (op->op) {
-        case GGML_OP_SUM_ROWS:
-            op_str = "sum_rows"; break;
-        case GGML_OP_MEAN:
-            op_str = "mean"; break;
+        case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
+        case GGML_OP_MEAN:     op_num = OP_SUM_ROWS_NUM_MEAN;     break;
         default: GGML_ABORT("fatal error");
     };
 
-    snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+    snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d", base, op_num);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     res.smem = 32*sizeof(float);
 
+    if (is_c4) {
+        res.smem *= 4;
+    }
+
+    res.c4  = is_c4;
+
     return res;
 }
 
@@ -507,6 +577,71 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    // v is src[2], dimensions: S_v = ne[0], H = ne[1]
+    const int ne20 = op->src[2]->ne[0]; // S_v
+    const int ne21 = op->src[2]->ne[1]; // H
+    const int ne30 = op->src[3]->ne[0]; // G
+
+    const int nsg = op->src[2]->ne[0]/32;
+
+    GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->ne[0] == ne20 * ne21);
+    GGML_ASSERT(ne20 % 32 == 0);
+
+    snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
+    snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
+        ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.nsg = nsg;
+
+    return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    const int nsg = 8;
+    const int n   = op->src[1]->ne[1];
+    const int k   = op->src[1]->ne[0];
+
+    snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
+        ggml_metal_cv_set_int16(cv, n,   FC_SOLVE_TRI + 1);
+        ggml_metal_cv_set_int16(cv, k,   FC_SOLVE_TRI + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.nsg  = nsg;
+    res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
     char base[256];
     char name[256];
@@ -1315,34 +1450,80 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
     GGML_UNUSED(op);
 }
 
-ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
-        ggml_metal_library_t lib,
-        ggml_op op,
-        int32_t n_fuse,
-        bool row) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
     char base[256];
     char name[256];
 
-    const char * op_str = "undefined";
-    switch (op) {
-        case GGML_OP_ADD:   op_str = "add";   break;
-        case GGML_OP_SUB:   op_str = "sub";   break;
-        case GGML_OP_MUL:   op_str = "mul";   break;
-        case GGML_OP_DIV:   op_str = "div";   break;
+    int op_num = -1;
+
+    switch (op->op) {
+        case GGML_OP_ADD: op_num = 0; break;
+        case GGML_OP_SUB: op_num = 1; break;
+        case GGML_OP_MUL: op_num = 2; break;
+        case GGML_OP_DIV: op_num = 3; break;
         default: GGML_ABORT("fatal error");
     };
 
-    if (row) {
-        snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
-    } else {
-        snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
-    }
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t1_str = ggml_type_name(op->src[1]->type);
+    const char * t_str  = ggml_type_name(op->type);
 
-    snprintf(name, 256, "%s", base);
+    const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
+
+    const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0];
+    const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
+
+    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
+    snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+        ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
+        ggml_metal_cv_set_bool (cv, is_rb,  FC_BIN + 2);
+        ggml_metal_cv_set_bool (cv, is_cb,  FC_BIN + 3);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.c4  = is_c4;
+    res.cnt = is_rb;
+
+    return res;
+}
+
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
+    char base[256];
+    char name[256];
+
+    int op_num = -1;
+
+    switch (op) {
+        case GGML_OP_ADD: op_num = 0; break;
+        case GGML_OP_SUB: op_num = 1; break;
+        case GGML_OP_MUL: op_num = 2; break;
+        case GGML_OP_DIV: op_num = 3; break;
+        default: GGML_ABORT("fatal error");
+    };
+
+    snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
+    snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
+        ggml_metal_cv_set_int16(cv, 1,      FC_BIN + 1);
+        ggml_metal_cv_set_bool (cv, false,  FC_BIN + 2);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     return res;
@@ -1351,13 +1532,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_L2_NORM);
 
-    GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
-    GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
-
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_l2_norm_f32");
+    const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
+
+    const char * t0_str = ggml_type_name(op->src[0]->type);
+    const char * t_str  = ggml_type_name(op->type);
+
+    snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@@ -1365,6 +1548,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
         res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
+    res.c4   = is_c4;
     res.smem = 32*sizeof(float);
 
     return res;
@@ -1570,12 +1754,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
-    snprintf(name, 256, "%s", base);
+    const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
+    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
+
+    const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
+
+    if (mode == GGML_SCALE_MODE_BILINEAR) {
+        snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type));
+    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
+        snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type));
+    } else {
+        snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type));
+    }
+    snprintf(name, 256, "%s_aa=%d", base, antialias);
 
     ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
     if (!res.pipeline) {
-        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
     }
 
     return res;
diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
index 9c3b0014..fd2b3dde 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ggml/src/ggml-metal/ggml-metal-device.h
@@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
     int nr1;
 
     size_t smem;
+
+    bool c4;
+    bool cnt;
 };
 
 int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
@@ -104,9 +107,11 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
 
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag              (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -120,6 +125,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net   (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -131,7 +138,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
-struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one           (ggml_metal_library_t lib, enum ggml_op op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
@@ -204,7 +212,9 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets);
 //
 
 struct ggml_metal_device_props {
+    int device;
     char name[128];
+    char desc[128];
 
     size_t max_buffer_size;
     size_t max_working_set_size;
@@ -223,11 +233,15 @@ struct ggml_metal_device_props {
     int op_offload_min_batch_size;
 };
 
-ggml_metal_device_t ggml_metal_device_init(void);
+typedef struct ggml_metal_event * ggml_metal_event_t;
+
+void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
+void ggml_metal_event_encode_wait  (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
+
+ggml_metal_device_t ggml_metal_device_init(int device);
 void ggml_metal_device_free(ggml_metal_device_t dev);
 
-// return a singleton that is automatically destroyed when the program exits
-ggml_metal_device_t ggml_metal_device_get(void);
+ggml_metal_device_t ggml_metal_device_get(int device);
 
 void * ggml_metal_device_get_obj  (ggml_metal_device_t dev); // id
 void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id
@@ -239,6 +253,10 @@ void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset
 
 void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev);
 
+ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev);
+void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev);
+void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev);
+
 void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
 bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
 
diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
index ff899a81..82101f47 100644
--- a/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ggml/src/ggml-metal/ggml-metal-device.m
@@ -24,9 +24,6 @@
 static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
 static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
 
-// virtual address for GPU memory allocations
-static atomic_uintptr_t g_addr_device = 0x000000400ULL;
-
 #if !GGML_METAL_EMBED_LIBRARY
 // Here to assist with NSBundle Path Hack
 @interface GGMLMetalClass : NSObject
@@ -349,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
 
     struct ggml_metal_pipeline_with_params res = {
         /*.pipeline =*/ nil,
+        /*.nsg      =*/ 0,
         /*.nr0      =*/ 0,
         /*.nr1      =*/ 0,
-        /*.nsg      =*/ 0,
         /*.smem     =*/ 0,
+        /*.c4       =*/ false,
+        /*.cnt      =*/ false,
     };
 
     res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
@@ -365,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
 struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
     struct ggml_metal_pipeline_with_params res = {
         /*.pipeline =*/ nil,
+        /*.nsg      =*/ 0,
         /*.nr0      =*/ 0,
         /*.nr1      =*/ 0,
-        /*.nsg      =*/ 0,
         /*.smem     =*/ 0,
+        /*.c4       =*/ false,
+        /*.cnt      =*/ false,
     };
 
     [lib->lock lock];
@@ -523,6 +524,9 @@ struct ggml_metal_device {
     ggml_metal_library_t library;
 
     struct ggml_metal_device_props props;
+
+    // virtual address for GPU memory allocations
+    atomic_uintptr_t addr_virt;
 };
 
 //
@@ -618,7 +622,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
     free(rsets);
 }
 
-ggml_metal_device_t ggml_metal_device_init(void) {
+ggml_metal_device_t ggml_metal_device_init(int device) {
     ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));
 
     assert(dev != NULL);
@@ -632,6 +636,9 @@ ggml_metal_device_t ggml_metal_device_init(void) {
                 GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
             }
 
+            dev->addr_virt = 0x000000400ULL;
+
+            dev->props.device = device;
             dev->props.has_simdgroup_reduction  = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
             dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
@@ -785,10 +792,15 @@ ggml_metal_device_t ggml_metal_device_init(void) {
             dev->props.op_offload_min_batch_size  = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
 
             dev->props.max_buffer_size            = dev->mtl_device.maxBufferLength;
-            dev->props.max_working_set_size       = dev->mtl_device.recommendedMaxWorkingSetSize;
             dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
+            if (@available(macOS 10.12, iOS 16.0, *)) {
+                dev->props.max_working_set_size   = dev->mtl_device.recommendedMaxWorkingSetSize;
+            } else {
+                dev->props.max_working_set_size   = dev->mtl_device.maxBufferLength;
+            }
 
-            strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1);
+            snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device);
+            snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]);
 
             dev->library = ggml_metal_library_init(dev);
             if (!dev->library) {
@@ -918,6 +930,59 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) {
     atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed);
 }
 
+struct ggml_metal_event {
+    void * obj; // id
+
+    atomic_int value;
+};
+
+void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {
+    id event = (id)ev->obj;
+
+    id cmd_buf = (id) cmd_buf_raw;
+
+    [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1];
+}
+
+void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {
+    id event = (id)ev->obj;
+
+    id cmd_buf = (id) cmd_buf_raw;
+
+    [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];
+}
+
+ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) {
+    id event = [dev->mtl_device newEvent];
+
+    ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event));
+
+    ev->obj = (__bridge void *)event;
+    ev->value = 0;
+
+    return ev;
+}
+
+void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) {
+    id event = ev->obj;
+    [event release];
+
+    free(ev);
+
+    GGML_UNUSED(dev);
+}
+
+void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) {
+    @autoreleasepool {
+        id event = ev->obj;
+
+        id cmd_buf = [dev->mtl_queue commandBuffer];
+        [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];
+        [cmd_buf commit];
+        [cmd_buf waitUntilCompleted];
+    }
+}
+
 void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {
     if (@available(macOS 10.12, iOS 16.0, *)) {
         *total = dev->mtl_device.recommendedMaxWorkingSetSize;
@@ -946,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
     }
 
     switch (op->op) {
+        case GGML_OP_SCALE:
+        case GGML_OP_FILL:
+        case GGML_OP_CLAMP:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_LOG:
+            return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_TANH:
@@ -965,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 case GGML_UNARY_OP_EXP:
                 case GGML_UNARY_OP_SOFTPLUS:
                 case GGML_UNARY_OP_EXPM1:
-                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+                    return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
                 default:
                     return false;
             }
@@ -993,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_ADD_ID:
-            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
+            return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_REPEAT:
-        case GGML_OP_SCALE:
-        case GGML_OP_FILL:
         case GGML_OP_CONV_TRANSPOSE_1D:
             return true;
         case GGML_OP_CONV_TRANSPOSE_2D:
@@ -1005,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
                 op->src[1]->type == GGML_TYPE_F32 &&
                 op->type == GGML_TYPE_F32;
-        case GGML_OP_CLAMP:
-            return op->src[0]->type == GGML_TYPE_F32;
-        case GGML_OP_SQR:
-        case GGML_OP_SQRT:
-        case GGML_OP_SIN:
-        case GGML_OP_COS:
-        case GGML_OP_LOG:
-            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SUM:
             return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_TRI:
@@ -1022,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_MEAN:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_L2_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+            return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_COUNT_EQUAL:
             return has_simdgroup_reduction &&
                 op->src[0]->type == GGML_TYPE_I32 &&
@@ -1044,10 +1107,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                    op->src[1]->type == GGML_TYPE_F32 &&
                    op->type == GGML_TYPE_F32 &&
                    (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
-        case GGML_OP_POOL_1D:
-            return false;
         case GGML_OP_UPSCALE:
-            return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
+            return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_POOL_1D:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_POOL_2D:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_PAD:
@@ -1078,12 +1141,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 op->src[0]->ne[0] != 112 &&
                 op->src[0]->ne[0] != 128 &&
                 op->src[0]->ne[0] != 192 &&
-                op->src[0]->ne[0] != 256) {
-                return false;
-            }
-            if (op->src[0]->ne[0] == 576) {
-                // DeepSeek sizes
-                // TODO: disabled for now, until optmized
+                op->src[0]->ne[0] != 256 &&
+                op->src[0]->ne[0] != 320 &&
+                op->src[0]->ne[0] != 576) {
                 return false;
             }
             if (op->src[1]->type != op->src[2]->type) {
@@ -1096,9 +1156,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true;
+        case GGML_OP_GATED_DELTA_NET:
+            return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0;
+        case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
-            return has_simdgroup_reduction;
+            return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4;
+        case GGML_OP_SET:
         case GGML_OP_CPY:
         case GGML_OP_DUP:
         case GGML_OP_CONT:
@@ -1155,7 +1219,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 };
             }
         case GGML_OP_GET_ROWS:
-            return true;
+            return op->src[0]->type != GGML_TYPE_NVFP4;
         case GGML_OP_SET_ROWS:
             {
                 if (op->src[0]->type != GGML_TYPE_F32) {
@@ -1177,6 +1241,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                         return false;
                 };
             }
+        case GGML_OP_DIAG:
+            return true;
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
             return has_simdgroup_reduction;
@@ -1218,7 +1284,7 @@ struct ggml_metal_buffer {
     bool use_residency_sets;
 
     // optional MTLResidencySet
-    // note: cannot use explicity "id" here because it is not available on certain OSes
+    // note: cannot use explicitly "id" here because it is not available on certain OSes
     id rset;
 
     // pointers to global device
@@ -1344,8 +1410,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
         res->all_data = ggml_metal_host_malloc(size_aligned);
         res->is_shared = true;
     } else {
-        // use virtual address from g_addr_device counter
-        res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed);
+        // use virtual address
+        res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed);
         res->is_shared = false;
     }
     res->all_size = size_aligned;
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
index d3b0e732..53437b23 100644
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -35,7 +35,7 @@
 #define N_R0_Q4_K 2
 #define N_SG_Q4_K 2
 
-#define N_R0_Q5_K 2
+#define N_R0_Q5_K 1
 #define N_SG_Q5_K 2
 
 #define N_R0_Q6_K 2
@@ -78,15 +78,52 @@
 #define FC_MUL_MM                      700
 #define FC_ROPE                        800
 #define FC_SSM_CONV                    900
-#define FC_COUNT_EQUAL                 1000
+#define FC_SOLVE_TRI                   1000
+#define FC_COUNT_EQUAL                 1100
+#define FC_UNARY                       1200
+#define FC_BIN                         1300
+#define FC_SUM_ROWS                    1400
+#define FC_UPSCALE                     1500
+#define FC_GATED_DELTA_NET             1600
 
 // op-specific constants
-#define OP_FLASH_ATTN_EXT_NQPTG 8
+#define OP_FLASH_ATTN_EXT_NQPSG 8
 #define OP_FLASH_ATTN_EXT_NCPSG 64
 
-#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
+#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
 
+#define OP_UNARY_NUM_SCALE      10
+#define OP_UNARY_NUM_FILL       11
+#define OP_UNARY_NUM_CLAMP      12
+#define OP_UNARY_NUM_SQR        13
+#define OP_UNARY_NUM_SQRT       14
+#define OP_UNARY_NUM_SIN        15
+#define OP_UNARY_NUM_COS        16
+#define OP_UNARY_NUM_LOG        17
+#define OP_UNARY_NUM_LEAKY_RELU 18
+
+#define OP_UNARY_NUM_TANH        100
+#define OP_UNARY_NUM_RELU        101
+#define OP_UNARY_NUM_SIGMOID     102
+#define OP_UNARY_NUM_GELU        103
+#define OP_UNARY_NUM_GELU_ERF    104
+#define OP_UNARY_NUM_GELU_QUICK  105
+#define OP_UNARY_NUM_SILU        106
+#define OP_UNARY_NUM_ELU         107
+#define OP_UNARY_NUM_NEG         108
+#define OP_UNARY_NUM_ABS         109
+#define OP_UNARY_NUM_SGN         110
+#define OP_UNARY_NUM_STEP        111
+#define OP_UNARY_NUM_HARDSWISH   112
+#define OP_UNARY_NUM_HARDSIGMOID 113
+#define OP_UNARY_NUM_EXP         114
+#define OP_UNARY_NUM_SOFTPLUS    115
+#define OP_UNARY_NUM_EXPM1       116
+
+#define OP_SUM_ROWS_NUM_SUM_ROWS 10
+#define OP_SUM_ROWS_NUM_MEAN     11
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@@ -122,6 +159,31 @@ typedef struct {
     int32_t  dim;
 } ggml_metal_kargs_concat;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    slope;
+    float    scale;
+    float    bias;
+    float    val;
+    float    min;
+    float    max;
+} ggml_metal_kargs_unary;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -179,20 +241,6 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_repeat;
 
-typedef struct {
-    float scale;
-    float bias;
-} ggml_metal_kargs_scale;
-
-typedef struct {
-    float val;
-} ggml_metal_kargs_fill;
-
-typedef struct {
-    float min;
-    float max;
-} ggml_metal_kargs_clamp;
-
 typedef struct {
     int64_t  nk0;
     int64_t  ne00;
@@ -496,8 +544,21 @@ typedef struct {
 
 typedef struct {
     int32_t  ne00;
-    int32_t  ne00_4;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
     uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
     float    eps;
 } ggml_metal_kargs_l2_norm;
 
@@ -733,6 +794,71 @@ typedef struct {
     uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne20;
+    int32_t  ne21;
+    int32_t  ne22;
+    int32_t  ne23;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
+    int32_t  ns02;
+    int32_t  ns12;
+    int32_t  ns22;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_gated_delta_net;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
 typedef struct {
     int32_t  ne00t;
     int32_t  ne00;
@@ -764,6 +890,25 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_set_rows;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_diag;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
@@ -785,6 +930,7 @@ typedef struct {
     float    sf1;
     float    sf2;
     float    sf3;
+    float    poffs;
 } ggml_metal_kargs_upscale;
 
 typedef struct {
@@ -833,10 +979,6 @@ typedef struct {
     int      max_period;
 } ggml_metal_kargs_timestep_embedding;
 
-typedef struct {
-    float    slope;
-} ggml_metal_kargs_leaky_relu;
-
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -928,6 +1070,15 @@ typedef struct {
     int64_t  np;
 } ggml_metal_kargs_pool_2d;
 
+typedef struct {
+    int32_t  k0;
+    int32_t  s0;
+    int32_t  p0;
+    int64_t  IW;
+    int64_t  OW;
+    int64_t  np;
+} ggml_metal_kargs_pool_1d;
+
 typedef struct {
      int64_t ne00;
     uint64_t nb01;
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
index a50b12b6..c0bcad39 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
         GGML_ABORT("unsupported op");
     }
 
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+        return 1;
+    }
+
     int n_fuse = 1;
 
     // check if the current node can run concurrently with other nodes before it
@@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
                 n_fuse = ggml_metal_op_acc(ctx, idx);
             } break;
         case GGML_OP_SCALE:
-            {
-                n_fuse = ggml_metal_op_scale(ctx, idx);
-            } break;
         case GGML_OP_FILL:
-            {
-                n_fuse = ggml_metal_op_fill(ctx, idx);
-            } break;
         case GGML_OP_CLAMP:
-            {
-                n_fuse = ggml_metal_op_clamp(ctx, idx);
-            } break;
+        case GGML_OP_LEAKY_RELU:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
@@ -337,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_rwkv(ctx, idx);
             } break;
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
+            } break;
+        case GGML_OP_SOLVE_TRI:
+            {
+                n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+            } break;
         case GGML_OP_MUL_MAT:
             {
                 n_fuse = ggml_metal_op_mul_mat(ctx, idx);
@@ -353,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_set_rows(ctx, idx);
             } break;
+        case GGML_OP_DIAG:
+            {
+                n_fuse = ggml_metal_op_diag(ctx, idx);
+            } break;
         case GGML_OP_L2_NORM:
             {
                 n_fuse = ggml_metal_op_l2_norm(ctx, idx);
@@ -414,10 +422,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_top_k(ctx, idx);
             } break;
-        case GGML_OP_LEAKY_RELU:
-            {
-                n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
-            } break;
         case GGML_OP_TRI:
             {
                 n_fuse = ggml_metal_op_tri(ctx, idx);
@@ -426,12 +430,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
             } break;
+        case GGML_OP_SET:
+            {
+                n_fuse = ggml_metal_op_set(ctx, idx);
+            } break;
         case GGML_OP_DUP:
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             {
                 n_fuse = ggml_metal_op_cpy(ctx, idx);
             } break;
+        case GGML_OP_POOL_1D:
+            {
+                n_fuse = ggml_metal_op_pool_1d(ctx, idx);
+            } break;
         case GGML_OP_POOL_2D:
             {
                 n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -612,8 +624,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
     GGML_ASSERT(op->type         == GGML_TYPE_F32);
 
-    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-    GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
 
     const size_t pnb1 = ((const int32_t *) op->op_params)[0];
     const size_t pnb2 = ((const int32_t *) op->op_params)[1];
@@ -623,7 +635,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
 
     if (!inplace) {
-        // run a separete kernel to cpy src->dst
+        // run a separate kernel to cpy src->dst
         // not sure how to avoid this
         // TODO: make a simpler cpy_bytes kernel
 
@@ -663,10 +675,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     }
 
     ggml_metal_kargs_bin args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
         /*.nb00 =*/ nb00,
         /*.nb01 =*/ pnb1,
         /*.nb02 =*/ pnb2,
@@ -679,10 +691,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.nb11 =*/ nb11,
         /*.nb12 =*/ nb12,
         /*.nb13 =*/ nb13,
-        /*.ne0  =*/ ne0,
-        /*.ne1  =*/ ne1,
-        /*.ne2  =*/ ne2,
-        /*.ne3  =*/ ne3,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
         /*.nb0  =*/ nb0,
         /*.nb1  =*/ pnb1,
         /*.nb2  =*/ pnb2,
@@ -691,7 +703,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.o1   =*/ { 0 },
     };
 
-    auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
+    auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -699,126 +711,19 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 
-    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+    const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    int nth = 1;
+
+    while (2*nth < args.ne0 && nth < nth_max) {
+        nth *= 2;
+    }
 
     ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
 
     return 1;
 }
 
-int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float scale;
-    float bias;
-    memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
-    memcpy(&bias,  ((const int32_t *) op->op_params) + 1, sizeof(float));
-
-    ggml_metal_kargs_scale args = {
-        /*.scale =*/ scale,
-        /*.bias  =*/ bias,
-    };
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
-    }
-
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    const float val = ggml_get_op_params_f32(op, 0);
-
-    ggml_metal_kargs_fill args = {
-        /*.val =*/ val
-    };
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
-    }
-
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
-int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float min;
-    float max;
-    memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
-    memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
-
-    ggml_metal_kargs_clamp args = {
-        /*.min =*/ min,
-        /*.max =*/ max,
-    };
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
-    }
-
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
 int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -830,19 +735,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    int64_t n = ggml_nelements(op);
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
 
-    if (n % 4 == 0) {
-        n /= 4;
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    ggml_metal_kargs_unary args = {
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.slope =*/ 0.0,
+        /*.scale =*/ 0.0,
+        /*.bias  =*/ 0.0,
+        /*.val   =*/ 0.0,
+        /*.min   =*/ 0.0,
+        /*.max   =*/ 0.0,
+    };
+
+    if (op->op == GGML_OP_LEAKY_RELU) {
+        args.slope = ggml_get_op_params_f32(op, 0);
+    }
+
+    if (op->op == GGML_OP_SCALE) {
+        args.scale = ggml_get_op_params_f32(op, 0);
+        args.bias  = ggml_get_op_params_f32(op, 1);
+    }
+
+    if (op->op == GGML_OP_FILL) {
+        args.val = ggml_get_op_params_f32(op, 0);
+    }
+
+    if (op->op == GGML_OP_CLAMP) {
+        args.min = ggml_get_op_params_f32(op, 0);
+        args.max = ggml_get_op_params_f32(op, 1);
     }
 
     auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         1);
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+    if (pipeline.cnt) {
+        const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    } else {
+        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+        const int nth = MIN(args.ne00, nth_max);
+
+        const int nk0 = (args.ne00 + nth - 1)/nth;
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
+    }
 
     return 1;
 }
@@ -953,6 +918,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     ggml_metal_kargs_sum_rows args = {
         /*.ne00 =*/ ne00,
         /*.ne01 =*/ ne01,
@@ -974,21 +944,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
 
     auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
     int nth = 32; // SIMD width
 
-    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+    while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00);
+    nth = std::min(nth, (int) args.ne00);
 
     const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
@@ -1247,6 +1222,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS(int32_t,  ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS(int32_t,  ne, op, ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+    ggml_metal_kargs_diag args = {
+        /*.ne00 =*/ne00,
+        /*.ne01 =*/ne01,
+        /*.ne02 =*/ne02,
+        /*.ne03 =*/ne03,
+        /*.nb00 =*/nb00,
+        /*.nb01 =*/nb01,
+        /*.nb02 =*/nb02,
+        /*.nb03 =*/nb03,
+        /*.ne0  =*/ne0,
+        /*.ne1  =*/ne1,
+        /*.ne2  =*/ne2,
+        /*.ne3  =*/ne3,
+        /*.nb0  =*/nb0,
+        /*.nb1  =*/nb1,
+        /*.nb2  =*/nb2,
+        /*.nb3  =*/nb3,
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -1549,6 +1566,266 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
+
+    int ida = 0;
+
+    ggml_metal_kargs_gated_delta_net args = {
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.ne10 =*/ ne10,
+        /*.ne11 =*/ ne11,
+        /*.ne12 =*/ ne12,
+        /*.ne13 =*/ ne13,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.ne20 =*/ ne20,
+        /*.ne21 =*/ ne21,
+        /*.ne22 =*/ ne22,
+        /*.ne23 =*/ ne23,
+        /*.nb20 =*/ nb20,
+        /*.nb21 =*/ nb21,
+        /*.nb22 =*/ nb22,
+        /*.nb23 =*/ nb23,
+        /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
+        /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
+        /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
+        /*.ne0  =*/ ne0,
+        /*.ne1  =*/ ne1,
+        /*.ne2  =*/ ne2,
+        /*.ne3  =*/ ne3,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+    };
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args),                  ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++); // dst
+
+    const int nsg = pipeline.nsg;
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
+
+    return 1;
+}
+
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_kargs_solve_tri args = {
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.ne10 =*/ ne10,
+        /*.ne11 =*/ ne11,
+        /*.ne12 =*/ ne12,
+        /*.ne13 =*/ ne13,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.ne0  =*/ ne0,
+        /*.ne1  =*/ ne1,
+        /*.ne2  =*/ ne2,
+        /*.ne3  =*/ ne3,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
+
+    const int nsg = pipeline.nsg;
+
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
+
+    return 1;
+}
+
+int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
+    const size_t pnb1 = ((const int32_t *) op->op_params)[0];
+    const size_t pnb2 = ((const int32_t *) op->op_params)[1];
+    const size_t pnb3 = ((const int32_t *) op->op_params)[2];
+    const size_t offs = ((const int32_t *) op->op_params)[3];
+
+    const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
+
+    if (!inplace) {
+        // run a separate kernel to cpy src->dst
+        // not sure how to avoid this
+        // TODO: make a simpler cpy_bytes kernel
+
+        //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
+        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+
+        ggml_metal_kargs_cpy args = {
+            /*.nk0  =*/ ne00,
+            /*.ne00 =*/ ne00,
+            /*.ne01 =*/ ne01,
+            /*.ne02 =*/ ne02,
+            /*.ne03 =*/ ne03,
+            /*.nb00 =*/ nb00,
+            /*.nb01 =*/ nb01,
+            /*.nb02 =*/ nb02,
+            /*.nb03 =*/ nb03,
+            /*.ne0  =*/ ne0,
+            /*.ne1  =*/ ne1,
+            /*.ne2  =*/ ne2,
+            /*.ne3  =*/ ne3,
+            /*.nb0  =*/ nb0,
+            /*.nb1  =*/ nb1,
+            /*.nb2  =*/ nb2,
+            /*.nb3  =*/ nb3,
+        };
+
+        ggml_metal_encoder_set_pipeline(enc, pipeline);
+        ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+        ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+        ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+        const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
+
+        ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
+
+        ggml_metal_op_concurrency_reset(ctx);
+    }
+
+    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
+
+    GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
+
+    int64_t nk0 = ne10;
+    if (ggml_is_quantized(op->src[1]->type)) {
+        nk0 = ne10/16;
+    } else if (ggml_is_quantized(op->type)) {
+        nk0 = ne10/ggml_blck_size(op->type);
+    }
+
+    int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    // when rows are small, we can batch them together in a single threadgroup
+    int nrptg = 1;
+
+    // TODO: relax this constraint in the future
+    if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
+        if (nth > nk0) {
+            nrptg = (nth + nk0 - 1)/nk0;
+            nth   = nk0;
+
+            if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+                nrptg--;
+            }
+        }
+    }
+
+    nth = std::min(nth, nk0);
+
+    ggml_metal_kargs_cpy args = {
+        /*.nk0  =*/ nk0,
+        /*.ne00 =*/ ne10,
+        /*.ne01 =*/ ne11,
+        /*.ne02 =*/ ne12,
+        /*.ne03 =*/ ne13,
+        /*.nb00 =*/ nb10,
+        /*.nb01 =*/ nb11,
+        /*.nb02 =*/ nb12,
+        /*.nb03 =*/ nb13,
+        /*.ne0  =*/ ne10,
+        /*.ne1  =*/ ne11,
+        /*.ne2  =*/ ne12,
+        /*.ne3  =*/ ne13,
+        /*.nb0  =*/ ggml_element_size(op),
+        /*.nb1  =*/ pnb1,
+        /*.nb2  =*/ pnb2,
+        /*.nb3  =*/ pnb3,
+    };
+
+    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
+
+    bid_dst.offs += offs;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, bid_src1, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -1622,6 +1899,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    const int32_t * opts = op->op_params;
+    ggml_op_pool op_pool = (ggml_op_pool) opts[0];
+
+    const int32_t k0 = opts[1];
+    const int32_t s0 = opts[2];
+    const int32_t p0 = opts[3];
+
+    const int64_t IW = op->src[0]->ne[0];
+    const int64_t OW = op->ne[0];
+
+    const int64_t np = ggml_nelements(op);
+
+    ggml_metal_kargs_pool_1d args_pool_1d = {
+        /* .k0 = */  k0,
+        /* .s0 = */  s0,
+        /* .p0 = */  p0,
+        /* .IW = */  IW,
+        /* .OW = */  OW,
+        /* .np = */  np
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
+
+    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
+    const int ntg = (np + nth - 1) / nth;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args_pool_1d, sizeof(args_pool_1d),  0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
+
+    return 1;
+}
+
+
 int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
@@ -1717,6 +2042,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
           (
            op->src[0]->type == GGML_TYPE_F32  || // TODO: helper function
            op->src[0]->type == GGML_TYPE_F16  ||
+           op->src[0]->type == GGML_TYPE_BF16 ||
            op->src[0]->type == GGML_TYPE_Q4_0 ||
            op->src[0]->type == GGML_TYPE_Q4_1 ||
            op->src[0]->type == GGML_TYPE_Q5_0 ||
@@ -1731,6 +2057,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
            op->src[0]->type == GGML_TYPE_Q4_K ||
            op->src[0]->type == GGML_TYPE_Q5_K ||
            op->src[0]->type == GGML_TYPE_Q6_K ||
+           op->src[0]->type == GGML_TYPE_Q2_K ||
+           op->src[0]->type == GGML_TYPE_Q3_K ||
            false) && (ne11 >= 4 && ne11 <= 8)
          )
         )
@@ -1759,7 +2087,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
         const int16_t r0ptg  = nypsg*nsg;         // num src0 rows per threadgroup
               int16_t r1ptg  = 4;                 // num src1 rows per threadgroup
 
-        // note: not sure how optimal are those across all different hardware. there might be someting cleverer
+        // note: not sure how optimal are those across all different hardware. there might be something cleverer
         switch (ne11) {
             case 2:
                 r1ptg = 2; break;
@@ -2239,7 +2567,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
     //    return res;
     //}
 
-    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
+    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
     const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
 
     const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
@@ -2355,7 +2683,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
     if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
         // half8x8 kernel
-        const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
+        const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
         const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
 
         GGML_ASSERT(nqptg <= 32);
@@ -2464,7 +2792,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
         // simdgroups per threadgroup (a.k.a. warps)
         //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
-        int32_t nsg = 4;
+        int32_t nsg = ne00 >= 512 ? 8 : 4;
 
         const size_t smem = FATTN_SMEM(nsg);
 
@@ -2522,9 +2850,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 #undef FATTN_SMEM
     } else {
         // half4x4 kernel
-        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
+        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
         const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
-        const int nkpsg = 1*ncpsg;
+        const int nhptg = 1;                           // heads per threadgroup
 
         GGML_ASSERT(nqptg <= 32);
         GGML_ASSERT(nqptg  % 1  == 0);
@@ -2576,6 +2904,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             ggml_metal_op_concurrency_reset(ctx);
         }
 
+        // note: for simplicity assume the K is larger or equal than V
+        GGML_ASSERT(ne10 >= ne20);
+
         // ne00 + 2*ncpsg*(nsg)
         // for each query, we load it as f16 in shared memory (ne00)
         // and store the soft_max values and the mask
@@ -2583,28 +2914,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         // ne20*(nsg)
         // each simdgroup has a full f32 head vector in shared mem to accumulate results
         //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
-
-        int64_t nsgmax = 2;
-        while (true) {
-            const size_t smem = FATTN_SMEM(nsgmax);
-            // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
-            if (smem > props_dev->max_theadgroup_memory_size/2) {
-                break;
-            }
-            nsgmax *= 2;
-        }
-        nsgmax /= 2;
-
-        // simdgroups per threadgroup (a.k.a. warps)
-        //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
-        const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
+#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
 
         int64_t nsg = 1;
-        while (nsg <= nsgt) {
-            nsg *= 2;
-        }
-        nsg /= 2;
 
         // workgroups
         // each workgroup handles nsg*nkpsg cache values
@@ -2617,7 +2929,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         } else {
             nwg = 32;
             nsg = 1;
-            while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+            while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
                 nsg *= 2;
             }
         }
@@ -2683,7 +2995,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
         } else {
             // sanity checks
             assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
@@ -2696,7 +3008,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
-            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
 
             // sync the 2 kernels
             ggml_metal_op_concurrency_reset(ctx);
@@ -2748,8 +3060,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
 
-    bool bcast_row = false;
-
     ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
     ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
     ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
@@ -2843,18 +3153,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
 
     struct ggml_metal_pipeline_with_params pipeline;
 
-    if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-        GGML_ASSERT(ggml_is_contiguous(op->src[0]));
-
-        // src1 is a row
-        GGML_ASSERT(ne11 == 1);
-
-        pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
-
-        bcast_row = true;
-    } else {
-        pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
-    }
+    pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
 
     if (n_fuse > 1) {
         bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@@ -2868,20 +3167,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
         }
     }
 
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne10 = ne10/4;
+        args.ne0  = ne0/4;
+    }
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
     ggml_metal_encoder_set_buffer  (enc, bid_src1, 2);
     ggml_metal_encoder_set_buffer  (enc, bid_dst,  3);
 
-    if (bcast_row) {
-        const int64_t n = ggml_nelements(op)/4;
-
-        ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
+    if (pipeline.cnt) {
+        ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1);
     } else {
-        int nth = 32;
+        const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
-        while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+        int nth = 1;
+
+        while (2*nth < args.ne0 && nth < nth_max) {
             nth *= 2;
         }
 
@@ -2902,39 +3207,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
+    GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
+
+    ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
+    ggml_metal_buffer_id bid_dst  = ggml_metal_get_buffer_id(op);
+
     float eps;
     memcpy(&eps, op->op_params, sizeof(float));
 
-    int nth = 32; // SIMD width
-
     ggml_metal_kargs_l2_norm args = {
-        /*.ne00   =*/ ne00,
-        /*.ne00_4 =*/ ne00/4,
-        /*.nb01   =*/ nb01,
-        /*.eps    =*/ eps,
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.eps   =*/ eps,
     };
 
     auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
 
-    while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
+    if (pipeline.c4) {
+        args.ne00 = ne00/4;
+        args.ne0  = ne0/4;
+    }
+
+    int nth = 32; // SIMD width
+
+    while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
     }
 
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
-    nth = std::min(nth, ne00/4);
 
     const size_t smem = pipeline.smem;
 
-    const int64_t nrows = ggml_nrows(op->src[0]);
-
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+    ggml_metal_encoder_set_buffer  (enc, bid_src0, 1);
+    ggml_metal_encoder_set_buffer  (enc, bid_dst,  2);
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
 
     return 1;
 }
@@ -3484,32 +3809,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    const float sf0 = (float)ne0/op->src[0]->ne[0];
-    const float sf1 = (float)ne1/op->src[0]->ne[1];
-    const float sf2 = (float)ne2/op->src[0]->ne[2];
-    const float sf3 = (float)ne3/op->src[0]->ne[3];
+    float sf0 = (float)ne0/op->src[0]->ne[0];
+    float sf1 = (float)ne1/op->src[0]->ne[1];
+    float sf2 = (float)ne2/op->src[0]->ne[2];
+    float sf3 = (float)ne3/op->src[0]->ne[3];
+
+    const int32_t mode_flags = ggml_get_op_params_i32(op, 0);
+
+    float poffs = 0.5f;
+
+    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
+        poffs = 0.0f;
+        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
+        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
+    }
 
     ggml_metal_kargs_upscale args = {
-        /*.ne00 =*/ ne00,
-        /*.ne01 =*/ ne01,
-        /*.ne02 =*/ ne02,
-        /*.ne03 =*/ ne03,
-        /*.nb00 =*/ nb00,
-        /*.nb01 =*/ nb01,
-        /*.nb02 =*/ nb02,
-        /*.nb03 =*/ nb03,
-        /*.ne0 =*/ ne0,
-        /*.ne1 =*/ ne1,
-        /*.ne2 =*/ ne2,
-        /*.ne3 =*/ ne3,
-        /*.nb0 =*/ nb0,
-        /*.nb1 =*/ nb1,
-        /*.nb2 =*/ nb2,
-        /*.nb3 =*/ nb3,
-        /*.sf0 =*/ sf0,
-        /*.sf1 =*/ sf1,
-        /*.sf2 =*/ sf2,
-        /*.sf3 =*/ sf3
+        /*.ne00  =*/ ne00,
+        /*.ne01  =*/ ne01,
+        /*.ne02  =*/ ne02,
+        /*.ne03  =*/ ne03,
+        /*.nb00  =*/ nb00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne0   =*/ ne0,
+        /*.ne1   =*/ ne1,
+        /*.ne2   =*/ ne2,
+        /*.ne3   =*/ ne3,
+        /*.nb0   =*/ nb0,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
+        /*.sf0   =*/ sf0,
+        /*.sf1   =*/ sf1,
+        /*.sf2   =*/ sf2,
+        /*.sf3   =*/ sf3,
+        /*.poffs =*/ poffs,
     };
 
     auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
@@ -3942,42 +4278,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
-int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
-    ggml_tensor * op = ctx->node(idx);
-
-    ggml_metal_library_t lib = ctx->lib;
-    ggml_metal_encoder_t enc = ctx->enc;
-
-    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
-    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
-    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
-
-    float slope;
-    memcpy(&slope, op->op_params, sizeof(float));
-
-    ggml_metal_kargs_leaky_relu args = {
-        /*.slope =*/ slope
-    };
-
-    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
-
-    int64_t n = ggml_nelements(op);
-
-    if (n % 4 == 0) {
-        n /= 4;
-    }
-
-    ggml_metal_encoder_set_pipeline(enc, pipeline);
-    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
-    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
-
-    ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
-
-    return 1;
-}
-
 int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
index c1025d35..019f2fec 100644
--- a/ggml/src/ggml-metal/ggml-metal-ops.h
+++ b/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
 int ggml_metal_op_concat            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_repeat            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_acc               (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_scale             (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_fill              (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_clamp             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_unary             (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_glu               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_sum               (ggml_metal_op_t ctx, int idx);
@@ -56,11 +53,16 @@ int ggml_metal_op_sum_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cumsum            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_get_rows          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set_rows          (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_diag              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_gated_delta_net   (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_mul_mat           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_mul_mat_id        (ggml_metal_op_t ctx, int idx);
@@ -83,7 +85,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_top_k             (ggml_metal_op_t ctx, int idx);
-int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_tri               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx);
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index 56b59f0a..9382ce53 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -7,11 +7,15 @@
 #include "ggml-metal-context.h"
 #include "ggml-metal-ops.h"
 
-// globals
+#include 
+#include 
 
-// initialized in ggml_backend_metal_reg
-static ggml_backend_reg    g_ggml_metal_reg;
-static ggml_backend_device g_ggml_metal_device;
+#define GGML_METAL_NAME "MTL"
+#define GGML_METAL_MAX_DEVICES 16
+
+// number of Metal devices
+// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices
+static int g_devices = 1;
 
 ////////////////////////////////////////////////////////////////////////////////
 // backend interface
@@ -165,10 +169,28 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
     /* .reset           = */ NULL,
 };
 
+static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer ||
+           buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer;
+}
+
 //
 // buffer types
 //
 
+struct ggml_backend_metal_buffer_type {
+    int device;
+    std::string name;
+};
+
+struct ggml_backend_metal_buffer_type_deleter {
+    void operator()(ggml_backend_metal_buffer_type * ctx) const {
+        delete ctx;
+    }
+};
+
+typedef std::unique_ptr ggml_backend_metal_buffer_type_ptr;
+
 // common method for allocating shread or private Metal buffers
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
     ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
@@ -218,9 +240,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
 // default (shared) buffer type
 
 static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal";
+    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
 
-    GGML_UNUSED(buft);
+    return ctx->name.c_str();
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -249,29 +271,54 @@ static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_ty
     GGML_UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
-    static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_shared_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,
-            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
-            /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,
-        },
-        /* .device  = */ &g_ggml_metal_device,
-        /* .context = */ NULL,
-    };
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
 
-    return &ggml_backend_buffer_type_metal;
+    static std::vector bufts;
+    static std::vector ctxs;
+
+    static bool initialized = false;
+    if (!initialized) {
+        bufts.reserve(g_devices);
+        ctxs.reserve(g_devices);
+
+        for (int i = 0; i < g_devices; ++i) {
+            ggml_backend_metal_buffer_type * raw_ctx =
+                new ggml_backend_metal_buffer_type {
+                    /* .device = */ i,
+                    /* .name   = */ GGML_METAL_NAME + std::to_string(i),
+                };
+            ctxs.emplace_back(raw_ctx);
+
+            ggml_backend_buffer_type buft = {
+                /* .iface = */ {
+                    /* .get_name         = */ ggml_backend_metal_buffer_type_shared_get_name,
+                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
+                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,
+                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,
+                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
+                    /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,
+                },
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
+                /* .context = */ raw_ctx,
+            };
+
+            bufts.emplace_back(buft);
+        }
+
+        initialized = true;
+    }
+
+    return &bufts[device];
 }
 
 // default (private) buffer type
 
 static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal_Private";
+    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
 
-    GGML_UNUSED(buft);
+    return ctx->name.c_str();
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -300,29 +347,53 @@ static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_t
     GGML_UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
-    static ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_private_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,
-            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
-            /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,
-        },
-        /* .device  = */ &g_ggml_metal_device,
-        /* .context = */ NULL,
-    };
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
 
-    return &ggml_backend_buffer_type_metal;
+    static std::vector bufts;
+    static std::vector ctxs;
+
+    static bool initialized = false;
+    if (!initialized) {
+        bufts.reserve(g_devices);
+        ctxs.reserve(g_devices);
+
+        for (int i = 0; i < g_devices; ++i) {
+            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
+                /* .device = */ i,
+                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + "_Private"
+            };
+            ctxs.emplace_back(raw_ctx);
+
+            ggml_backend_buffer_type buft = {
+                /* .iface = */ {
+                    /* .get_name         = */ ggml_backend_metal_buffer_type_private_get_name,
+                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
+                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,
+                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,
+                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
+                    /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,
+                },
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
+                /* .context = */ raw_ctx,
+            };
+
+            bufts.emplace_back(buft);
+        }
+
+        initialized = true;
+    }
+
+    return &bufts[device];
 }
 
 // mapped buffer type
 
 static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
-    return "Metal_Mapped";
+    ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
 
-    GGML_UNUSED(buft);
+    return ctx->name.c_str();
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -352,31 +423,55 @@ static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_ty
     GGML_UNUSED(buft);
 }
 
-static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
-    // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
-    //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
-    static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = {
-        /* .iface = */ {
-            /* .get_name         = */ ggml_backend_metal_buffer_type_mapped_get_name,
-            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
-            /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
-            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
-            /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,
-        },
-        /* .device  = */ &g_ggml_metal_device,
-        /* .context = */ NULL,
-    };
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
 
-    return &ggml_backend_buffer_type_mapped_metal;
+    static std::vector bufts;
+    static std::vector ctxs;
+
+    static bool initialized = false;
+    if (!initialized) {
+        bufts.reserve(g_devices);
+        ctxs.reserve(g_devices);
+
+        for (int i = 0; i < g_devices; ++i) {
+            ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
+                /* .device = */ i,
+                /* .name   = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped"
+            };
+            ctxs.emplace_back(raw_ctx);
+
+            // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
+            //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
+            ggml_backend_buffer_type buft = {
+                /* .iface = */ {
+                    /* .get_name         = */ ggml_backend_metal_buffer_type_mapped_get_name,
+                    /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
+                    /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
+                    /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
+                    /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
+                    /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,
+                },
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
+                /* .context = */ raw_ctx,
+            };
+
+            bufts.emplace_back(buft);
+        }
+
+        initialized = true;
+    }
+
+    return &bufts[device];
 }
 
 // backend
 
 static const char * ggml_backend_metal_name(ggml_backend_t backend) {
-    return "Metal";
+    ggml_metal_t ctx = (ggml_metal_t)backend->context;
 
-    GGML_UNUSED(backend);
+    return ggml_metal_get_name(ctx);
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
@@ -409,12 +504,24 @@ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const gg
 }
 
 static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
-    return false;
+    if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) {
+        return false;
+    }
 
-    GGML_UNUSED(backend_src);
-    GGML_UNUSED(backend_dst);
-    GGML_UNUSED(src);
-    GGML_UNUSED(dst);
+    if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) {
+        return false;
+    }
+
+    ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context;
+    ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context;
+
+    //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+    //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+    //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context;
+    //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context;
+
+    return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst);
 }
 
 static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
@@ -423,6 +530,20 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend,
     return ggml_metal_graph_compute(ctx, cgraph);
 }
 
+static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_metal_t ctx = (ggml_metal_t)backend->context;
+    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
+
+    ggml_metal_event_record(ctx, ev);
+}
+
+static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_metal_t ctx = (ggml_metal_t)backend->context;
+    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
+
+    ggml_metal_event_wait(ctx, ev);
+}
+
 static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_metal_t ctx = (ggml_metal_t)backend->context;
 
@@ -435,7 +556,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
     ggml_metal_t ctx = (ggml_metal_t)backend->context;
 
     ggml_metal_set_n_cb(ctx, n_cb);
-
 }
 
 static ggml_backend_i ggml_backend_metal_i = {
@@ -450,12 +570,8 @@ static ggml_backend_i ggml_backend_metal_i = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_metal_graph_compute,
-
-    // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal
-    // in any case, these docs seem relevant if we ever decide to implement it:
-    // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
+    /* .event_record            = */ ggml_backend_metal_event_record,
+    /* .event_wait              = */ ggml_backend_metal_event_wait,
     /* .graph_optimize          = */ ggml_backend_metal_graph_optimize,
 };
 
@@ -519,15 +635,17 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
 // backend device
 
 static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
-    return "Metal";
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
-    GGML_UNUSED(dev);
+    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
+
+    return props_dev->name;
 }
 
 static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
     ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
-    return ggml_metal_device_get_props(ctx_dev)->name;
+    return ggml_metal_device_get_props(ctx_dev)->desc;
 }
 
 static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
@@ -550,14 +668,14 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac
     ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
 
     props->caps = {
-        /* .async                 = */ true,
-        /* .host_buffer           = */ false,
-        /* .buffer_from_host_ptr  = */ true,
-        /* .events                = */ false,
+        /* .async                = */ true,
+        /* .host_buffer          = */ false,
+        /* .buffer_from_host_ptr = */ true,
+        /* .events               = */ true,
     };
 }
 
-static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
+static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) {
     ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
 
     ggml_metal_t ctx = ggml_metal_init(ctx_dev);
@@ -587,7 +705,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml
 
     const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
 
-    return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private();
+    return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device);
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
@@ -595,7 +713,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backen
 
     ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);
 
-    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size);
+    const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
+
+    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size);
 }
 
 static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
@@ -606,9 +726,10 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
 
 static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
     return
+        buft->device == dev && (
         buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
         buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
-        buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name;
+        buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name);
 
     GGML_UNUSED(dev);
 }
@@ -632,45 +753,97 @@ static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const g
             get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;
 }
 
+static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) {
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
+
+    ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev);
+    GGML_ASSERT(event);
+
+    ggml_backend_event_t ev = new ggml_backend_event {
+        /* .device  = */ dev,
+        /* .context = */ event,
+    };
+
+    return ev;
+}
+
+static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
+
+    ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
+
+    ggml_metal_device_event_free(ctx_dev, ev);
+
+    delete event;
+}
+
+static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
+
+    ggml_metal_event_t evt = (ggml_metal_event_t)event->context;
+
+    ggml_metal_device_event_synchronize(ctx_dev, evt);
+}
+
 static ggml_backend_device_i ggml_backend_metal_device_i = {
     /* .get_name             = */ ggml_backend_metal_device_get_name,
     /* .get_description      = */ ggml_backend_metal_device_get_description,
     /* .get_memory           = */ ggml_backend_metal_device_get_memory,
     /* .get_type             = */ ggml_backend_metal_device_get_type,
     /* .get_props            = */ ggml_backend_metal_device_get_props,
-    /* .init_backend         = */ ggml_backend_metal_device_init,
+    /* .init_backend         = */ ggml_backend_metal_device_init_backend,
     /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
     /* .get_host_buffer_type = */ NULL,
     /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
     /* .supports_op          = */ ggml_backend_metal_device_supports_op,
     /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
     /* .offload_op           = */ ggml_backend_metal_device_offload_op,
-    /* .event_new            = */ NULL,
-    /* .event_free           = */ NULL,
-    /* .event_synchronize    = */ NULL,
+    /* .event_new            = */ ggml_backend_metal_device_event_new,
+    /* .event_free           = */ ggml_backend_metal_device_event_free,
+    /* .event_synchronize    = */ ggml_backend_metal_device_event_synchronize,
 };
 
 // backend registry
 
+struct ggml_backend_metal_reg {
+    std::vector devices;
+};
+
+typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t;
+
+static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) {
+    ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg;
+
+    return ctx;
+}
+
+static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) {
+    delete ctx;
+}
+
+struct ggml_backend_metal_reg_deleter {
+    void operator()(ggml_backend_metal_reg_t ctx) {
+        ggml_backend_metal_reg_free(ctx);
+    }
+};
+
+typedef std::unique_ptr ggml_backend_metal_reg_ptr;
+
 static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
-    return "Metal";
+    return GGML_METAL_NAME;
 
     GGML_UNUSED(reg);
 }
 
 static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
-    return 1;
-
-    GGML_UNUSED(reg);
+    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
+    return ctx->devices.size();
 }
 
 static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
-    GGML_ASSERT(index == 0);
-
-    return &g_ggml_metal_device;
-
-    GGML_UNUSED(reg);
-    GGML_UNUSED(index);
+    ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
+    GGML_ASSERT(index < ctx->devices.size());
+    return ctx->devices[index];
 }
 
 static ggml_backend_feature g_ggml_backend_metal_features[] = {
@@ -698,27 +871,67 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const
 
 static ggml_backend_reg_i ggml_backend_metal_reg_i = {
     /* .get_name         = */ ggml_backend_metal_reg_get_name,
-    /* .device_count     = */ ggml_backend_metal_reg_device_count,
-    /* .device_get       = */ ggml_backend_metal_reg_device_get,
+    /* .get_device_count = */ ggml_backend_metal_reg_device_count,
+    /* .get_device       = */ ggml_backend_metal_reg_device_get,
     /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
 };
 
-ggml_backend_reg_t ggml_backend_metal_reg(void) {
-    {
-        g_ggml_metal_reg = {
-            /* .api_version = */ GGML_BACKEND_API_VERSION,
-            /* .iface       = */ ggml_backend_metal_reg_i,
-            /* .context     = */ NULL,
-        };
+static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) {
+    return new ggml_backend_device {
+        /* .iface   = */ ggml_backend_metal_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ ggml_metal_device_get(device),
+    };
+}
 
-        g_ggml_metal_device = {
-            /* .iface   = */ ggml_backend_metal_device_i,
-            /* .reg     = */ &g_ggml_metal_reg,
-            /* .context = */ ggml_metal_device_get(),
-        };
+static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) {
+    delete dev;
+}
+
+struct ggml_backend_device_deleter {
+    void operator()(ggml_backend_dev_t ctx) {
+        ggml_backend_metal_device_free(ctx);
+    }
+};
+
+typedef std::unique_ptr ggml_backend_device_ptr;
+
+ggml_backend_reg_t ggml_backend_metal_reg(void) {
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
+    {
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+
+        const char * env = getenv("GGML_METAL_DEVICES");
+        if (env) {
+            g_devices = atoi(env);
+        }
+
+        static std::vector devs;
+
+        if (!initialized) {
+            static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());
+
+            for (int i = 0; i < g_devices; ++i) {
+                auto * dev = ggml_backend_metal_device_init(®, i);
+                devs.emplace_back(dev);
+
+                reg_ctx->devices.push_back(dev);
+            }
+
+            reg = {
+                /* .api_version = */ GGML_BACKEND_API_VERSION,
+                /* .iface       = */ ggml_backend_metal_reg_i,
+                /* .context     = */ reg_ctx.get(),
+            };
+        }
+
+        initialized = true;
     }
 
-    return &g_ggml_metal_reg;
+    return ®
 }
 
 GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 16d17d26..b2328605 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
     return x*y;
 }
 
+static inline float sum(float x) {
+    return x;
+}
+
+static inline float sum(float4 x) {
+    return x[0] + x[1] + x[2] + x[3];
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template 
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -895,60 +903,218 @@ enum ggml_sort_order {
     GGML_SORT_ORDER_DESC,
 };
 
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-template 
-kernel void kernel_add_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
+
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+constant float p_erf  = 0.3275911f;
+constant float a1_erf = 0.254829592f;
+constant float a2_erf = -0.284496736f;
+constant float a3_erf = 1.421413741f;
+constant float a4_erf = -1.453152027f;
+constant float a5_erf = 1.061405429f;
+
+template
+inline T erf_approx(T x) {
+    T sign_x = sign(x);
+    x = fabs(x);
+    T t = 1.0f / (1.0f + p_erf * x);
+    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    return sign_x * y;
+}
+
+template T elu_approx(T x);
+
+template<> inline float elu_approx(float x) {
+    return (x > 0.f) ? x : (exp(x) - 1);
+}
+
+template<> inline float4 elu_approx(float4 x) {
+    float4 res;
+
+    res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
+    res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
+    res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
+    res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
+
+    return res;
+}
+
+constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
+constant bool  FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
+
+template 
+kernel void kernel_unary_impl(
+        constant ggml_metal_kargs_unary & args,
         device const char * src0,
-        device const char * src1,
         device       char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+#define FC_OP  FC_unary_op
+#define FC_CNT FC_unary_cnt
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+    device const T0 * src0_ptr;
+    device       T  * dst_ptr;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
-    device       float * dst_ptr  = (device       float *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+    int i0;
 
-    device const float * src1_ptr[F];
-    for (short j = 0; j < F; ++j) {
-        src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+    if (FC_CNT) {
+        i0 = tgpig.x;
+
+        src0_ptr = (device const T0 *) (src0);
+        dst_ptr  = (device       T  *) (dst);
+    } else {
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int k0  = tgpig.x/args.ne01;
+        const int i01 = tgpig.x - k0*args.ne01;
+
+        i0 = k0*ntg.x + tpitg.x;
+
+        src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+        dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1 );
     }
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
+    {
+        //threadgroup_barrier(mem_flags::mem_none);
 
-        float res = src0_ptr[i0];
-
-#pragma unroll
-        for (short j = 0; j < F; ++j) {
-            res += src1_ptr[j][i10];
+        if (!FC_CNT) {
+            if (i0 >= args.ne0) {
+                return;
+            }
         }
 
-        dst_ptr[i0] = res;
+        const TC x = (TC) src0_ptr[i0];
+
+        if (FC_OP == OP_UNARY_NUM_SCALE) {
+            dst_ptr[i0] = (T) (args.scale * x + args.bias);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_FILL) {
+            dst_ptr[i0] = (T) args.val;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_CLAMP) {
+            dst_ptr[i0] = (T) clamp(x, args.min, args.max);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SQR) {
+            dst_ptr[i0] = (T) (x * x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SQRT) {
+            dst_ptr[i0] = (T) sqrt(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SIN) {
+            dst_ptr[i0] = (T) sin(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_COS) {
+            dst_ptr[i0] = (T) cos(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_LOG) {
+            dst_ptr[i0] = (T) log(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
+            dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_TANH) {
+            dst_ptr[i0] = (T) precise::tanh(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_RELU) {
+            dst_ptr[i0] = (T) fmax(0, x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SIGMOID) {
+            dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_GELU) {
+            dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
+            dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
+            dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SILU) {
+            dst_ptr[i0] = (T) (x / (1 + exp(-x)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_ELU) {
+            dst_ptr[i0] = (T) elu_approx(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_NEG) {
+            dst_ptr[i0] = (T) -x;
+        }
+
+        if (FC_OP == OP_UNARY_NUM_ABS) {
+            dst_ptr[i0] = (T) fabs(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SGN) {
+            dst_ptr[i0] = T(x > 0) - T(x < 0);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_STEP) {
+            dst_ptr[i0] = T(x > 0);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
+            dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
+            dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
+        }
+
+        if (FC_OP == OP_UNARY_NUM_EXP) {
+            dst_ptr[i0] = (T) exp(x);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
+            dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
+        }
+
+        if (FC_OP == OP_UNARY_NUM_EXPM1) {
+            // TODO: precise implementation
+            dst_ptr[i0] = (T) (exp(x) - 1);
+        }
     }
+
+#undef FC_OP
+#undef FC_CNT
 }
 
-typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
+typedef decltype(kernel_unary_impl) kernel_unary_t;
 
-template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
-template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
-template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
-template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
-template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>;
-template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>;
-template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
-template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
+template [[host_name("kernel_unary_f32_f32")]]   kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f16_f16")]]   kernel kernel_unary_t kernel_unary_impl;
+template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl;
 
-kernel void kernel_sub_fuse_1(
+// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
+constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
+constant short FC_bin_f  [[function_constant(FC_BIN + 1)]];
+constant bool  FC_bin_rb [[function_constant(FC_BIN + 2)]];
+constant bool  FC_bin_cb [[function_constant(FC_BIN + 3)]];
+
+template 
+kernel void kernel_bin_fuse_impl(
         constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
@@ -956,89 +1122,154 @@ kernel void kernel_sub_fuse_1(
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+#define FC_OP FC_bin_op
+#define FC_F  FC_bin_f
+#define FC_RB FC_bin_rb
+#define FC_CB FC_bin_cb
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+    if (FC_RB) {
+        // row broadcast
+        const uint i0 = tgpig.y*args.ne00 + tgpig.x;
+        const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x;
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+        device const T0 * src0_row = (device const T0 *) (src0);
+        device       T  * dst_row  = (device       T  *) (dst);
 
-    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-        const int i10 = i0%args.ne10;
-        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
-    }
-}
+        if (FC_F == 1) {
+            device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
 
-kernel void kernel_mul_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+            if (FC_OP == 0) {
+                dst_row[i0] = src0_row[i0] + src1_row[i1];
+            }
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
+            if (FC_OP == 1) {
+                dst_row[i0] = src0_row[i0] - src1_row[i1];
+            }
 
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+            if (FC_OP == 2) {
+                dst_row[i0] = src0_row[i0] * src1_row[i1];
+            }
 
-    if (args.ne10 == 1) {
-        const float x = *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
+            if (FC_OP == 3) {
+                dst_row[i0] = src0_row[i0] / src1_row[i1];
+            }
+        } else {
+            T0 res = src0_row[i0];
+
+            if (FC_OP == 0) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res += ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
+
+            if (FC_OP == 1) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
+
+            if (FC_OP == 2) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
+
+            if (FC_OP == 3) {
+                FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                    res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
+                }
+            }
+
+            dst_row[i0] = res;
         }
     } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+        const int i03 = tgpig.z;
+        const int i02 = tgpig.y;
+        const int i01 = tgpig.x;
+
+        if (i01 >= args.ne01) {
+            return;
+        }
+
+        const int i13 = i03%args.ne13;
+        const int i12 = i02%args.ne12;
+        const int i11 = i01%args.ne11;
+
+        device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
+        device       T  * dst_ptr  = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs);
+
+        if (FC_F == 1) {
+            device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = FC_CB ? i0%args.ne10 : i0;
+
+                if (FC_OP == 0) {
+                    dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
+                }
+
+                if (FC_OP == 1) {
+                    dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
+                }
+
+                if (FC_OP == 2) {
+                    dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
+                }
+
+                if (FC_OP == 3) {
+                    dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
+                }
+            }
+        } else {
+            device const T1 * src1_ptr[8];
+            FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
+            }
+
+            for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+                const int i10 = FC_CB ? i0%args.ne10 : i0;
+
+                T res = src0_ptr[i0];
+
+                if (FC_OP == 0) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res += src1_ptr[j][i10];
+                    }
+                }
+
+                if (FC_OP == 1) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res -= src1_ptr[j][i10];
+                    }
+                }
+
+                if (FC_OP == 2) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res *= src1_ptr[j][i10];
+                    }
+                }
+
+                if (FC_OP == 3) {
+                    FOR_UNROLL (short j = 0; j < FC_F; ++j) {
+                        res /= src1_ptr[j][i10];
+                    }
+                }
+
+                dst_ptr[i0] = res;
+            }
         }
     }
+
+#undef FC_OP
+#undef FC_F
+#undef FC_RB
+#undef FC_CB
 }
 
-kernel void kernel_div_fuse_1(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig.z;
-    const int i02 = tgpig.y;
-    const int i01 = tgpig.x;
+typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t;
 
-    const int i13 = i03%args.ne13;
-    const int i12 = i02%args.ne12;
-    const int i11 = i01%args.ne11;
-
-    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
-    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0];
-    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
-
-    if (args.ne10 == 1) {
-        const float x = 1.0f / *((device float *)(src1_ptr));
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x;
-        }
-    } else {
-        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
-            const int i10 = i0%args.ne10;
-            *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
-        }
-    }
-}
+template [[host_name("kernel_bin_fuse_f32_f32_f32")]]   kernel kernel_bin_fuse_t kernel_bin_fuse_impl;
+template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl;
 
 kernel void kernel_add_id(
         constant ggml_metal_kargs_add_id & args,
@@ -1057,7 +1288,7 @@ kernel void kernel_add_id(
     const size_t nb1 = args.ne0 * sizeof(float);
     const size_t nb2 = args.ne1 * nb1;
 
-    device       float * dst_row  = (device       float *)((device char *)dst + i1*nb1 + i2*nb2);
+    device       float * dst_row  = (device       float *)((device char *)dst  +  i1*nb1       + i2*nb2);
     device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
     device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
 
@@ -1098,549 +1329,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat
 template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
 template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
 
-// assumption: src1 is a row
-// broadcast src1 into src0
-template 
-kernel void kernel_add_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res += ((device const float4 *) (src1 + args.o1[j]))[i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
-
-template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
-template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
-template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
-template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
-template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>;
-template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>;
-template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>;
-template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>;
-
-template 
-kernel void kernel_sub_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res -= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
-
-template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
-
-template 
-kernel void kernel_mul_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res *= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
-
-template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
-
-template 
-kernel void kernel_div_row_c4_fuse_impl(
-        constant ggml_metal_kargs_bin & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint tpig[[thread_position_in_grid]]) {
-
-    const uint nb = args.ne00/4;
-    const uint i  = tpig % nb;
-
-    device const float4 * src0_row = (device const float4 *) (src0);
-    device       float4 *  dst_row = (device       float4 *) (dst);
-
-    device const float4 * src1_row[F];
-    for (short j = 0; j < F; ++j) {
-        src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
-    }
-
-    float4 res = src0_row[tpig];
-
-#pragma unroll(F)
-    for (short j = 0; j < F; ++j) {
-        res /= src1_row[j][i];
-    }
-
-    dst_row[tpig] = res;
-}
-
-typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
-
-template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
-
-kernel void kernel_scale_f32(
-        constant ggml_metal_kargs_scale & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_scale_f32_4(
-        constant ggml_metal_kargs_scale & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * args.scale + args.bias;
-}
-
-kernel void kernel_fill_f32(
-        constant ggml_metal_kargs_fill & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_fill_f32_4(
-        constant ggml_metal_kargs_fill & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = args.val;
-}
-
-kernel void kernel_clamp_f32(
-        constant ggml_metal_kargs_clamp & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_clamp_f32_4(
-        constant ggml_metal_kargs_clamp & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = clamp(src0[tpig], args.min, args.max);
-}
-
-kernel void kernel_relu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_relu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = max(0.0f, src0[tpig]);
-}
-
-kernel void kernel_sigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_sigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
-}
-
-kernel void kernel_tanh_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-kernel void kernel_tanh_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = precise::tanh(src0[tpig]);
-}
-
-constant float GELU_COEF_A     = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
-constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
-
-kernel void kernel_gelu_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    // BEWARE !!!
-    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
-    // This was observed with Falcon 7B and 40B models
-    //
-    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
-kernel void kernel_gelu_quick_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-kernel void kernel_gelu_quick_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
-
-// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
-// ref: https://www.johndcook.com/blog/python_erf/
-constant float p_erf  = 0.3275911f;
-constant float a1_erf = 0.254829592f;
-constant float a2_erf = -0.284496736f;
-constant float a3_erf = 1.421413741f;
-constant float a4_erf = -1.453152027f;
-constant float a5_erf = 1.061405429f;
-
-template
-T erf_approx(T x) {
-    T sign_x = sign(x);
-    x = fabs(x);
-    T t = 1.0f / (1.0f + p_erf * x);
-    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
-    return sign_x * y;
-}
-
-kernel void kernel_gelu_erf_f32(
-    device const float * src0,
-    device       float * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV));
-}
-
-kernel void kernel_gelu_erf_f32_4(
-    device const float4 * src0,
-    device       float4 * dst,
-    uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-
-    dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV));
-}
-
-kernel void kernel_silu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_silu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = x / (1.0f + exp(-x));
-}
-
-kernel void kernel_elu_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
-}
-
-kernel void kernel_elu_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
-    dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
-    dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
-    dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
-}
-
-kernel void kernel_sqr_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqr_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src0[tpig];
-}
-
-kernel void kernel_sqrt_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sqrt_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sqrt(src0[tpig]);
-}
-
-kernel void kernel_sin_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_sin_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sin(src0[tpig]);
-}
-
-kernel void kernel_cos_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_cos_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = cos(src0[tpig]);
-}
-
-kernel void kernel_log_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_log_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = log(src0[tpig]);
-}
-
-kernel void kernel_neg_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_neg_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = -src0[tpig];
-}
-
-kernel void kernel_abs_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_abs_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = fabs(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_sgn_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = sign(src0[tpig]);
-}
-
-kernel void kernel_step_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_step_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = step(0.0f, src0[tpig]);
-}
-
-kernel void kernel_hardswish_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardswish_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_hardsigmoid_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f));
-}
-
-kernel void kernel_exp_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
-
-kernel void kernel_exp_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]);
-}
-
-kernel void kernel_softplus_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
-
-kernel void kernel_softplus_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    device const float4 & x = src0[tpig];
-    dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f);
-}
-
-kernel void kernel_expm1_f32(
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
-
-kernel void kernel_expm1_f32_4(
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = exp(src0[tpig]) - 1.0f;
-}
-
 kernel void kernel_reglu_f32(
         constant ggml_metal_kargs_glu & args,
         device const char * src0,
@@ -1824,33 +1512,35 @@ kernel void kernel_op_sum_f32(
     }
 }
 
-template 
-kernel void kernel_sum_rows(
+constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
+
+template 
+kernel void kernel_sum_rows_impl(
         constant ggml_metal_kargs_sum_rows & args,
-        device const float * src0,
-        device       float * dst,
-        threadgroup  float * shmem_f32 [[threadgroup(0)]],
+        device const char * src0,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3 tpitg[[thread_position_in_threadgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
-    int64_t i3 = tgpig.z;
-    int64_t i2 = tgpig.y;
-    int64_t i1 = tgpig.x;
+#define FC_OP  FC_sum_rows_op
 
-    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
-        return;
-    }
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
 
     if (sgitg == 0) {
-        shmem_f32[tiisg] = 0.0f;
+        shmem_t[tiisg] = 0.0f;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
+    device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       T  * dst_row = (device       T  *) (dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
-    float sumf = 0;
+    T0 sumf = T0(0.0f);
 
     for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
         sumf += src_row[i0];
@@ -1861,23 +1551,33 @@ kernel void kernel_sum_rows(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     if (tiisg == 0) {
-        shmem_f32[sgitg] = sumf;
+        shmem_t[sgitg] = sumf;
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    sumf = shmem_f32[tiisg];
+    sumf = shmem_t[tiisg];
     sumf = simd_sum(sumf);
 
     if (tpitg.x == 0) {
-        dst_row[0] = norm ? sumf / args.ne00 : sumf;
+        if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
+            if (is_same::value) {
+                dst_row[0] = sum(sumf) / (4*args.ne00);
+            } else {
+                dst_row[0] = sum(sumf) / args.ne00;
+            }
+        } else {
+            dst_row[0] = sum(sumf);
+        }
     }
+
+#undef FC_OP
 }
 
-typedef decltype(kernel_sum_rows) kernel_sum_rows_t;
+typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t;
 
-template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows;
-template [[host_name("kernel_mean_f32")]]     kernel kernel_sum_rows_t kernel_sum_rows;
+template [[host_name("kernel_sum_rows_f32_f32")]]   kernel kernel_sum_rows_t kernel_sum_rows_impl;
+template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl;
 
 template
 kernel void kernel_cumsum_blk(
@@ -2737,6 +2437,302 @@ kernel void kernel_rwkv_wkv7_f32(
     }
 }
 
+constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
+constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
+
+#if 1
+template
+kernel void kernel_gated_delta_net_impl(
+        constant ggml_metal_kargs_gated_delta_net & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * g,
+        device const char * b,
+        device const char * s,
+        device       char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]])  {
+#define S_v FC_gated_delta_net_ne20
+#define G   FC_gated_delta_net_ne30
+
+    const uint tx = tpitg.x;
+    const uint ty = tpitg.y;
+
+    const uint i23 = tgpig.z; // B
+    const uint i21 = tgpig.y; // H
+    const uint i20 = tgpig.x*NSG + ty;
+
+    const uint i01 = i21 % args.ne01;
+    const uint i11 = i21 % args.ne11;
+
+    const float scale = 1.0f / sqrt((float)S_v);
+
+    // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
+
+    float ls[NSG];
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        ls[j] = s_ptr[is];
+    }
+
+    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+    device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+    device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+    for (short t = 0; t < args.ne22; t++) {
+        float s_k = 0.0f;
+
+        if (G == 1) {
+            const float g_exp = exp(g_ptr[0]);
+
+            FOR_UNROLL (short j = 0; j < NSG; j++) {
+                const short is = tx*NSG + j;
+                ls[j] *= g_exp;
+
+                s_k += ls[j]*k_ptr[is];
+            }
+        } else {
+            // KDA
+            FOR_UNROLL (short j = 0; j < NSG; j++) {
+                const short is = tx*NSG + j;
+                ls[j] *= exp(g_ptr[is]);
+
+                s_k += ls[j]*k_ptr[is];
+            }
+        }
+
+        s_k = simd_sum(s_k);
+
+        const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+        float y = 0.0f;
+
+        FOR_UNROLL (short j = 0; j < NSG; j++) {
+            const short is = tx*NSG + j;
+            ls[j] += k_ptr[is]*d;
+
+            y += ls[j]*q_ptr[is];
+        }
+
+        y = simd_sum(y);
+
+        if (tx == 0) {
+            dst_attn[t*args.ne21*S_v] = y*scale;
+        }
+
+        q_ptr += args.ns02;
+        k_ptr += args.ns12;
+        v_ptr += args.ns22;
+
+        b_ptr += args.ne21;
+        g_ptr += args.ne21*G;
+    }
+
+    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        dst_state[is] = ls[j];
+    }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
+
+#else
+// a simplified version of the above
+// no performance improvement, so keep the above version for now
+
+template
+kernel void kernel_gated_delta_net_impl(
+        constant ggml_metal_kargs_gated_delta_net & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * g,
+        device const char * b,
+        device const char * s,
+        device       char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]])  {
+#define S_v FC_gated_delta_net_ne20
+#define G   FC_gated_delta_net_ne30
+
+    const uint tx = tpitg.x;
+    const uint ty = tpitg.y;
+
+    const uint i23 = tgpig.z; // B
+    const uint i21 = tgpig.y; // H
+    const uint i20 = tgpig.x*NSG + ty;
+
+    const uint i01 = i21 % args.ne01;
+    const uint i11 = i21 % args.ne11;
+
+    const float scale = 1.0f / sqrt((float)S_v);
+
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    float lsf[NSG];
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        lsf[j] = s_ptr[is*S_v];
+    }
+
+    thread T * ls = (thread T *) (lsf);
+
+    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+    device const float * b_ptr  = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+    device const float * g_ptr  = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+    for (short t = 0; t < args.ne22; t++) {
+        device const T * qt_ptr = (device const T *) (q_ptr);
+        device const T * kt_ptr = (device const T *) (k_ptr);
+        device const T * gt_ptr = (device const T *) (g_ptr);
+
+        if (G == 1) {
+            *ls *= exp(g_ptr[0]);
+        } else {
+            // KDA
+            *ls *= exp(gt_ptr[tx]);
+        }
+
+        const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
+
+        const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+        *ls += kt_ptr[tx]*d;
+
+        const float y = simd_sum(dot(*ls, qt_ptr[tx]));
+
+        if (tx == 0) {
+            *dst_attn = y*scale;
+        }
+
+        q_ptr += args.ns02;
+        k_ptr += args.ns12;
+        v_ptr += args.ns22;
+
+        b_ptr += args.ne21;
+        g_ptr += args.ne21*G;
+
+        dst_attn += args.ne21*S_v;
+    }
+
+    device float * dst_state  = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    device T     * dstt_state = (device T     *) (dst_state);
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        dst_state[is*S_v] = lsf[j];
+    }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl;
+#endif
+
+constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
+constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
+constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
+
+kernel void kernel_solve_tri_f32(
+        constant ggml_metal_kargs_solve_tri & args,
+        device   const char * src0,
+        device   const char * src1,
+        device         char * dst,
+        threadgroup    char * shmem [[threadgroup(0)]],
+        ushort3 tgpig[[threadgroup_position_in_grid]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const short NSG = FC_solve_tri_nsg;
+    const short N   = FC_solve_tri_n;
+    const short K   = FC_solve_tri_k;
+    const short NP  = PAD2(N, NW);
+
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+    const int32_t i01 = tgpig.x*NSG + sgitg;
+
+    threadgroup float * sh0 = (threadgroup float *) shmem;
+
+    device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
+    device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
+    device       float * dst_ptr  = (device       float *)(dst  + i02 * args.nb2  + i03 * args.nb3)  + i01;
+
+    for (short rr = 0; rr < N; rr += NSG) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        {
+            threadgroup float * sh0_cur = sh0 + sgitg*NP;
+
+            for (short t = 0; t*NW < N; ++t) {
+                const short idx = t*NW + tiisg;
+                sh0_cur[idx] = src0_ptr[idx];
+            }
+
+            src0_ptr += NSG*N;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (i01 >= args.ne10) {
+            continue;
+        }
+
+        for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
+            const short r = rr + ir;
+
+            threadgroup float * sh0_cur = sh0 + ir*NP;
+
+            float sum = 0.0f;
+
+            for (short t = 0; t*NW < r; ++t) {
+                const short idx = t*NW + tiisg;
+                sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
+            }
+
+            sum = simd_sum(sum);
+
+            if (tiisg == 0) {
+                const float diag = sh0_cur[r];
+
+                dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
+            }
+        }
+    }
+}
+
 kernel void kernel_argmax_f32(
         constant ggml_metal_kargs_argmax & args,
         device   const char * src0,
@@ -2970,26 +2966,32 @@ template [[host_name("kernel_rms_norm_f32_4")]]         kernel kernel_rms_norm_f
 template [[host_name("kernel_rms_norm_mul_f32_4")]]     kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl;
 template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl;
 
-kernel void kernel_l2_norm_f32(
+template 
+kernel void kernel_l2_norm_impl(
         constant ggml_metal_kargs_l2_norm & args,
         device const char * src0,
         device       char * dst,
         threadgroup float * shmem_f32 [[threadgroup(0)]],
-        uint   tgpig[[threadgroup_position_in_grid]],
-        ushort tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort   ntg[[threads_per_threadgroup]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
     if (sgitg == 0) {
         shmem_f32[tiisg] = 0.0f;
     }
 
-    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+    device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T  * y = (device       T  *) (dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1);
 
     float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         sumf += dot(x[i00], x[i00]);
     }
     sumf = simd_sum(sumf);
@@ -3005,14 +3007,18 @@ kernel void kernel_l2_norm_f32(
     sumf = shmem_f32[tiisg];
     sumf = simd_sum(sumf);
 
-    const float scale = 1.0f/sqrt(max(sumf, args.eps));
+    const float scale = 1.0f/max(sqrt(sumf), args.eps);
 
-    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
-    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+    for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         y[i00] = x[i00] * scale;
     }
 }
 
+typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t;
+
+template [[host_name("kernel_l2_norm_f32_f32")]]   kernel kernel_l2_norm_t kernel_l2_norm_impl;
+template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl;
+
 kernel void kernel_group_norm_f32(
         constant ggml_metal_kargs_group_norm & args,
         device const float * src0,
@@ -3700,6 +3706,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]]    kernel mul_mv_ext_q4
 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;
 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4,        4,  dequantize_f16_t4>;
 
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4,      4,  dequantize_bf16_t4>;
+template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4,      4,  dequantize_bf16_t4>;
+template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4,      4,  dequantize_bf16_t4>;
+template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4,      4,  dequantize_bf16_t4>;
+#endif
+
 template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0,   32, dequantize_q4_0_t4>;
 template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0,   32, dequantize_q4_0_t4>;
 template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0,   32, dequantize_q4_0_t4>;
@@ -3750,6 +3763,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
 
+template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>;
+template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>;
+template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>;
+template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>;
+
+template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>;
+template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>;
+template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>;
+template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>;
+
 template
 void kernel_mul_mv_t_t_impl(
         args_t args,
@@ -4437,7 +4460,7 @@ kernel void kernel_im2col(
 template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
 template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
 
-// TODO: obolete -- remove
+// TODO: obsolete -- remove
 //typedef void (im2col_ext_t)(
 //        constant ggml_metal_kargs_im2col & args,
 //        device const float * x,
@@ -4749,7 +4772,9 @@ kernel void kernel_conv_transpose_2d(
     uint3   tpitg[[thread_position_in_threadgroup]],
     uint3     ntg[[threads_per_threadgroup]]);
 
-kernel void kernel_upscale_f32(
+constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]];
+
+kernel void kernel_upscale_nearest_f32(
     constant ggml_metal_kargs_upscale & args,
     device  const char * src0,
     device        char * dst,
@@ -4775,6 +4800,156 @@ kernel void kernel_upscale_f32(
     }
 }
 
+static inline float bilinear_tri(float x) {
+    return MAX(0.0f, 1.0f - fabs(x));
+}
+
+kernel void kernel_upscale_bilinear_f32(
+    constant ggml_metal_kargs_upscale & args,
+    device  const char * src0,
+    device        char * dst,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3 / args.sf3;
+    const int64_t i02 = i2 / args.sf2;
+
+    const float   f01  = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
+    const int64_t i01  = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01)));
+    const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1));
+    const float   fd1  = MAX(0.0f, MIN(1.0f, f01 - (float)i01));
+
+    src0 += i03*args.nb03 + i02*args.nb02;
+
+    device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
+
+    if (FC_upscale_aa) {
+        const float support0  = MAX(1.0f, 1.0f / args.sf0);
+        const float invscale0 = 1.0f / support0;
+        const float support1  = MAX(1.0f, 1.0f / args.sf1);
+        const float invscale1 = 1.0f / support1;
+
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
+
+            int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs));
+            int64_t x_max = MIN(args.ne00,  (int64_t)ceil (f00 + support0 + args.poffs));
+
+            int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs));
+            int64_t y_max = MIN(args.ne01,  (int64_t)ceil (f01 + support1 + args.poffs));
+
+            float sum = 0.0f;
+            float wsum = 0.0f;
+
+            for (int64_t sy = y_min; sy < y_max; ++sy) {
+                const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1);
+                for (int64_t sx = x_min; sx < x_max; ++sx) {
+                    const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
+                    const float w  = wx * wy;
+                    const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
+                    sum  += (*src_ptr) * w;
+                    wsum += w;
+                }
+            }
+
+            const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f;
+            dst_ptr[i0] = v;
+        }
+    } else {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            const float   f00  = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
+            const int64_t i00  = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00)));
+            const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1));
+            const float   fd0  = MAX(0.0f, MIN(1.0f, f00 - (float)i00));
+
+            device const float * src00 = (device const float *)(src0 + i01*args.nb01  + i00*args.nb00);
+            device const float * src10 = (device const float *)(src0 + i01*args.nb01  + i00p*args.nb00);
+            device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00);
+            device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00);
+
+            const float v =
+                (*src00) * (1.0f - fd0) * (1.0f - fd1) +
+                (*src10) * fd0          * (1.0f - fd1) +
+                (*src01) * (1.0f - fd0) * fd1 +
+                (*src11) * fd0          * fd1;
+
+            dst_ptr[i0] = v;
+        }
+    }
+}
+
+static inline float bicubic_weight1(float x) {
+    const float a = -0.75f;
+    return ((a + 2) * x - (a + 3)) * x * x + 1;
+}
+
+static inline float bicubic_weight2(float x) {
+    const float a = -0.75f;
+    return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a;
+}
+
+kernel void kernel_upscale_bicubic_f32(
+    constant ggml_metal_kargs_upscale & args,
+    device  const char * src0,
+    device        char * dst,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3 / args.sf3;
+    const int64_t i02 = i2 / args.sf2;
+
+    const float   f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs;
+    const int64_t i01 = (int64_t)floor(f01);
+    const float   fd1 = f01 - (float)i01;
+
+    const float w_y0 = bicubic_weight2(fd1 + 1.0f);
+    const float w_y1 = bicubic_weight1(fd1);
+    const float w_y2 = bicubic_weight1(1.0f - fd1);
+    const float w_y3 = bicubic_weight2(2.0f - fd1);
+
+    const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02;
+
+    device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1);
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const float   f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs;
+        const int64_t i00 = (int64_t)floor(f00);
+        const float   fd0 = f00 - (float)i00;
+
+        const float w_x0 = bicubic_weight2(fd0 + 1.0f);
+        const float w_x1 = bicubic_weight1(fd0);
+        const float w_x2 = bicubic_weight1(1.0f - fd0);
+        const float w_x3 = bicubic_weight2(2.0f - fd0);
+
+        float sum = 0.0f;
+
+        for (int dy = -1; dy <= 2; ++dy) {
+            const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy));
+            const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3;
+
+            for (int dx = -1; dx <= 2; ++dx) {
+                const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
+                const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;
+
+                const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
+                sum += (*src_ptr) * wx * wy;
+            }
+        }
+
+        dst_ptr[i0] = sum;
+    }
+}
+
 kernel void kernel_pad_f32(
     constant ggml_metal_kargs_pad & args,
     device  const char * src0,
@@ -5114,24 +5289,6 @@ kernel void kernel_argsort_merge_f32_i32(
 template [[host_name("kernel_argsort_merge_f32_i32_asc")]]  kernel argsort_merge_t kernel_argsort_merge_f32_i32;
 template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32;
 
-kernel void kernel_leaky_relu_f32(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float * src0,
-        device       float * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float x = src0[tpig];
-    dst[tpig] = x > 0.0f ? x : x * args.slope;
-}
-
-kernel void kernel_leaky_relu_f32_4(
-        constant     ggml_metal_kargs_leaky_relu & args,
-        device const float4 * src0,
-        device       float4 * dst,
-        uint tpig[[thread_position_in_grid]]) {
-    const float4 x = src0[tpig];
-    dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
-}
-
 constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
 
 constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
@@ -5208,6 +5365,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
 // scan the blocks of the mask that are not masked
 // 0 -     masked (i.e. full of -INF, skip)
 // 1 - not masked (i.e. at least one element of the mask is not -INF)
+// 2 - all zero
 kernel void kernel_flash_attn_ext_blk(
         constant ggml_metal_kargs_flash_attn_ext_blk & args,
         device const char * mask,
@@ -5229,27 +5387,29 @@ kernel void kernel_flash_attn_ext_blk(
 
     device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
 
-    // fast route
-    if (res == 0) {
-        if (simd_max(*mask_src) > -MAXHALF/2) {
-            res = 1;
-        }
-    }
-
     // detailed check of the elements of the block
     if ((C > NW || Q > 1) && res == 0) {
-        half m = -MAXHALF;
+        half mmin =  MAXHALF;
+        half mmax = -MAXHALF;
 
         FOR_UNROLL (short j = 0; j < Q; ++j) {
             FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
-                m = max(m, mask_src[ii*NW]);
+                mmin = min(mmin, mask_src[ii*NW]);
+                mmax = max(mmax, mask_src[ii*NW]);
             }
 
             mask_src += args.nb31/2;
         }
 
-        if (simd_max(m) > -MAXHALF/2) {
-            res = 1;
+        mmin = simd_min(mmin);
+        mmax = simd_max(mmax);
+
+        if (mmax > -MAXHALF) {
+            if (mmin == 0.0 && mmax == 0.0) {
+                res = 2;
+            } else {
+                res = 1;
+            }
         }
     }
 
@@ -5491,9 +5651,13 @@ void kernel_flash_attn_ext_impl(
                 ic = 0;
             }
 
+            char blk_cur = 1;
+
             // read the mask into shared mem
             if (FC_flash_attn_ext_has_mask) {
-                if (blk[ic0] == 0) {
+                blk_cur = blk[ic0];
+
+                if (blk_cur == 0) {
                     FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
                         pm2[jj] += NW;
                     }
@@ -5501,16 +5665,22 @@ void kernel_flash_attn_ext_impl(
                     continue;
                 }
 
-                FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
-                    const short j = jj*NSG + sgitg;
+                if (blk_cur == 1) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        const short j = jj*NSG + sgitg;
 
-                    if (FC_flash_attn_ext_bc_mask) {
-                        sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
-                    } else {
-                        sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        if (FC_flash_attn_ext_bc_mask) {
+                            sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
+                        } else {
+                            sm2[j*SH + tiisg] = pm2[jj][tiisg];
+                        }
+
+                        pm2[jj] += NW;
+                    }
+                } else if (blk_cur == 2) {
+                    FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+                        pm2[jj] += NW;
                     }
-
-                    pm2[jj] += NW;
                 }
 
 #if 0
@@ -5552,9 +5722,7 @@ void kernel_flash_attn_ext_impl(
 
                 constexpr short NC = (C/8)/NSG;
 
-                // note: do not unroll for large heads
-                #pragma unroll (DK <= 64 ? NC : 1)
-                for (short cc = 0; cc < NC; ++cc) {
+                FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                     qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
 
                     if (DK % 16 != 0) {
@@ -5575,7 +5743,9 @@ void kernel_flash_attn_ext_impl(
                         k8x8_t mk[2];
                         q8x8_t mq[2];
 
-                        FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+                        // note: too much unroll can tank the performance for large heads
+                        #pragma unroll (MIN(DK8/2, 4*NSG))
+                        for (short i = 0; i < DK8/2; ++i) {
                             simdgroup_barrier(mem_flags::mem_none);
 
                             simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5675,10 +5845,12 @@ void kernel_flash_attn_ext_impl(
                 }
 
                 // mqk = mqk + slope*mask
-                if (FC_flash_attn_ext_has_bias) {
-                    s2 += s2_t(sm2[j*SH + tiisg])*slope;
-                } else {
-                    s2 += s2_t(sm2[j*SH + tiisg]);
+                if (blk_cur != 2) {
+                    if (FC_flash_attn_ext_has_bias) {
+                        s2 += s2_t(sm2[j*SH + tiisg])*slope;
+                    } else {
+                        s2 += s2_t(sm2[j*SH + tiisg]);
+                    }
                 }
 
                 M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
@@ -5749,7 +5921,9 @@ void kernel_flash_attn_ext_impl(
                                 pv  += 8*NS20;
                             }
                         } else {
-                            FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
+                            constexpr short NC = (C/8)/2;
+
+                            FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                                 s8x8_t vs[2];
 
                                 simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5929,7 +6103,7 @@ template<
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
     short DK,         // K head size
     short DV,         // V head size
-    short Q  = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
+    short Q  = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
     short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
         constant ggml_metal_kargs_flash_attn_ext & args,
@@ -5952,6 +6126,7 @@ kernel void kernel_flash_attn_ext(
       //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break;
       //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break;
         case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+        case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break;
     }
 #undef FWD_TMPL
 #undef FWD_ARGS
@@ -6001,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]]  kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32"  )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6015,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]]  kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_HAS_BF16)
@@ -6030,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
@@ -6045,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6059,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6073,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6087,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32"  )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -6101,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
@@ -6138,11 +6321,10 @@ template<
     void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
     short DK,       // K head size
     short DV,       // V head size
-    short NE,       // head elements per thread
-    short Q,        // queries per threadgroup
-    short C,        // cache items per threadgroup
-    short NSG>      // number of simd groups
-void kernel_flash_attn_ext_vec_impl(
+    short NE = 4,   // head elements per thread
+    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPSG,  // queries per threadgroup
+    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext_vec & args,
         device const char * q,
         device const char * k,
@@ -6159,6 +6341,7 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DV % 32 == 0, "DV must be divisible by 32");
 
 #define NWG  (FC_flash_attn_ext_vec_nwg)
+#define NSG  (FC_flash_attn_ext_vec_nsg)
 
 #define NS10 (FC_flash_attn_ext_vec_ns10)
 #define NS20 (FC_flash_attn_ext_vec_ns20)
@@ -6185,14 +6368,14 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
     static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
 
-    const short T = PK + NSG*SH; // shared memory size per query in (half)
+  //const short T = PK + NSG*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*PK); // holds the query data
-    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*PK); // same as above but in q4_t
-    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*PK); // scratch buffer for attention
-    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*PK); // same as above but in s4_t
-    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
-    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + NSG*PK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + NSG*PK); // same as above but in s4_t
+    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
+    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + NSG*PK + NSG*SH); // scratch buffer for the results
 
     // store the result for all queries in shared memory (the O matrix from the paper)
     so4 += tiisg;
@@ -6210,11 +6393,13 @@ void kernel_flash_attn_ext_vec_impl(
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q);
 
-    for (short i = tiisg; i < PK4; i += NW) {
-        if (iq1 < args.ne01 && i < DK4) {
-            sq4[i] = (q4_t) q4[i];
-        } else {
-            sq4[i] = (q4_t) 0.0f;
+    if (iq1 < args.ne01) {
+        for (short i = tiisg; i < PK4; i += NW) {
+            if (i < DK4) {
+                sq4[i] = (q4_t) q4[i];
+            } else {
+                sq4[i] = (q4_t) 0.0f;
+            }
         }
     }
 
@@ -6292,7 +6477,7 @@ void kernel_flash_attn_ext_vec_impl(
             }
 
             // skip -INF blocks
-            if (simd_max(sm[tiisg]) == -INFINITY) {
+            if (simd_max(sm[tiisg]) <= -MAXHALF) {
                 continue;
             }
 
@@ -6566,57 +6751,11 @@ void kernel_flash_attn_ext_vec_impl(
     }
 
 #undef NWG
+#undef NSG
 #undef NS10
 #undef NS20
 }
 
-template<
-    typename q4_t,  // query types in shared memory
-    typename k4_t,  // key types in shared memory
-    typename v4_t,  // value types in shared memory
-    typename qk_t,  // Q*K types
-    typename s_t,   // soft-max types
-    typename s4_t,
-    typename o4_t,  // attention accumulation types
-    typename kd4_t, // key type in device memory
-    short nl_k,
-    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
-    typename vd4_t, // value type in device memory
-    short nl_v,
-    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
-    short DK,       // K head size
-    short DV,       // V head size
-    short NE = 4,   // head elements per thread
-    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPTG,  // queries per threadgroup
-    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
-        constant ggml_metal_kargs_flash_attn_ext_vec & args,
-        device const char * q,
-        device const char * k,
-        device const char * v,
-        device const char * mask,
-        device const char * sinks,
-        device const char * pad,
-        device       char * dst,
-        threadgroup  half * shmem_f16 [[threadgroup(0)]],
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort  tiisg[[thread_index_in_simdgroup]],
-        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
-    switch (FC_flash_attn_ext_vec_nsg) {
-      // note: disabled cases to reduce library load time
-        case 1:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-        case 2:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-        case 4:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 8:  kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-      //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break;
-    }
-#undef FWD_TMPL
-#undef FWD_ARGS
-}
-
 // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
@@ -6715,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas
 template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
+template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_HAS_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 #if defined(GGML_METAL_HAS_BF16)
@@ -8779,6 +8929,26 @@ kernel void kernel_set_rows_f(
     }
 }
 
+kernel void kernel_diag_f32(
+        constant ggml_metal_kargs_diag & args,
+        device   const char * src0,
+        device         char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]]) {
+    constexpr short NW = N_SIMDWIDTH;
+
+    const int32_t i3 = tgpig.z;
+    const int32_t i2 = tgpig.y;
+    const int32_t i1 = tgpig.x;
+
+    device const float * src0_ptr = (device const float *)(src0 +                i2*args.nb02 + i3*args.nb03);
+    device       float * dst_ptr  = (device       float *)(dst  + i1*args.nb01 + i2*args.nb2  + i3*args.nb3);
+
+    for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
+        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
+    }
+}
+
 constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
 constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
 
@@ -8797,7 +8967,9 @@ kernel void kernel_mul_mm(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -8920,8 +9092,8 @@ kernel void kernel_mul_mm(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -9153,6 +9325,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_
 template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
 template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
+template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
 
 template
 kernel void kernel_mul_mm_id(
@@ -9170,7 +9343,9 @@ kernel void kernel_mul_mm_id(
     threadgroup S0 * sa = (threadgroup S0 *)(shmem);
     threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
 
+#ifdef GGML_METAL_HAS_TENSOR
     threadgroup float * sc = (threadgroup float *)(shmem);
+#endif
 
     constexpr int NR0 = 64;
     constexpr int NR1 = 32;
@@ -9305,8 +9480,8 @@ kernel void kernel_mul_mm_id(
             const short sx = (tiitg%NL1);
             const short sy = (tiitg/NL1)/8;
 
-            const short dx = sx;
-            const short dy = sy;
+          //const short dx = sx;
+          //const short dy = sy;
 
             const short ly = (tiitg/NL1)%8;
 
@@ -9869,6 +10044,74 @@ kernel void kernel_pool_2d_avg_f32(
     o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
+
+kernel void kernel_pool_1d_max_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = -INFINITY;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        int j = base + ki;
+        if (j < 0 || j >= args.IW){
+            continue;
+        }
+        float v = src[src_off + j];
+        acc = max(acc, v);
+    }
+
+    dst[dst_off + ow] = acc;
+}
+
+kernel void kernel_pool_1d_avg_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = 0.0f;
+    int   cnt = 0;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        const int j = base + ki;
+        if (j < 0 || j >= args.IW) {
+            continue;
+        }
+        acc += src[src_off + j];
+        cnt += 1;
+    }
+
+    dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
+}
+
 kernel void kernel_opt_step_adamw_f32(
         constant    ggml_metal_kargs_opt_step_adamw & args,
         device       float * x,
@@ -9919,7 +10162,7 @@ kernel void kernel_opt_step_sgd_f32(
 
 template
 kernel void kernel_memset(
-        constant ggml_metal_kargs_fill & args,
+        constant ggml_metal_kargs_memset & args,
         device T * dst,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = args.val;
diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt
index d8fa5310..1f825093 100644
--- a/ggml/src/ggml-opencl/CMakeLists.txt
+++ b/ggml/src/ggml-opencl/CMakeLists.txt
@@ -57,11 +57,13 @@ set(GGML_OPENCL_KERNELS
     add
     add_id
     argsort
+    tri
     fill
     clamp
     cpy
     cvt
     diag_mask_inf
+    diag
     div
     gelu
     gemv_noshuffle_general
@@ -69,6 +71,7 @@ set(GGML_OPENCL_KERNELS
     get_rows
     glu
     group_norm
+    solve_tri
     im2col_f32
     im2col_f16
     mean
@@ -83,7 +86,11 @@ set(GGML_OPENCL_KERNELS
     mul_mv_q4_0_f32_8x_flat
     mul_mv_q4_0_f32_1d_8x_flat
     mul_mv_q4_0_f32_1d_16x_flat
-    mul_mv_q6_k
+    mul_mv_q4_1_f32
+    mul_mv_q4_1_f32_flat
+    mul_mv_q4_k_f32
+    mul_mv_q6_k_f32
+    mul_mv_q6_k_f32_flat
     mul_mv_q8_0_f32
     mul_mv_q8_0_f32_flat
     mul_mv_mxfp4_f32
@@ -97,10 +104,19 @@ set(GGML_OPENCL_KERNELS
     gemv_moe_mxfp4_f32
     mul_mm_f32_f32_l4_lm
     mul_mm_f16_f32_l4_lm
+    mul_mm_q4_0_f32_l4_lm
+    mul_mm_q4_1_f32_l4_lm
     mul_mm_q8_0_f32_l4_lm
+    mul_mm_q6_k_f32_l4_lm
+    mul_mm_q8_0_f32_8x4
+    gemv_noshuffle_q4_1_f32
+    gemm_noshuffle_q4_1_f32
+    gemv_noshuffle_general_q8_0_f32
     mul
+    neg
     norm
     relu
+    l2_norm
     rms_norm
     rope
     scale
@@ -116,11 +132,13 @@ set(GGML_OPENCL_KERNELS
     ssm_conv
     sub
     sum_rows
+    cumsum
     transpose
     concat
     tsembd
     upscale
     tanh
+    exp
     expm1
     softplus
     pad
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
index d925f67f..e1dca6b4 100644
--- a/ggml/src/ggml-opencl/ggml-opencl.cpp
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -226,7 +226,8 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) {
         return ADRENO_GPU_GEN::A7X;
     }
 
-    if (strstr(device_name, "830")) {
+    if (strstr(device_name, "830") ||
+        strstr(device_name, "840")) {
         return ADRENO_GPU_GEN::A8X;
     }
 
@@ -312,7 +313,7 @@ struct ProfilingInfo {
     cl_ulong cmd_duration_ns;
     // The time for the kernel to complete - COMPLETE - END
     cl_ulong cmd_complete_duration_ns;
-    // Total time to finish the kernel - COMPELTE - QUEUED
+    // Total time to finish the kernel - COMPLETE - QUEUED
     cl_ulong cmd_total_duration_ns;
     // Global and local work sizes.
     size_t global_size[3];
@@ -398,6 +399,7 @@ struct ggml_backend_opencl_context {
     int adreno_wave_size;
 
     cl_bool non_uniform_workgroups;
+    size_t  image_max_buffer_size;
 
     cl_context context;
     cl_command_queue queue;
@@ -407,10 +409,13 @@ struct ggml_backend_opencl_context {
     ggml_cl_buffer prealloc_scales_trans;
     ggml_cl_buffer prealloc_act_trans;
 
+    // prealloc buffers for src0 and src1
+    ggml_cl_buffer prealloc_src0;
+    ggml_cl_buffer prealloc_src1;
+
     cl_program program_add;
     cl_program program_add_id;
     cl_program program_clamp;
-    cl_program program_cpy;
     cl_program program_cvt;
     cl_program program_diag_mask_inf;
     cl_program program_gelu;
@@ -447,7 +452,6 @@ struct ggml_backend_opencl_context {
     cl_program program_rms_norm;
     cl_program program_group_norm;
     cl_program program_rope;
-    cl_program program_scale;
     cl_program program_silu;
     cl_program program_sigmoid;
     cl_program program_softmax_f32;
@@ -456,11 +460,8 @@ struct ggml_backend_opencl_context {
     cl_program program_softmax_4_f16;
     cl_program program_argsort_f32_i32;
     cl_program program_sum_rows_f32;
-    cl_program program_repeat;
     cl_program program_pad;
-    cl_program program_tanh;
     cl_program program_upscale;
-    cl_program program_concat;
     cl_program program_conv_2d_f16;
     cl_program program_conv_2d_f32;
     cl_program program_conv_2d_f16_f32;
@@ -479,24 +480,27 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
     cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
     cl_kernel kernel_add_id;
-    cl_kernel kernel_scale;
+    cl_kernel kernel_scale_f32, kernel_scale_f32_4;
     cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
     cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
-    cl_kernel kernel_mean_f32;
+    cl_kernel kernel_mean_f32, kernel_mean_f32_4;
     cl_kernel kernel_silu, kernel_silu_4;
     cl_kernel kernel_gelu, kernel_gelu_4;
     cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
     cl_kernel kernel_gelu_quick, kernel_gelu_quick_4;
     cl_kernel kernel_relu;
     cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
+    cl_kernel kernel_tri;
     cl_kernel kernel_fill;
     cl_kernel kernel_clamp;
     cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
               kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
     cl_kernel kernel_norm, kernel_norm_mul_add;
     cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
+    cl_kernel kernel_l2_norm_f32;
     cl_kernel kernel_group_norm, kernel_group_norm_mul_add;
     cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
+    cl_kernel kernel_diag_f32;
     cl_kernel kernel_soft_max, kernel_soft_max_4;
     cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
     std::map, cl_kernel> kernels_flash_attn_f16;
@@ -511,7 +515,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32;
     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
     cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
-    cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
+    cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32;
     cl_kernel kernel_mul_mat_f32_f32;
     cl_kernel kernel_mul_mat_f16_f16;
     cl_kernel kernel_mul_mat_f16_f32_1row;
@@ -522,30 +526,43 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mm_f16_f32_kq;
     cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
     cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
+    cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
     cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
-    cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
+    cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
     cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
     cl_kernel kernel_convert_block_q4_0_noshuffle;
     cl_kernel kernel_restore_block_q4_0_noshuffle;
+    cl_kernel kernel_convert_block_q4_1_noshuffle;
+    cl_kernel kernel_restore_block_q4_1_noshuffle;
+    cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
     cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
+    cl_kernel kernel_mul_mv_q4_1_f32;
+    cl_kernel kernel_mul_mv_q4_1_f32_flat;
+    cl_kernel kernel_mul_mv_q4_K_f32;
     cl_kernel kernel_mul_mv_q6_K_f32;
+    cl_kernel kernel_mul_mv_q6_K_f32_flat;
     cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
     cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat;
+    cl_kernel kernel_solve_tri_f32;
     cl_kernel kernel_im2col_f32, kernel_im2col_f16;
     cl_kernel kernel_argsort_f32_i32;
-    cl_kernel kernel_sum_rows_f32;
-    cl_kernel kernel_repeat;
+    cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4;
+    cl_kernel kernel_cumsum_blk, kernel_cumsum_add;
+    cl_kernel kernel_repeat_f32;
     cl_kernel kernel_pad;
-    cl_kernel kernel_tanh_f32_nd;
-    cl_kernel kernel_tanh_f16_nd;
-    cl_kernel kernel_expm1_f32_nd;
-    cl_kernel kernel_expm1_f16_nd;
-    cl_kernel kernel_softplus_f32_nd;
-    cl_kernel kernel_softplus_f16_nd;
+    cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
+    cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;
+    cl_kernel kernel_neg_f32, kernel_neg_f32_4, kernel_neg_f32_nc;
+    cl_kernel kernel_neg_f16, kernel_neg_f16_4, kernel_neg_f16_nc;
+    cl_kernel kernel_exp_f32, kernel_exp_f32_4, kernel_exp_f32_nc;
+    cl_kernel kernel_exp_f16, kernel_exp_f16_4, kernel_exp_f16_nc;
+    cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc;
+    cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc;
+    cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc;
+    cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc;
     cl_kernel kernel_upscale;
     cl_kernel kernel_upscale_bilinear;
-    cl_kernel kernel_concat_f32_contiguous;
-    cl_kernel kernel_concat_f32_non_contiguous;
+    cl_kernel kernel_concat_f32;
     cl_kernel kernel_conv_2d_f16;
     cl_kernel kernel_conv_2d_f32;
     cl_kernel kernel_conv_2d_f16_f32;
@@ -558,7 +575,10 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
     cl_kernel kernel_mul_mm_f32_f32_l4_lm;
     cl_kernel kernel_mul_mm_f16_f32_l4_lm;
+    cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
+    cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
     cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
+    cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
 
     std::vector profiling_info;
 
@@ -671,7 +691,9 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_transpose_32;
     cl_kernel kernel_transpose_32_16;
     cl_kernel kernel_transpose_16;
+    cl_kernel kernel_transpose_8_buf;
     cl_kernel kernel_transpose_16_buf;
+    cl_kernel kernel_transpose_32_buf;
     cl_kernel kernel_transpose_16_4x1;
 
     // Gemm and Gemv related programs, kernels, etc
@@ -687,6 +709,10 @@ struct ggml_backend_opencl_context {
     cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
     cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
     cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
+    cl_kernel kernel_gemv_noshuffle_q4_1_f32;
+    cl_kernel kernel_gemm_noshuffle_q4_1_f32;
+    cl_kernel kernel_mul_mm_q8_0_f32_8x4;
+    cl_kernel CL_mul_mat_vec_q8_0_f32;
 #endif // GGML_OPENCL_USE_ADRENO_KERNELS
 
     void free() {
@@ -792,6 +818,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // tri
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "tri.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("tri.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err));
+        GGML_LOG_CONT(".");
+
+        CL_CHECK(clReleaseProgram(prog));
+    }
+
     // fill
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -835,13 +879,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("cpy.cl");
 #endif
-        backend_ctx->program_cpy =
+        cl_program prog =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err));
-        CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err));
-        CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err));
-        CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(prog, "kernel_cpy_f16_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -861,12 +906,19 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
         CL_CHECK((backend_ctx->kernel_convert_block_q4_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_q4_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
+        CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err));
+        CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err));
+        CL_CHECK((backend_ctx->kernel_convert_block_q4_1  = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err));
+        CL_CHECK((backend_ctx->kernel_restore_block_q4_1  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
         CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
         CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
         CL_CHECK((backend_ctx->kernel_convert_block_q8_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
         CL_CHECK((backend_ctx->kernel_restore_block_q8_0  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
+        CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err));
+        CL_CHECK((backend_ctx->kernel_convert_block_q6_K  = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err));
+        CL_CHECK((backend_ctx->kernel_restore_block_q6_K  = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -887,6 +939,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // diag
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "diag.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("diag.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_diag_f32 = clCreateKernel(prog, "kernel_diag_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     // gelu
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -952,6 +1021,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // solve_tri_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "solve_tri.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("solve_tri.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_solve_tri_f32 = clCreateKernel(prog, "kernel_solve_tri_f32", &err), err));
+        GGML_LOG_CONT(".");
+        CL_CHECK(clReleaseProgram(prog));
+    }
+
     // im2col_f32
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1072,14 +1158,65 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
-    // mul_mv_q6_k
+    // mul_mv_q4_1_f32
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
         const std::string kernel_src {
-            #include "mul_mv_q6_k.cl.h"
+            #include "mul_mv_q4_1_f32.cl.h"
         };
 #else
-        const std::string kernel_src = read_file("mul_mv_q6_k.cl");
+        const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mv_q4_1_f32_flat
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_q4_1_f32_flat.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mv_q4_k_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_q4_k_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mv_q6_k_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_q6_k_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl");
 #endif
         backend_ctx->program_mul_mv_q6_K =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
@@ -1088,6 +1225,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mv_q6_k_f32_flat
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mv_q6_k_f32_flat.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     // mul_mv_q8_0_f32
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1280,6 +1434,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_q4_0_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_q4_0_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mm_q4_1_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_q4_1_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // mul_mm_q8_0_f32_l4_lm
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1296,6 +1482,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // mul_mm_q6_k_f32_l4_lm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "mul_mm_q6_k_f32_l4_lm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     // mul_mm_f16_f32_kq_kqv
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1384,6 +1587,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // l2_norm
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "l2_norm.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("l2_norm.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_l2_norm_f32     = clCreateKernel(prog, "kernel_l2_norm_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     // rope
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1416,10 +1636,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("scale.cl");
 #endif
-        backend_ctx->program_scale =
+        cl_program prog =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err));
+        CL_CHECK((backend_ctx->kernel_scale_f32   = clCreateKernel(prog, "kernel_scale_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_scale_f32_4 = clCreateKernel(prog, "kernel_scale_f32_4", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
         GGML_LOG_CONT(".");
     }
 
@@ -1664,6 +1886,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
         CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err));
 
         CL_CHECK(clReleaseProgram(prog));
         GGML_LOG_CONT(".");
@@ -1701,9 +1924,28 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
         CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err));
         GGML_LOG_CONT(".");
     }
 
+    // cumsum
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "cumsum.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("cumsum.cl");
+#endif
+        cl_program prog;
+        prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, "kernel_cumsum_blk", &err), err));
+        CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, "kernel_cumsum_add", &err), err));
+        GGML_LOG_CONT(".");
+        CL_CHECK(clReleaseProgram(prog));
+    }
+
     // sigmoid
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -1747,16 +1989,11 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("repeat.cl");
 #endif
-        if (!kernel_src.empty()) {
-            backend_ctx->program_repeat =
-                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
-            CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err));
-            GGML_LOG_CONT(".");
-        } else {
-            GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n");
-            backend_ctx->program_repeat = nullptr;
-            backend_ctx->kernel_repeat = nullptr;
-        }
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_repeat_f32 = clCreateKernel(prog, "kernel_repeat_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
     }
 
     // pad
@@ -1789,18 +2026,58 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("tanh.cl");
 #endif
-        if (!kernel_src.empty()) {
-            backend_ctx->program_tanh =
-                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
-            CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err));
-            CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err));
-            GGML_LOG_CONT(".");
-        } else {
-            GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n");
-            backend_ctx->program_tanh = nullptr;
-            backend_ctx->kernel_tanh_f32_nd = nullptr;
-            backend_ctx->kernel_tanh_f16_nd = nullptr;
-        }
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_tanh_f32    = clCreateKernel(prog, "kernel_tanh_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_tanh_f32_4  = clCreateKernel(prog, "kernel_tanh_f32_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_tanh_f32_nc = clCreateKernel(prog, "kernel_tanh_f32_nc", &err), err));
+        CL_CHECK((backend_ctx->kernel_tanh_f16    = clCreateKernel(prog, "kernel_tanh_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_tanh_f16_4  = clCreateKernel(prog, "kernel_tanh_f16_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_tanh_f16_nc = clCreateKernel(prog, "kernel_tanh_f16_nc", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // neg
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "neg.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("neg.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_neg_f32    = clCreateKernel(prog, "kernel_neg_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_neg_f32_4  = clCreateKernel(prog, "kernel_neg_f32_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_neg_f32_nc = clCreateKernel(prog, "kernel_neg_f32_nc", &err), err));
+        CL_CHECK((backend_ctx->kernel_neg_f16    = clCreateKernel(prog, "kernel_neg_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_neg_f16_4  = clCreateKernel(prog, "kernel_neg_f16_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_neg_f16_nc = clCreateKernel(prog, "kernel_neg_f16_nc", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // exp
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "exp.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("exp.cl");
+#endif
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_exp_f32    = clCreateKernel(prog, "kernel_exp_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_exp_f32_4  = clCreateKernel(prog, "kernel_exp_f32_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_exp_f32_nc = clCreateKernel(prog, "kernel_exp_f32_nc", &err), err));
+        CL_CHECK((backend_ctx->kernel_exp_f16    = clCreateKernel(prog, "kernel_exp_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_exp_f16_4  = clCreateKernel(prog, "kernel_exp_f16_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_exp_f16_nc = clCreateKernel(prog, "kernel_exp_f16_nc", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
     }
 
     // expm1
@@ -1812,20 +2089,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("expm1.cl");
 #endif
-        cl_program prog;
-        if (!kernel_src.empty()) {
-            prog =
-                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
-            CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err));
-            CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err));
-            GGML_LOG_CONT(".");
-        } else {
-            GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n");
-            prog = nullptr;
-            backend_ctx->kernel_expm1_f32_nd = nullptr;
-            backend_ctx->kernel_expm1_f16_nd = nullptr;
-        }
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_expm1_f32    = clCreateKernel(prog, "kernel_expm1_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_expm1_f32_4  = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err));
+        CL_CHECK((backend_ctx->kernel_expm1_f16    = clCreateKernel(prog, "kernel_expm1_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_expm1_f16_4  = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err));
         CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
     }
 
     // softplus
@@ -1837,20 +2110,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
 #else
         const std::string kernel_src = read_file("softplus.cl");
 #endif
-        cl_program prog;
-        if (!kernel_src.empty()) {
-            prog =
-                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
-            CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
-            CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
-            GGML_LOG_CONT(".");
-        } else {
-            GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
-            prog = nullptr;
-            backend_ctx->kernel_softplus_f32_nd = nullptr;
-            backend_ctx->kernel_softplus_f16_nd = nullptr;
-        }
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_softplus_f32    = clCreateKernel(prog, "kernel_softplus_f32", &err), err));
+        CL_CHECK((backend_ctx->kernel_softplus_f32_4  = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err));
+        CL_CHECK((backend_ctx->kernel_softplus_f16    = clCreateKernel(prog, "kernel_softplus_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_softplus_f16_4  = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err));
+        CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err));
         CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
     }
 
     // upscale
@@ -1892,22 +2161,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
             #include "concat.cl.h"
         };
 #else
-
         const std::string kernel_src = read_file("concat.cl");
 #endif
-        if (!kernel_src.empty()) {
-            backend_ctx->program_concat =
-                build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
-
-            CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err));
-            CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err));
-            GGML_LOG_CONT(".");
-        } else {
-            GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n");
-            backend_ctx->program_concat = nullptr;
-            backend_ctx->kernel_concat_f32_contiguous = nullptr;
-            backend_ctx->kernel_concat_f32_non_contiguous = nullptr;
-        }
+        cl_program prog =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
     }
 
     // timestep_embedding
@@ -2107,7 +2367,9 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err));
         CL_CHECK((backend_ctx->kernel_transpose_32    = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err));
         CL_CHECK((backend_ctx->kernel_transpose_16    = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err));
+        CL_CHECK((backend_ctx->kernel_transpose_8_buf  = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_8_buf", &err), err));
         CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err));
+        CL_CHECK((backend_ctx->kernel_transpose_32_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_buf", &err), err));
         CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err));
         GGML_LOG_CONT(".");
     }
@@ -2227,6 +2489,85 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // gemm_noshuffle_q4_1_f32
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "gemm_noshuffle_q4_1_f32.cl.h"
+       };
+#else
+        const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl");
+#endif
+        cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_1_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // gemv_noshuffle_q4_1_f32
+    {
+        std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
+                                       " -cl-mad-enable ";
+        if (backend_ctx->has_vector_subgroup_broadcast) {
+            CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+        }
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "gemv_noshuffle_q4_1_f32.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("gemv_noshuffle_q4_1_f32.cl");
+#endif
+
+        cl_program prog = build_program_from_source(
+            backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_1_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
+    // mul_mm_q8_0_f32_8x4
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src_q8_8x4_gemm {
+            #include "mul_mm_q8_0_f32_8x4.cl.h"
+       };
+#else
+        const std::string kernel_src_q8_8x4_gemm = read_file("mul_mm_q8_0_f32_8x4.cl");
+#endif
+        backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts);
+        CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mm_q8_0_f32_8x4", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
+    // gemv_noshuffle_general_q8_0_f32
+    {
+        std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std +
+                                       " -cl-mad-enable "
+                                       " -DSIMDGROUP_WIDTH=" +
+                                       std::to_string(backend_ctx->adreno_wave_size);
+        if (backend_ctx->has_vector_subgroup_broadcast) {
+            CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+        }
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src_CL_gemv_general {
+            #include "gemv_noshuffle_general_q8_0_f32.cl.h"
+        };
+#else
+        const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general_q8_0_f32.cl");
+#endif
+
+        cl_program prog = build_program_from_source(
+            backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);
+
+        CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err));
+        CL_CHECK(clReleaseProgram(prog));
+        GGML_LOG_CONT(".");
+    }
+
     std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
             " -cl-mad-enable "
             " -cl-fast-relaxed-math";
@@ -2315,7 +2656,7 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r
 
     cl_platform_id platform_ids[NPLAT];
     if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) {
-        GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n");
+        GGML_LOG_ERROR("ggml_opencl: platform IDs not available.\n");
         return found_devices;
     }
 
@@ -2621,6 +2962,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
     clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
     GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
 
+    clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL);
+    GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size);
+
     clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
     GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
 
@@ -2729,6 +3073,82 @@ static void ggml_cl2_free(ggml_backend_t backend) {
     }
 }
 
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+static void transpose_2d(
+    ggml_backend_opencl_context * backend_ctx,
+    cl_kernel kernel,
+    cl_mem src, cl_mem dst, size_t size,
+    cl_int stride, cl_int rows,
+    bool blocking = true
+) {
+    static ggml_cl_buffer buf;
+
+    cl_event evt;
+    cl_int err;
+
+    buf.allocate(backend_ctx->context, size);
+
+    cl_mem trans;
+    cl_buffer_region region;
+
+    region.origin = 0;
+    region.size = size;
+    CL_CHECK((trans = clCreateSubBuffer(
+        buf.buffer, CL_MEM_READ_WRITE,
+        CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &stride));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &rows));
+
+    size_t local_size[3] = {64, 1, 1};
+    size_t global_size[3] = {(size_t)stride, (size_t)rows, 1};;
+    CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL,
+        global_size, local_size, 0, NULL, NULL));
+
+    if (blocking) {
+        CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clReleaseEvent(evt));
+    } else {
+        CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, NULL));
+    }
+
+    CL_CHECK(clReleaseMemObject(trans));
+}
+
+static void transpose_2d_as_8b(
+    ggml_backend_opencl_context * backend_ctx,
+    cl_mem src, cl_mem dst, size_t size,
+    cl_int stride, cl_int rows,
+    bool blocking = true
+) {
+    transpose_2d(backend_ctx, backend_ctx->kernel_transpose_8_buf,
+        src, dst, size, stride, rows, blocking);
+}
+
+static void transpose_2d_as_16b(
+    ggml_backend_opencl_context * backend_ctx,
+    cl_mem src, cl_mem dst, size_t size,
+    cl_int stride, cl_int rows,
+    bool blocking = true
+) {
+    transpose_2d(backend_ctx, backend_ctx->kernel_transpose_16_buf,
+        src, dst, size, stride, rows, blocking);
+}
+
+static void transpose_2d_as_32b(
+    ggml_backend_opencl_context * backend_ctx,
+    cl_mem src, cl_mem dst, size_t size,
+    cl_int stride, cl_int rows,
+    bool blocking = true
+) {
+    transpose_2d(backend_ctx, backend_ctx->kernel_transpose_32_buf,
+        src, dst, size, stride, rows, blocking);
+}
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
 //------------------------------------------------------------------------------
 // Tensor extra management
 //------------------------------------------------------------------------------
@@ -2796,6 +3216,59 @@ struct ggml_tensor_extra_cl_q4_0 {
     }
 };
 
+struct ggml_tensor_extra_cl_q4_1 {
+    // Quantized values.
+    cl_mem q = nullptr;
+    // Quantized values in image1d_buffer_t.
+    cl_mem q_img = nullptr;
+    // Scales.
+    cl_mem d = nullptr;
+    // Scales in image1d_buffer_t.
+    cl_mem d_img = nullptr;
+    // Min
+    cl_mem m = nullptr;
+    // Min in image1d_buffer_t.
+    cl_mem m_img = nullptr;
+    // Size of quantized values.
+    size_t size_q = 0;
+    // Size of scales.
+    size_t size_d = 0;
+    // Size of min values.
+    size_t size_m = 0;
+
+    ~ggml_tensor_extra_cl_q4_1() {
+        reset();
+    }
+
+    void reset() {
+        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
+        // They must be properly released so that the original buffer can be
+        // properly released to avoid memory leak.
+        if (q != nullptr) {
+            CL_CHECK(clReleaseMemObject(q));
+            q = nullptr;
+        }
+        if (d != nullptr) {
+            CL_CHECK(clReleaseMemObject(d));
+            d = nullptr;
+        }
+        if (m != nullptr) {
+            CL_CHECK(clReleaseMemObject(m));
+            m = nullptr;
+        }
+        // Currently, q_img and d_img are only initialized when SMALL_ALLOC is
+        // enabled. They point to the images in ggml_backend_opencl_buffer_context.
+        // So, there is no need to release them here.
+        // TODO: initialize them for non SMALL_PATH path, or remove them.
+        q_img = nullptr;
+        d_img = nullptr;
+        m_img = nullptr;
+        size_q = 0;
+        size_d = 0;
+        size_m = 0;
+    }
+};
+
 struct ggml_tensor_extra_cl_mxfp4 {
     // Quantized values.
     cl_mem q = nullptr;
@@ -2874,6 +3347,50 @@ struct ggml_tensor_extra_cl_q8_0 {
     }
 };
 
+struct ggml_tensor_extra_cl_q6_K {
+    // Lower 4 bits of quantized weights.
+    cl_mem ql = nullptr;
+    // Upper 2 bits of quantized weights.
+    cl_mem qh = nullptr;
+    // Scales for each block.
+    cl_mem s  = nullptr;
+    // Scales for each super block.
+    cl_mem d  = nullptr;
+
+    size_t size_ql = 0;
+    size_t size_qh = 0;
+    size_t size_s  = 0;
+    size_t size_d  = 0;
+
+    ~ggml_tensor_extra_cl_q6_K() {
+        reset();
+    }
+
+    void reset() {
+        if (ql != nullptr) {
+            CL_CHECK(clReleaseMemObject(ql));
+            ql = nullptr;
+        }
+        if (qh != nullptr) {
+            CL_CHECK(clReleaseMemObject(qh));
+            qh = nullptr;
+        }
+        if (s != nullptr) {
+            CL_CHECK(clReleaseMemObject(s));
+            s = nullptr;
+        }
+        if (d != nullptr) {
+            CL_CHECK(clReleaseMemObject(d));
+            d = nullptr;
+        }
+
+        size_ql = 0;
+        size_qh = 0;
+        size_s  = 0;
+        size_d  = 0;
+    }
+};
+
 //------------------------------------------------------------------------------
 // Backend API
 //------------------------------------------------------------------------------
@@ -2923,7 +3440,7 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) {
     CL_CHECK(clReleaseEvent(evt));
 }
 
-// Syncronizes the 'backend_ctx's device with others so that commands
+// Synchronizes the 'backend_ctx's device with others so that commands
 // enqueued to it won't start until commands in the other devices have
 // completed.
 static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) {
@@ -3040,6 +3557,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
             continue;
         }
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) {
             ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
             i += 2;
@@ -3124,9 +3645,21 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                         default:
                             return false;
                     }
+                case GGML_TYPE_I32:
+                    switch (op->type) {
+                        case GGML_TYPE_I32:
+                            return true;
+                        default:
+                            return false;
+                    }
                 default:
                     return false;
             }
+        case GGML_OP_SET: {
+            return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) &&
+                    op->type == op->src[0]->type &&
+                    op->type == op->src[1]->type;
+        }
         case GGML_OP_SCALE:
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
@@ -3160,14 +3693,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 case GGML_UNARY_OP_SIGMOID:
                     return ggml_is_contiguous(op->src[0]);
                 case GGML_UNARY_OP_TANH:
-                   return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
-                          (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_EXP:
+                   return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
                 case GGML_UNARY_OP_EXPM1:
-                   return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
-                          (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
+                   return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
                 case GGML_UNARY_OP_SOFTPLUS:
-                   return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
-                          (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
+                   return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
                 default:
                     return false;
             }
@@ -3183,6 +3715,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 default:
                     return false;
             }
+        case GGML_OP_TRI:
+            return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
         case GGML_OP_FILL:
             return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op);
         case GGML_OP_CLAMP:
@@ -3192,6 +3726,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return true;
         case GGML_OP_RMS_NORM:
             return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]);
+        case GGML_OP_L2_NORM:
+            return ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_REPEAT:
             return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
         case GGML_OP_PAD:
@@ -3223,7 +3759,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 return true;
             } else if (op->src[0]->type == GGML_TYPE_F32) {
                 return op->src[1]->type == GGML_TYPE_F32;
-            } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
+            } else if (op->src[0]->type == GGML_TYPE_Q4_0  || op->src[0]->type == GGML_TYPE_Q4_1 ||
+                       op->src[0]->type == GGML_TYPE_MXFP4 ||
+                       op->src[0]->type == GGML_TYPE_Q4_K  ||
                        op->src[0]->type == GGML_TYPE_Q6_K) {
                 return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
             } else if (op->src[0]->type == GGML_TYPE_Q8_0) {
@@ -3244,6 +3782,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
             return true;
+        case GGML_OP_DIAG:
+            return true;
         case GGML_OP_DIAG_MASK_INF:
             return op->ne[3] == 1;
         case GGML_OP_ROPE: {
@@ -3266,6 +3806,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             }
             return true;
         }
+        case GGML_OP_SOLVE_TRI:
+            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
             return true;
         case GGML_OP_ARGSORT: {
@@ -3280,8 +3822,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
         }
         case GGML_OP_SUM_ROWS:
-        case GGML_OP_MEAN:
+        case GGML_OP_CUMSUM:
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
+        case GGML_OP_MEAN:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 const ggml_tensor * q = op->src[0];
@@ -3412,6 +3956,12 @@ struct ggml_backend_opencl_buffer_context {
         for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) {
             delete e;
         }
+        for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
+            delete e;
+        }
     }
 
     ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
@@ -3444,6 +3994,21 @@ struct ggml_backend_opencl_buffer_context {
         return extra;
     }
 
+    ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() {
+        ggml_tensor_extra_cl_q4_1 * extra;
+        if (temp_tensor_extras_q4_1.empty()) {
+            extra = new ggml_tensor_extra_cl_q4_1();
+        } else {
+            extra = temp_tensor_extras_q4_1.back();
+            temp_tensor_extras_q4_1.pop_back();
+        }
+
+        temp_tensor_extras_q4_1_in_use.push_back(extra);
+
+        extra->reset();
+        return extra;
+    }
+
     ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
         ggml_tensor_extra_cl_mxfp4 * extra;
         if (temp_tensor_extras_mxfp4.empty()) {
@@ -3474,6 +4039,21 @@ struct ggml_backend_opencl_buffer_context {
         return extra;
     }
 
+    ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() {
+        ggml_tensor_extra_cl_q6_K * extra;
+        if (temp_tensor_extras_q6_K.empty()) {
+            extra = new ggml_tensor_extra_cl_q6_K();
+        } else {
+            extra = temp_tensor_extras_q6_K.back();
+            temp_tensor_extras_q6_K.pop_back();
+        }
+
+        temp_tensor_extras_q6_K_in_use.push_back(extra);
+
+        extra->reset();
+        return extra;
+    }
+
     void reset() {
         for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
             temp_tensor_extras.push_back(e);
@@ -3485,6 +4065,11 @@ struct ggml_backend_opencl_buffer_context {
         }
         temp_tensor_extras_q4_0_in_use.clear();
 
+        for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) {
+            temp_tensor_extras_q4_1.push_back(e);
+        }
+        temp_tensor_extras_q4_1_in_use.clear();
+
         for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
             temp_tensor_extras_mxfp4.push_back(e);
         }
@@ -3494,6 +4079,11 @@ struct ggml_backend_opencl_buffer_context {
             temp_tensor_extras_q8_0.push_back(e);
         }
         temp_tensor_extras_q8_0_in_use.clear();
+
+        for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) {
+            temp_tensor_extras_q6_K.push_back(e);
+        }
+        temp_tensor_extras_q6_K_in_use.clear();
     }
 
     // Pools for extras. Available extras are in `temp_tensor_extras`. Extras
@@ -3505,14 +4095,18 @@ struct ggml_backend_opencl_buffer_context {
     std::vector temp_tensor_extras_in_use;
     std::vector temp_tensor_extras_q4_0;
     std::vector temp_tensor_extras_q4_0_in_use;
+    std::vector temp_tensor_extras_q4_1;
+    std::vector temp_tensor_extras_q4_1_in_use;
     std::vector temp_tensor_extras_mxfp4;
     std::vector temp_tensor_extras_mxfp4_in_use;
     std::vector temp_tensor_extras_q8_0;
     std::vector temp_tensor_extras_q8_0_in_use;
+    std::vector temp_tensor_extras_q6_K;
+    std::vector temp_tensor_extras_q6_K_in_use;
 
     // The buffer_context is initially created by ggml_backend_buft_alloc_buffer
     // before any tensor is initialized (at the beginning of alloc_tensor_range).
-    // Hence, there is alway a buffer object in this vector. When each tensor is
+    // Hence, there is always a buffer object in this vector. When each tensor is
     // being initialized, this original buffer object will be released if both
     // flattening and small allocation are enabled, and additional buffer
     // objects will be created in init_tensor to represent flattened quantized
@@ -3550,7 +4144,7 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff
         // Reuse extra of the parent tensor. The offset of this view tensor
         // becomes `extra->offset + view_offs` and needs to be calculated when
         // it is used. This changes is needed because of the change to
-        // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640.
+        // ggml_alloc.c in https://github.com/ggml-org/llama.cpp/pull/7640.
         // `buffer` passed in here will always be `tensor->buffer`. It is OK
         // to allocate extras from the same buffer context for ordinary
         // intermediate tensors. But for views into kv cache tensors, doing so
@@ -3599,6 +4193,15 @@ inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ct
     return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
 }
 
+inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
+
+    bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor);
+
+    size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
+
+    return ((elem_num < 128 * 1024 * 1024) && adreno_kernel);  // max element num: 2**27
+}
+
 static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
 
@@ -3638,7 +4241,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
         //GGML_ASSERT(offset == 0);
 
         // We create subbuffers from the original tensor buffer for scales and
-        // quants - i.e., scales and quants are aliases into the buffer obejct
+        // quants - i.e., scales and quants are aliases into the buffer object
         // that backs the original tensor. This is a cleaner way to adapt to the
         // new memory management.
         // In the old code, we allocate new buffers for scales and quants
@@ -3863,6 +4466,99 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
         return;
 
     }
+    if (tensor->type == GGML_TYPE_Q4_1) {
+        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
+        GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
+
+        // Allocate the new extra and create aliases from the original.
+        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+        ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1();
+
+        size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
+        size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
+        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
+        GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+        CL_CHECK(clEnqueueWriteBuffer(
+            queue, data_device, CL_TRUE, 0,
+            ggml_nbytes(tensor), data, 0, NULL, NULL));
+
+        cl_buffer_region region;
+
+        // The original tensor memory is divided into scales and quants, i.e.,
+        // we first store scales, mins, then quants.
+        // Create subbuffer for scales.
+        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
+        region.size = size_d;
+        extra->d = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+        auto previous_origin = region.origin;
+
+        // Create subbuffer for mins.
+        region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
+        region.size = size_m;
+        extra->m = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+        previous_origin = region.origin;
+
+        // Create subbuffer for quants.
+        region.origin = align_to(previous_origin + size_m, backend_ctx->alignment);
+        region.size = size_q;
+        extra->q = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+
+    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
+
+        if (use_adreno_kernels(backend_ctx, tensor)) {
+            kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle;
+        }
+    #else
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
+    #endif // GGML_OPENCL_USE_ADRENO_KERNELS
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clReleaseMemObject(data_device));
+
+        tensor->extra = extra;
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        if (use_adreno_kernels(backend_ctx, tensor)) {
+
+            int M = tensor->ne[1];
+            int K = tensor->ne[0];
+
+            GGML_ASSERT(K % 32 == 0);
+
+            // Transpose q as ushort
+            transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M);
+            // Transpose d as ushort
+            transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M);
+            // Transpose m as ushort
+            transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M);
+        }
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+        return;
+    }
     if (tensor->type == GGML_TYPE_MXFP4) {
         ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
         GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
@@ -4013,6 +4709,216 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
 
         tensor->extra = extra;
 
+        // Transpose the weights and scales
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        if (enable_adreno_trans_weight(backend_ctx, tensor)) {
+
+            int M = tensor->ne[1];   // ne01
+            int K = tensor->ne[0];   // ne00
+
+            GGML_ASSERT(K % 32 == 0);
+            GGML_ASSERT(M % 4 == 0);
+            GGML_ASSERT(tensor->ne[2] == 1);
+            GGML_ASSERT(tensor->ne[3] == 1);
+
+            // Transpose weights
+            size_t q_size_bytes = K * M / 4 * sizeof(float);
+            cl_buffer_region region;
+            region.origin = 0;
+            region.size = q_size_bytes;
+            cl_mem qT_d = clCreateSubBuffer(
+                backend_ctx->prealloc_quant_trans.buffer,
+                0,
+                CL_BUFFER_CREATE_TYPE_REGION,
+                ®ion,
+                &err);
+            CL_CHECK(err);
+
+            cl_mem q_d_image1D;
+            cl_mem qT_d_image1D;
+
+            cl_image_format img_fmt_1d;
+            cl_image_desc img_desc_1d;
+
+            img_fmt_1d = { CL_RGBA, CL_FLOAT };
+            memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+            img_desc_1d.image_width = M * K / 4 / 4;
+            img_desc_1d.buffer = extra->q;
+            q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+            CL_CHECK(err);
+
+            img_fmt_1d = { CL_RGBA, CL_FLOAT };
+            memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+            img_desc_1d.image_width = M * K / 4 / 4;
+            img_desc_1d.buffer = qT_d;
+            qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+            CL_CHECK(err);
+
+            int height_q = M / 4;
+            int width_q = K / 4 / 4;
+            kernel = backend_ctx->kernel_transpose_32;
+
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_q));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_q));
+
+            size_t local_size_q[3] = {4, 16, 1};
+            size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1};
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+
+            // Transpose scales
+            size_t d_size_bytes = M * (K / 32) * 2;
+            region.origin = 0;
+            region.size = d_size_bytes;
+            cl_mem dT_d = clCreateSubBuffer(
+                backend_ctx->prealloc_scales_trans.buffer,
+                0,
+                CL_BUFFER_CREATE_TYPE_REGION,
+                ®ion,
+                &err);
+            CL_CHECK(err);
+
+            cl_mem d_d_image1D;
+            cl_mem dT_d_image1D;
+
+            memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+            img_fmt_1d = { CL_R, CL_HALF_FLOAT };
+            img_desc_1d.image_width = M * K / 32;
+            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+            img_desc_1d.buffer = extra->d;
+            d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+            CL_CHECK(err);
+
+            img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT };
+            memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+            img_desc_1d.image_width = M * K / 32 / 4;
+            img_desc_1d.buffer = dT_d;
+            dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+            CL_CHECK(err);
+
+            int height_s = M / 4;
+            int width_s = K / 32;
+
+            kernel = backend_ctx->kernel_transpose_16_4x1;
+
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));
+
+            size_t local_size_s[3] = {4, 16, 1};
+            size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1};
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+
+            // copy transposed buffer contents to original buffers
+            CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+
+            CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+
+            CL_CHECK(clReleaseMemObject(qT_d));
+            CL_CHECK(clReleaseMemObject(dT_d));
+
+            CL_CHECK(clReleaseMemObject(q_d_image1D));
+            CL_CHECK(clReleaseMemObject(d_d_image1D));
+            CL_CHECK(clReleaseMemObject(qT_d_image1D));
+            CL_CHECK(clReleaseMemObject(dT_d_image1D));
+        } // end transpose
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+        return;
+    }
+    if (tensor->type == GGML_TYPE_Q6_K) {
+        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
+        GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
+
+        // Allocate the new extra and create aliases from the original.
+        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+        ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K();
+
+        size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
+        size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4;
+        size_t size_s  = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16;
+        size_t size_d  = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
+        GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) &&
+            "Incorrect tensor size");
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+        CL_CHECK(clEnqueueWriteBuffer(
+            queue, data_device, CL_TRUE, 0,
+            ggml_nbytes(tensor), data, 0, NULL, NULL));
+
+        cl_buffer_region region;
+
+        // Subbuffer for ql
+        region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
+        region.size = size_ql;
+        extra->ql = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+        auto previous_origin = region.origin;
+
+        // Subbuffer for qh
+        region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment);
+        region.size = size_qh;
+        extra->qh = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+        previous_origin = region.origin;
+
+        // Subbuffer for scales
+        region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment);
+        region.size = size_s;
+        extra->s = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+        previous_origin = region.origin;
+
+        // Create subbuffer for d.
+        region.origin = align_to(previous_origin + size_s, backend_ctx->alignment);
+        region.size = size_d;
+        extra->d = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+        previous_origin = region.origin;
+
+        // Flatten the weights
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clReleaseMemObject(data_device));
+
+        extra->size_ql = size_ql;
+        extra->size_qh = size_qh;
+        extra->size_s  = size_s;
+        extra->size_d  = size_d;
+
+        tensor->extra  = extra;
         return;
     }
 #endif // GGML_OPENCL_SOA_Q
@@ -4155,7 +5061,82 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
             size, data, 0, NULL, NULL));
         CL_CHECK(clReleaseMemObject(data_device));
         return;
-    } else if (tensor->type == GGML_TYPE_MXFP4) {
+    }
+    if (tensor->type == GGML_TYPE_Q4_1) {
+        ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        if (use_adreno_kernels(backend_ctx, tensor)) {
+            static ggml_cl_buffer buf_trans_q;
+            static ggml_cl_buffer buf_trans_m;
+            static ggml_cl_buffer buf_trans_d;
+            static ggml_cl_buffer buf_unpacked;
+
+            cl_int M = tensor->ne[1];
+            cl_int K = tensor->ne[0];
+
+            GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0);
+
+            size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2;
+            size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);
+            size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t);
+            GGML_ASSERT(size_d + size_q + size_m == ggml_nbytes(tensor) && "Incorrect tensor size");
+
+            buf_trans_q.allocate(backend_ctx->context, size_q);
+            buf_trans_m.allocate(backend_ctx->context, size_m);
+            buf_trans_d.allocate(backend_ctx->context, size_d);
+            buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor));
+
+            // transpose q, d, m back
+            transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4);
+            transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32);
+            transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32);
+
+            cl_uchar mask_0F = 0x0F;
+            cl_uchar mask_F0 = 0xF0;
+
+            size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+            size_t local_work_size[] = {1, 1, 1};
+
+            cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_noshuffle;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &buf_trans_q.buffer));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &buf_trans_d.buffer));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &buf_trans_m.buffer));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &buf_unpacked.buffer));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F));
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0));
+
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+            CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL));
+            return;
+        }
+#endif
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+
+        cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {1, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
+            global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clEnqueueReadBuffer(
+            queue, data_device, CL_TRUE, offset,
+            size, data, 0, NULL, NULL));
+        CL_CHECK(clReleaseMemObject(data_device));
+        return;
+    }
+    if (tensor->type == GGML_TYPE_MXFP4) {
         ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
 
         cl_int err;
@@ -4216,6 +5197,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
             ggml_nbytes(tensor), NULL, &err);
         CL_CHECK(err);
 
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        if (enable_adreno_trans_weight(backend_ctx, tensor)) {
+            cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0_trans;
+
+            int ne00 = tensor->ne[0];
+            int ne01 = tensor->ne[1];
+            GGML_ASSERT(tensor->ne[2] == 1);
+            GGML_ASSERT(tensor->ne[3] == 1);
+
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
+
+            size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), 1, 1};
+            size_t local_work_size[3] = {64, 1, 1};
+
+            cl_event evt;
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
+                global_work_size, local_work_size, 0, NULL, &evt));
+            CL_CHECK(clWaitForEvents(1, &evt));
+
+            CL_CHECK(clEnqueueReadBuffer(
+                queue, data_device, CL_TRUE, offset,
+                size, data, 0, NULL, NULL));
+            CL_CHECK(clReleaseMemObject(data_device));
+            return;
+        }
+#endif
         cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0;
         CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
         CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
@@ -4224,6 +5235,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
         size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
         size_t local_work_size[] = {1, 1, 1};
 
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
+            global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clEnqueueReadBuffer(
+            queue, data_device, CL_TRUE, offset,
+            size, data, 0, NULL, NULL));
+        CL_CHECK(clReleaseMemObject(data_device));
+        return;
+    }
+    if (tensor->type == GGML_TYPE_Q6_K) {
+        ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra;
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+
+        cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {1, 1, 1};
+
         cl_event evt;
         CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
             global_work_size, local_work_size, 0, NULL, &evt));
@@ -4347,7 +5386,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_
 }
 
 static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
-    *free = 0;
+    // no memory to report
+    *free  = 0;
     *total = 0;
 
     GGML_UNUSED(dev);
@@ -4666,6 +5706,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg
             (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
 }
 
+// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but
+// nb[] is recalculated such that tensor is contiguous.
+static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst,
+                                       cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) {
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    const int tensor_type_size = ggml_type_size(src->type);
+
+    const int ne00 = src->ne[0];
+    const int ne01 = src->ne[1];
+    const int ne02 = src->ne[2];
+    const int ne03 = src->ne[3];
+
+    const cl_ulong nb00 = src->nb[0];
+    const cl_ulong nb01 = src->nb[1];
+    const cl_ulong nb02 = src->nb[2];
+    const cl_ulong nb03 = src->nb[3];
+
+    const int ne0 = src->ne[0];
+    const int ne1 = src->ne[1];
+    const int ne2 = src->ne[2];
+    const int ne3 = src->ne[3];
+
+    nb0 = tensor_type_size;
+    nb1 = tensor_type_size*ne00;
+    nb2 = tensor_type_size*ne00*ne01;
+    nb3 = tensor_type_size*ne00*ne01*ne02;
+
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra;
+
+    cl_ulong offset0 = extra->offset + src->view_offs;
+    cl_ulong offsetd = 0;
+
+    cl_kernel kernel;
+
+    switch (src->type) {
+        case GGML_TYPE_F32:
+            kernel = backend_ctx->kernel_cpy_f32_f32;
+            break;
+        case GGML_TYPE_F16:
+            kernel = backend_ctx->kernel_cpy_f16_f16;
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &dst));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne2));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne3));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
+
+    const int nth = MIN(64, ne00);
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src);
+}
+
 static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     UNUSED(backend);
     UNUSED(src0);
@@ -4681,19 +5796,12 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
-    const int      ne00 = src0->ne[0];
-    const cl_ulong nb01 = src0->nb[1];
-    const cl_ulong nb02 = src0->nb[2];
-    const cl_ulong nb03 = src0->nb[3];
-    const int      ne10 = src1->ne[0];
-    const cl_ulong nb10 = src1->nb[0];
-    const int      ne11 = src1->ne[1];
-    const int      ne12 = src1->ne[2];
-    const cl_ulong nb11 = src1->nb[1];
-    const cl_ulong nb12 = src1->nb[2];
-    const cl_ulong nb1  = dst->nb[1];
-    const cl_ulong nb2  = dst->nb[2];
-    const cl_ulong nb3  = dst->nb[3];
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int,      ne1, src1, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);
+    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -4739,8 +5847,14 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
     CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
 
-    size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
-    size_t local_work_size[] = {64, 1, 1};
+    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+    int nth = 1;
+    while (nth < ne00 && 2*nth <= max_workgroup_size) {
+        nth *= 2;
+    }
+
+    size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
@@ -5595,7 +6709,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_UNUSED(src1);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -5618,7 +6731,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
     const cl_ulong nb2  = dst->nb[2];
     const cl_ulong nb3  = dst->nb[3];
 
-    cl_kernel kernel = backend_ctx->kernel_mean_f32;
+    cl_kernel kernel;
+
+    const bool is_c4 = ne00 % 4 == 0;
+    if (is_c4) {
+        kernel = backend_ctx->kernel_mean_f32_4;
+    } else {
+        kernel = backend_ctx->kernel_mean_f32;
+    }
 
     CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
@@ -5635,7 +6755,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const
     CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
     CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
 
-    size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
+    size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)64, 1, 1};
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
@@ -5941,6 +7061,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
 }
 
+static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const int tri_type = ggml_get_op_params_i32(dst, 0);
+    const int64_t n = ggml_nelements(dst);
+    const int     ne0  = dst->ne[0];
+    const int     ne1  = dst->ne[1];
+
+    cl_kernel kernel = backend_ctx->kernel_tri;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &n));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne1));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &tri_type));
+
+    size_t local_work_size[1] = { 256 };
+    size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
@@ -6436,6 +7594,64 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0,
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
+static void ggml_cl_l2_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+
+    size_t sgs;
+    if (backend_ctx->gpu_family == ADRENO) {
+        sgs = 64;
+    } else if (backend_ctx->gpu_family == INTEL) {
+        sgs = 32;
+    } else {
+        GGML_ASSERT(false && "Unsupported GPU");
+    }
+
+    cl_kernel kernel = backend_ctx->kernel_l2_norm_f32;
+
+    int nth = sgs;
+    while (nth < ne00 && nth < (int)backend_ctx->get_kernel_workgroup_size(kernel)) {
+        nth *= 2;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),    &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),  &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),    &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),  &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),       &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),       &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),       &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),       &ne03));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),  &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),  &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),  &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),     &eps));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs,  NULL));
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -6449,79 +7665,251 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
-    cl_ulong offset0_abs = extra0->offset + src0->view_offs;
-    cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+
+    const cl_ulong nb0  = dst->nb[0];
+    const cl_ulong nb1  = dst->nb[1];
+    const cl_ulong nb2  = dst->nb[2];
+    const cl_ulong nb3  = dst->nb[3];
 
     cl_kernel kernel;
-    if (dst->type == GGML_TYPE_F32) {
-        kernel = backend_ctx->kernel_tanh_f32_nd;
-    } else if (dst->type == GGML_TYPE_F16) {
-        kernel = backend_ctx->kernel_tanh_f16_nd;
-    } else {
-        GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh");
-    }
-    GGML_ASSERT(kernel != nullptr);
 
-    const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3];
-    const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3];
-
-    const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3];
-    const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3];
-
-    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
-    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
-
-    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
-    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne02));
-    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &ne03));
-    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
-    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
-
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),     &ne10));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),     &ne11));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),     &ne12));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),     &ne13));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
-
-    size_t global_work_size[3];
-    if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
-        return;
-    }
-    global_work_size[0] = (size_t)ne10;
-    global_work_size[1] = (size_t)ne11;
-    global_work_size[2] = (size_t)ne12;
-
-    size_t lws0 = 16, lws1 = 4, lws2 = 1;
-    if (ne10 < 16) lws0 = ne10;
-    if (ne11 < 4) lws1 = ne11;
-    if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
-
-    while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
-    while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
-    while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
-
-
-    size_t local_work_size[] = {lws0, lws1, lws2};
-
-    size_t* local_work_size_ptr = local_work_size;
-    if (!backend_ctx->non_uniform_workgroups) {
-        if (global_work_size[0] % local_work_size[0] != 0 ||
-            global_work_size[1] % local_work_size[1] != 0 ||
-            global_work_size[2] % local_work_size[2] != 0) {
-            local_work_size_ptr = NULL;
+    if (ggml_is_contiguous(src0)) {
+        // Handle contiguous input
+        int n = ggml_nelements(dst);
+        if (n % 4 == 0) {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_tanh_f32_4;
+            } else {
+                kernel = backend_ctx->kernel_tanh_f16_4;
+            }
+            n /= 4;
+        } else {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_tanh_f32;
+            } else {
+                kernel = backend_ctx->kernel_tanh_f16;
+            }
         }
-    }
-    if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
 
-    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        size_t * local_work_size_ptr = local_work_size;
+        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
+            local_work_size_ptr = nullptr;
+        }
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+    } else {
+        // Handle non-contiguous input
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_tanh_f32_nc;
+        } else {
+            kernel = backend_ctx->kernel_tanh_f16_nc;
+        }
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
+
+        int nth = 64;
+
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    }
+}
+
+static void ggml_cl_neg(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);
+
+    cl_kernel kernel;
+
+    if (ggml_is_contiguous(src0)) {
+        // Handle contiguous input
+        int n = ggml_nelements(dst);
+        if (n % 4 == 0) {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_neg_f32_4;
+            } else {
+                kernel = backend_ctx->kernel_neg_f16_4;
+            }
+            n /= 4;
+        } else {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_neg_f32;
+            } else {
+                kernel = backend_ctx->kernel_neg_f16;
+            }
+        }
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int),   &n));
+
+        size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    } else {
+        // Handle non-contiguous input
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_neg_f32_nc;
+        } else {
+            kernel = backend_ctx->kernel_neg_f16_nc;
+        }
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
+
+        int nth = 64;
+
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    }
+}
+
+static void ggml_cl_exp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);
+
+    cl_kernel kernel;
+
+    if (ggml_is_contiguous(src0)) {
+        // Handle contiguous input
+        int n = ggml_nelements(dst);
+        if (n % 4 == 0) {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_exp_f32_4;
+            } else {
+                kernel = backend_ctx->kernel_exp_f16_4;
+            }
+            n /= 4;
+        } else {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_exp_f32;
+            } else {
+                kernel = backend_ctx->kernel_exp_f16;
+            }
+        }
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int),   &n));
+
+        size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    } else {
+        // Handle non-contiguous input
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_exp_f32_nc;
+        } else {
+            kernel = backend_ctx->kernel_exp_f16_nc;
+        }
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
+
+        int nth = 64;
+
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    }
 }
 
 static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -6537,18 +7925,8 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
-    cl_ulong offset0_abs = extra0->offset + src0->view_offs;
-    cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
-
-    cl_kernel kernel;
-    if (dst->type == GGML_TYPE_F32) {
-        kernel = backend_ctx->kernel_expm1_f32_nd;
-    } else if (dst->type == GGML_TYPE_F16) {
-        kernel = backend_ctx->kernel_expm1_f16_nd;
-    } else {
-        GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1");
-    }
-    GGML_ASSERT(kernel != nullptr);
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
 
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
@@ -6560,70 +7938,74 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
     const cl_ulong nb02 = src0->nb[2];
     const cl_ulong nb03 = src0->nb[3];
 
-    const int ne10 = dst->ne[0];
-    const int ne11 = dst->ne[1];
-    const int ne12 = dst->ne[2];
-    const int ne13 = dst->ne[3];
+    const cl_ulong nb0 = dst->nb[0];
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
 
-    const cl_ulong nb10 = dst->nb[0];
-    const cl_ulong nb11 = dst->nb[1];
-    const cl_ulong nb12 = dst->nb[2];
-    const cl_ulong nb13 = dst->nb[3];
+    cl_kernel kernel;
 
-    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
-    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
-
-    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
-    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne02));
-    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &ne03));
-    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
-    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
-
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),     &ne10));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),     &ne11));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),     &ne12));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),     &ne13));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
-
-    size_t global_work_size[3];
-    if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
-        return;
-    }
-    global_work_size[0] = (size_t)ne10;
-    global_work_size[1] = (size_t)ne11;
-    global_work_size[2] = (size_t)ne12;
-
-    size_t lws0 = 16, lws1 = 4, lws2 = 1;
-    if (ne10 < 16) lws0 = ne10;
-    if (ne11 < 4) lws1 = ne11;
-    if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
-
-    while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
-    while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
-    while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
-
-
-    size_t local_work_size[] = {lws0, lws1, lws2};
-
-    size_t* local_work_size_ptr = local_work_size;
-    if (!backend_ctx->non_uniform_workgroups) {
-        if (global_work_size[0] % local_work_size[0] != 0 ||
-            global_work_size[1] % local_work_size[1] != 0 ||
-            global_work_size[2] % local_work_size[2] != 0) {
-            local_work_size_ptr = NULL;
+    if (ggml_is_contiguous(src0)) {
+        // Handle contiguous input
+        int n = ggml_nelements(dst);
+        if (n % 4 == 0) {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_expm1_f32_4;
+            } else {
+                kernel = backend_ctx->kernel_expm1_f16_4;
+            }
+            n /= 4;
+        } else {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_expm1_f32;
+            } else {
+                kernel = backend_ctx->kernel_expm1_f16;
+            }
         }
-    }
-    if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
 
-    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        size_t * local_work_size_ptr = local_work_size;
+        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
+            local_work_size_ptr = nullptr;
+        }
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+    } else {
+        // Handle non-contiguous input
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_expm1_f32_nc;
+        } else {
+            kernel = backend_ctx->kernel_expm1_f16_nc;
+        }
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
+
+        int nth = 64;
+
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    }
 }
 
 static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -6639,18 +8021,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
-    cl_ulong offset0_abs = extra0->offset + src0->view_offs;
-    cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
-
-    cl_kernel kernel;
-    if (dst->type == GGML_TYPE_F32) {
-        kernel = backend_ctx->kernel_softplus_f32_nd;
-    } else if (dst->type == GGML_TYPE_F16) {
-        kernel = backend_ctx->kernel_softplus_f16_nd;
-    } else {
-        GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
-    }
-    GGML_ASSERT(kernel != nullptr);
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
 
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
@@ -6662,70 +8034,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
     const cl_ulong nb02 = src0->nb[2];
     const cl_ulong nb03 = src0->nb[3];
 
-    const int ne10 = dst->ne[0];
-    const int ne11 = dst->ne[1];
-    const int ne12 = dst->ne[2];
-    const int ne13 = dst->ne[3];
+    const cl_ulong nb0 = dst->nb[0];
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
 
-    const cl_ulong nb10 = dst->nb[0];
-    const cl_ulong nb11 = dst->nb[1];
-    const cl_ulong nb12 = dst->nb[2];
-    const cl_ulong nb13 = dst->nb[3];
+    cl_kernel kernel;
 
-    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
-    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
-
-    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
-    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne02));
-    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &ne03));
-    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
-    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
-
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),     &ne10));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),     &ne11));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),     &ne12));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),     &ne13));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
-
-    size_t global_work_size[3];
-    if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
-        return;
-    }
-    global_work_size[0] = (size_t)ne10;
-    global_work_size[1] = (size_t)ne11;
-    global_work_size[2] = (size_t)ne12;
-
-    size_t lws0 = 16, lws1 = 4, lws2 = 1;
-    if (ne10 < 16) lws0 = ne10;
-    if (ne11 < 4) lws1 = ne11;
-    if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
-
-    while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
-    while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
-    while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
-
-
-    size_t local_work_size[] = {lws0, lws1, lws2};
-
-    size_t* local_work_size_ptr = local_work_size;
-    if (!backend_ctx->non_uniform_workgroups) {
-        if (global_work_size[0] % local_work_size[0] != 0 ||
-            global_work_size[1] % local_work_size[1] != 0 ||
-            global_work_size[2] % local_work_size[2] != 0) {
-            local_work_size_ptr = NULL;
+    if (ggml_is_contiguous(src0)) {
+        // Handle contiguous input
+        int n = ggml_nelements(dst);
+        if (n % 4 == 0) {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_softplus_f32_4;
+            } else {
+                kernel = backend_ctx->kernel_softplus_f16_4;
+            }
+            n /= 4;
+        } else {
+            if (src0->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_softplus_f32;
+            } else {
+                kernel = backend_ctx->kernel_softplus_f16;
+            }
         }
-    }
-    if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
 
-    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        size_t * local_work_size_ptr = local_work_size;
+        if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
+            local_work_size_ptr = nullptr;
+        }
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+    } else {
+        // Handle non-contiguous input
+        if (src0->type == GGML_TYPE_F32) {
+            kernel = backend_ctx->kernel_softplus_f32_nc;
+        } else {
+            kernel = backend_ctx->kernel_softplus_f16_nc;
+        }
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
+
+        int nth = 64;
+
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+    }
 }
 
 static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
@@ -6739,53 +8115,58 @@ static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, con
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
-    if (backend_ctx->kernel_repeat == nullptr) {
-        GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__);
-        return;
-    }
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad  = (ggml_tensor_extra_cl *)dst->extra;
 
-    ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
-    ggml_tensor_extra_cl * extra_dst  = (ggml_tensor_extra_cl *)dst->extra;
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd  = extrad->offset + dst->view_offs;
 
-    cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
-    cl_ulong off_dst  = extra_dst->offset  + dst->view_offs;
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
 
-    const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3];
-    const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3];
+    const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
 
-    const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3];
-    const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3];
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    const int ne3 = dst->ne[3];
 
-    cl_kernel kernel = backend_ctx->kernel_repeat;
+    const cl_ulong nb0 = dst->nb[0];
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
 
-    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra_src0->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),    &extra_dst->data_device));
-    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong),  &off_src0));
-    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &off_dst));
-    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &src0_ne0));
-    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),       &src0_ne1));
-    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),       &src0_ne2));
-    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),       &src0_ne3));
-    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong),  &src0_nb0));
-    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong),  &src0_nb1));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3));
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &dst_ne0));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &dst_ne1));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &dst_ne2));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &dst_ne3));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3));
+    cl_kernel kernel = backend_ctx->kernel_repeat_f32;
 
-    size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1;
-    size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1;
-    size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1;
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb0));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
 
-    size_t global_work_size[] = { gws0, gws1, gws2 };
+    int nth = 64;
 
-    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
+    size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
 static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -7009,121 +8390,76 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
-    cl_command_queue queue = backend_ctx->queue;
 
-    if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) {
-        GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__);
-        return;
-    }
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
-    ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra;
-    ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra;
-    ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra;
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd  = extrad->offset + dst->view_offs;
 
-    cl_ulong off_src0 = extra0_cl->offset + src0->view_offs;
-    cl_ulong off_src1 = extra1_cl->offset + src1->view_offs;
-    cl_ulong off_dst  = extrad_cl->offset + dst->view_offs;
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
 
-    const int32_t dim = ((const int32_t *) dst->op_params)[0];
+    const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+
+    const cl_ulong nb10 = src1->nb[0];
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+    const int ne2 = dst->ne[2];
+    const int ne3 = dst->ne[3];
+
+    const cl_ulong nb0 = dst->nb[0];
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
+
+    const cl_int dim = ((const int32_t *) dst->op_params)[0];
     GGML_ASSERT(dim >= 0 && dim <= 3);
 
-    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
-        if (dim == 3) {
+    int nth = MIN(64, ne0);
 
-            size_t nbytes_src0 = ggml_nbytes(src0);
-            size_t nbytes_src1 = ggml_nbytes(src1);
+    cl_kernel kernel = backend_ctx->kernel_concat_f32;
 
-            CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device,
-                                         off_src0, off_dst, nbytes_src0, 0, NULL, NULL));
-            CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device,
-                                         off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL));
-        } else {
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int),   &dim));
 
-            cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous;
-            size_t global_work_size[3];
+    size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
 
-            for (int i3 = 0; i3 < dst->ne[3]; ++i3) {
-                cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]);
-                cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]);
-                cl_ulong current_off_dst  = off_dst  + (i3 * dst->nb[3]);
-
-                int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2];
-                int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2];
-                int d_ne0  = dst->ne[0];  int d_ne1  = dst->ne[1];  int d_ne2  = dst->ne[2];
-
-                CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra0_cl->data_device));
-                CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  ¤t_off_src0));
-                CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extra1_cl->data_device));
-                CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  ¤t_off_src1));
-                CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),    &extrad_cl->data_device));
-                CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  ¤t_off_dst));
-                CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),       &d_ne00));
-                CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),       &d_ne01));
-                CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),       &d_ne02));
-                CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int),       &d_ne10));
-                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &d_ne11));
-                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &d_ne12));
-                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &d_ne0));
-                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &d_ne1));
-                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &d_ne2));
-                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &dim));
-
-                global_work_size[0] = d_ne0;
-                global_work_size[1] = d_ne1;
-                global_work_size[2] = d_ne2;
-
-                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
-            }
-        }
-    } else {
-        cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous;
-
-        cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
-        cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
-
-        cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
-
-        cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3];
-        cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3];
-
-
-        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra0_cl->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &off_src0));
-        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extra1_cl->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &off_src1));
-        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),    &extrad_cl->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  &off_dst));
-
-        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long),      &ne00));
-        CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long),      &ne01));
-        CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long),      &ne02));
-        CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long),      &ne03));
-        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),    &nb00));
-        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),    &nb01));
-        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),    &nb02));
-        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),    &nb03));
-
-        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),    &nb10));
-        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong),    &nb11));
-        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),    &nb12));
-        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),    &nb13));
-
-        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long),     &d_ne0));
-        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long),     &d_ne1));
-        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long),     &d_ne2));
-        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long),     &d_ne3));
-        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),    &d_nb0));
-        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong),    &d_nb1));
-        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong),    &d_nb2));
-        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong),    &d_nb3));
-        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int),      &dim));
-
-        size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1,
-                                         d_ne2 > 0 ? (size_t)d_ne2 : 1,
-                                         d_ne3 > 0 ? (size_t)d_ne3 : 1 };
-
-        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_nc, NULL, dst);
-    }
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
 static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -7574,6 +8910,427 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten
     CL_CHECK(clReleaseMemObject(D_sub_buffer));
 }
 
+static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+    ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
+
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const int  ne00 = src0->ne[0];
+    const int  ne01 = src0->ne[1];
+
+    const int  ne1 = dst->ne[1];
+
+    GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+    cl_context context = backend_ctx->context;
+    cl_kernel kernel;
+
+    cl_int              err;
+    cl_image_format     img_fmt;
+    cl_image_desc       img_desc;
+    cl_buffer_region    region;
+
+    int M = ne01;
+    int N = ne1;
+    int K = ne00;
+
+    if (ne1 == 1) {
+        cl_mem q_img = nullptr;
+        cl_mem b_sub_buf = nullptr;
+        cl_mem b_img = nullptr;
+
+        // image for q
+        img_fmt = { CL_R, CL_UNSIGNED_INT32};
+        memset(&img_desc, 0, sizeof(img_desc));
+        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc.image_width = M * K / 2 / 4;
+        img_desc.buffer = extra0_q4_1->q;
+        CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
+
+        // subbuffer for activations
+        region.origin = offset1;
+        region.size = K * N * sizeof(float);
+        CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
+
+        // image for activations
+        img_fmt = {CL_RGBA, CL_FLOAT};
+        memset(&img_desc, 0, sizeof(img_desc));
+        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc.image_width = K * N / 4;
+        img_desc.buffer = b_sub_buf;
+        CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
+
+        kernel = backend_ctx->kernel_gemv_noshuffle_q4_1_f32;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &q_img));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &extra0_q4_1->d));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra0_q4_1->m));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &b_img));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int),   &ne00));
+        CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int),   &ne01));
+
+        size_t local_work_size[3] = {64, 4, 1};
+        size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+
+        CL_CHECK(clReleaseMemObject(q_img));
+        CL_CHECK(clReleaseMemObject(b_sub_buf));
+        CL_CHECK(clReleaseMemObject(b_img));
+    } else {
+        cl_mem b_sub_buf = nullptr;
+        cl_mem b_sub_buf_trans = nullptr;
+        cl_mem b_img = nullptr;
+        cl_mem b_img_trans = nullptr;
+
+        // subbuffer for activations
+        region.origin = offset1;
+        region.size = K * N * sizeof(float);
+        CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
+
+        // image for activations
+        img_fmt = {CL_RGBA, CL_FLOAT};
+        memset(&img_desc, 0, sizeof(img_desc));
+        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc.image_width = K * N / 4;
+        img_desc.buffer = b_sub_buf;
+        CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err));
+
+        // pad N to multiple of 8
+        int extra_elements = N % 8;
+        int padding = 0;
+        if (extra_elements > 0){
+            padding = 8 - extra_elements;
+        }
+
+        // subbuffer for transposed activations
+        region.origin = 0;
+        region.size = K * (N + padding) * sizeof(float)/2;
+        backend_ctx->prealloc_act_trans.allocate(context, region.size);
+        CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err));
+
+        // image for transposed activations
+        img_fmt = {CL_RGBA, CL_HALF_FLOAT};
+        memset(&img_desc, 0, sizeof(img_desc));
+        img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc.image_width = K * (N + padding) / 4;
+        img_desc.buffer = b_sub_buf_trans;
+        CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err));
+
+        // transpose activations
+        int height_B = N/4;
+        if (height_B == 0) {
+            height_B = 1;
+        }
+        int width_B = K/4;
+        int padded_height_B = (N + padding)/4;
+
+        kernel = backend_ctx->kernel_transpose_32_16;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));
+
+        size_t local_work_size_t[2] = { 1, 16 };
+        size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B };
+        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst);
+
+        // gemm
+        kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32;
+        int padded_N = N + padding;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0_q4_1->q));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem),   &extra0_q4_1->d));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra0_q4_1->m));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem),   &b_img_trans));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int),   &ne01));
+        CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int),   &padded_N));
+        CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int),   &ne00));
+        CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int),   &ne1));
+
+        size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1};
+        size_t local_work_size[3] = {1, 128, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+
+        CL_CHECK(clReleaseMemObject(b_sub_buf));
+        CL_CHECK(clReleaseMemObject(b_sub_buf_trans));
+        CL_CHECK(clReleaseMemObject(b_img));
+        CL_CHECK(clReleaseMemObject(b_img_trans));
+    }
+#else
+    GGML_UNUSED(backend);
+    GGML_UNUSED(src0);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+#endif
+}
+
+static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const enum ggml_type src0t = src0->type;
+    const enum ggml_type src1t = src1->type;
+
+    GGML_ASSERT(src0t == GGML_TYPE_Q8_0);
+    GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
+
+    GGML_ASSERT(src1->view_offs == 0);
+    GGML_ASSERT(dst->view_offs == 0);
+
+    const int  ne00 = src0->ne[0];
+    const int  ne01 = src0->ne[1];
+    const int  ne02 = src0->ne[2];
+
+    const int  ne10 = src1->ne[0];
+    const int  ne12 = src1->ne[2];
+
+    const int  ne0 = dst->ne[0];
+    const int  ne1 = dst->ne[1];
+
+    GGML_ASSERT(ne00 == ne10);
+    GGML_ASSERT((ne00 % 32) == 0);
+    GGML_ASSERT(ne0 == ne01);
+
+    cl_context context = backend_ctx->context;
+    cl_kernel kernel;
+
+    // init CL objects
+    cl_int              status;
+    cl_image_format     img_fmt_1d;
+    cl_image_desc       img_desc_1d;
+    cl_buffer_region    region;
+    cl_mem              A_image1d;
+    cl_mem              B_image1d;
+    cl_mem              B_sub_buffer;
+    cl_mem              S_image1d;
+
+    cl_mem              D_image1d;
+    cl_mem              D_sub_buffer;
+
+    int M = ne01;
+    int N = ne1;
+    int K = ne00;
+
+    // create an image for A
+    img_fmt_1d = { CL_R, CL_FLOAT};
+    memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+    img_desc_1d.image_width = M * K / 4;    // Divide by 4 for char -> float
+    img_desc_1d.buffer = extra0_q8_0->q;
+    A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+    CL_CHECK(status);
+
+    // create an image for Scale
+    img_fmt_1d = { CL_R, CL_HALF_FLOAT};
+    memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+    img_desc_1d.image_width = M * K / 32;    // Block size is 32
+    img_desc_1d.buffer = extra0_q8_0->d;
+    S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+    CL_CHECK(status);
+
+    // create a sub_buffer for B
+    region.origin = (extra1->offset); // + src1->view_offs);
+    region.size = K * N * sizeof(float);
+    B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
+    CL_CHECK(status);
+
+    // create an image for B from sub_buffer: RGBA (OCL)
+    img_fmt_1d = {CL_RGBA, CL_FLOAT};
+    memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+    img_desc_1d.image_width = K * N / 4;
+    img_desc_1d.buffer = B_sub_buffer;
+    B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+    CL_CHECK(status);
+
+    // Create subbuffer and image1d_buffer for dst
+    region.origin = (extrad->offset); // + dst->view_offs;
+    region.size = M * N * sizeof(float);
+    D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
+    CL_CHECK(status);
+
+    img_fmt_1d = {CL_R, CL_FLOAT};
+    memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+    img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+    img_desc_1d.image_width = M * N;
+    img_desc_1d.buffer = D_sub_buffer;
+    D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status);
+    CL_CHECK(status);
+
+    size_t local_work_size[3] = {1, 1, 1};
+    size_t global_work_size[3] = {1, 1, 1};
+
+    if (N == 1) {
+        kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32;
+
+        int r2 = 1;
+        int r3 = 1;
+        cl_uint k_arg = 0;
+
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &A_image1d));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extra0_q8_0->d));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &B_image1d));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extra1->offset));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extrad->offset));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne10));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne12));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne0));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne1));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r2));
+        CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r3));
+
+        size_t wavesize = backend_ctx->adreno_wave_size;
+        local_work_size[0] = wavesize;
+        local_work_size[1] = 4; // reduce factor
+        local_work_size[2] = 1;
+
+        global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize;
+        global_work_size[1] = 4; // reduce factor
+        global_work_size[2] = 1;
+    } else {
+        cl_ulong offsetd = extrad->offset + dst->view_offs;
+        cl_mem              B_image1d_trans = nullptr;
+        // for B transpose
+        cl_mem B_d = nullptr;
+        int padding;
+
+        //how many extra elements beyond multiple of 8
+        int extra_elements = N % 8;
+
+        //how much padding to add
+        padding = 0;
+        if (extra_elements > 0){
+            padding = 8 - extra_elements;
+        }
+
+        // Specify the starting offset (in bytes)
+        region.origin = 0;
+        // Specify the size of the sub-buffer (divide by 2 for FP16)
+        region.size = K * (N + padding) * sizeof(float)/2;
+        backend_ctx->prealloc_act_trans.allocate(context, region.size);
+        B_d = clCreateSubBuffer(
+            backend_ctx->prealloc_act_trans.buffer,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            ®ion,
+            &status);
+        CL_CHECK(status);
+
+        cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16)
+        cl_image_desc image_desc_B_d_output = {
+            CL_MEM_OBJECT_IMAGE1D_BUFFER,
+            static_cast(K * (N + padding)/4),
+            0, 0, 0, 0, 0, 0, 0, { B_d }
+        };
+        B_image1d_trans = clCreateImage(
+            context,
+            0,
+            &image_format_B_d_output,
+            &image_desc_B_d_output,
+            NULL,
+            &status);
+        CL_CHECK(status);
+
+        int height_B = N/4;
+        if (height_B == 0) {
+            height_B = 1;
+        }
+        int width_B = K/4;
+        int padded_height_B = (N + padding)/4;
+
+        kernel = backend_ctx->kernel_transpose_32_16;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));
+
+        size_t local_size_t[2] = { 1, 16 };
+        size_t global_size_t[2] = {
+            static_cast(width_B),
+            static_cast(padded_height_B)
+        };
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst);
+
+        kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4;
+
+        int N_with_padding = N + padding;
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q8_0->q));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q8_0->d));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &B_image1d_trans));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &K));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &M));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &N_with_padding));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &N));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &offsetd));
+
+        global_work_size[0] = (size_t)(N + 7) / 8;
+        global_work_size[1] = (size_t)(M + 3) / 4;
+        global_work_size[2] = 1;
+
+        local_work_size[0] = 2;
+        local_work_size[1] = 128;
+        local_work_size[2] = 1;
+    }
+
+    // enqueue kernel with profiling
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+
+    // deallocate sub buffers and images
+    CL_CHECK(clReleaseMemObject(A_image1d));
+    CL_CHECK(clReleaseMemObject(B_sub_buffer));
+    CL_CHECK(clReleaseMemObject(B_image1d));
+    CL_CHECK(clReleaseMemObject(S_image1d));
+    CL_CHECK(clReleaseMemObject(D_sub_buffer));
+    CL_CHECK(clReleaseMemObject(D_image1d));
+#else
+    GGML_UNUSED(backend);
+    GGML_UNUSED(src0);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+#endif
+}
+
 static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -7597,8 +9354,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
 #ifdef GGML_OPENCL_SOA_Q
     ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
+    ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
     ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
     ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
+    ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
 #endif
 
     const int  ne00 = src0 ? src0->ne[0] : 0;
@@ -7641,9 +9400,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
     cl_context context = backend_ctx->context;
 
     if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){
-        if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) {
+        if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0  &&
+            // dst is wrapped with image1d_buffer, the size limit applies, also src0
+            (ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) {
             // For KQ
             if (ggml_is_permuted(src0) && ggml_is_permuted(src1) &&
+                ((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) &&
                 nb00 <= nb02 &&
                 nb02 <= nb01 &&
                 nb01 <= nb03 &&
@@ -7654,7 +9416,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 return;
             }
             // For KQV
-            if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+            if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
+                ((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) {
                 ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst);
                 return;
             }
@@ -7686,6 +9449,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
     int padding;
     // <--------------------------------------------> //
 
+    // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require
+    // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that
+    // limit, so the check is omitted.
+
+    // q4_1 x fp32
+    if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) {
+            ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst);
+            return;
+    }
+
+    // q8_0 x fp32
+    if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 &&
+        enable_adreno_trans_weight(backend_ctx, src0)) {
+            ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst);
+            return;
+    }
+
     // q4_0 x fp32
     if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {
         // TODO: remove duplicate definitions of image description + format -- move to top
@@ -7960,9 +9740,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
     // GEMM using local memory
     // Current BK = 16, so ne00 % 16 == 0
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) &&
-        src1t == GGML_TYPE_F32 &&
+    if (src1t == GGML_TYPE_F32 &&
         ne00 % 16 == 0 &&
         ne11 > 1) {
         switch(src0t) {
@@ -7974,10 +9752,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 int batch_stride_b = ne10*ne11;
                 int batch_stride_d = ne0*ne1;
 
-                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
-                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
-                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
-                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                cl_mem mem_src0 = extra0->data_device;
+                cl_mem mem_src1 = extra1->data_device;
+
+                cl_ulong nb00_cont = nb00;
+                cl_ulong nb01_cont = nb01;
+                cl_ulong nb02_cont = nb02;
+                cl_ulong nb03_cont = nb03;
+
+                cl_ulong nb10_cont = nb10;
+                cl_ulong nb11_cont = nb11;
+                cl_ulong nb12_cont = nb12;
+                cl_ulong nb13_cont = nb13;
+
+                cl_ulong offset0_cont = offset0;
+                cl_ulong offset1_cont = offset1;
+
+                if (!ggml_is_contiguous(src0)) {
+                    backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
+                    ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
+                        nb00_cont, nb01_cont, nb02_cont, nb03_cont);
+                    mem_src0 = backend_ctx->prealloc_src0.buffer;
+                    offset0_cont = 0;
+                }
+
+                if (!ggml_is_contiguous(src1)) {
+                    backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
+                    ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
+                        nb10_cont, nb11_cont, nb12_cont, nb13_cont);
+                    mem_src1 = backend_ctx->prealloc_src1.buffer;
+                    offset1_cont = 0;
+                }
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &mem_src0));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0_cont));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &mem_src1));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1_cont));
                 CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
                 CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
                 CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
@@ -8009,8 +9819,82 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 int batch_stride_b = ne10*ne11;
                 int batch_stride_d = ne0*ne1;
 
-                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
-                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+                cl_mem mem_src0 = extra0->data_device;
+                cl_mem mem_src1 = extra1->data_device;
+
+                cl_ulong nb00_cont = nb00;
+                cl_ulong nb01_cont = nb01;
+                cl_ulong nb02_cont = nb02;
+                cl_ulong nb03_cont = nb03;
+
+                cl_ulong nb10_cont = nb10;
+                cl_ulong nb11_cont = nb11;
+                cl_ulong nb12_cont = nb12;
+                cl_ulong nb13_cont = nb13;
+
+                cl_ulong offset0_cont = offset0;
+                cl_ulong offset1_cont = offset1;
+
+                if (!ggml_is_contiguous(src0)) {
+                    backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0));
+                    ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer,
+                        nb00_cont, nb01_cont, nb02_cont, nb03_cont);
+                    mem_src0 = backend_ctx->prealloc_src0.buffer;
+                    offset0_cont = 0;
+                }
+
+                if (!ggml_is_contiguous(src1)) {
+                    backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1));
+                    ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer,
+                            nb10_cont, nb11_cont, nb12_cont, nb13_cont);
+                    mem_src1 = backend_ctx->prealloc_src1.buffer;
+                    offset1_cont = 0;
+                }
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &mem_src0));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0_cont));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &mem_src1));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1_cont));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
+            case GGML_TYPE_Q4_0: {
+                if (ne11 < 32) {
+                    break;
+                }
+                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+                    break;
+                }
+
+                kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));
                 CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
                 CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
                 CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
@@ -8036,10 +9920,57 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
                 return;
             }
+            case GGML_TYPE_Q4_1: {
+                if (ne11 < 32) {
+                    break;
+                }
+                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+                    break;
+                }
+
+                kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_1->q));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_1->d));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q4_1->m));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
             case GGML_TYPE_Q8_0: {
                 if (ne11 < 32) {
                     break;
                 }
+                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+                    break;
+                }
+
                 kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm;
                 nth0 = 128; // calculated as (BM*BN)/(TM*TN)
 
@@ -8074,6 +10005,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
                 backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
                 return;
             }
+            case GGML_TYPE_Q6_K: {
+                if (ne11 < 32) {
+                    break;
+                }
+                if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
+                    break;
+                }
+
+                kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
+                nth0 = 128; // calculated as (BM*BN)/(TM*TN)
+
+                int batch_stride_a = ne00*ne01;
+                int batch_stride_b = ne10*ne11;
+                int batch_stride_d = ne0*ne1;
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q6_K->ql));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q6_K->qh));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q6_K->s));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra0_q6_K->d));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne11));
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10)); // stride_a
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10)); // stride_b
+                CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne01)); // stride_d
+                CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &batch_stride_a));
+                CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &batch_stride_b));
+                CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &batch_stride_d));
+                CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &r3));
+
+                // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
+                size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
+                size_t local_work_size[] = {(size_t)nth0, 1, 1};
+
+                backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+                return;
+            }
             default:
                 break;
         }
@@ -8328,7 +10303,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
             CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
 #endif // GGML_OPENCL_SOA_Q
             break;
-        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q4_1: {
+#ifdef GGML_OPENCL_SOA_Q
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 1;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat;
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_1->q));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_1->d));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q4_1->m));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &r3));
+#else
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 1;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            kernel = backend_ctx->kernel_mul_mv_q4_1_f32;
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+#endif // GGML_OPENCL_SOA_Q
+            break;
+        }
         case GGML_TYPE_Q8_0: {
 #ifdef GGML_OPENCL_SOA_Q
             kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat;
@@ -8409,17 +10448,87 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
         }
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q4_K: {
+            kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 1;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),     &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(int),        &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),     &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(int),        &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),     &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),        &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),        &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),        &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong),   &nb01));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong),   &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),   &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),        &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),   &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),   &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),   &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),        &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),        &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),        &r2));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),        &r3));
+            break;
+        }
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
+#ifdef GGML_OPENCL_SOA_Q
+            kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 2;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 2;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q6_K->ql));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q6_K->qh));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra0_q6_K->s));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_mem),   &extra0_q6_K->d));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &r3));
+#else
             kernel = backend_ctx->kernel_mul_mv_q6_K_f32;
 
             if (backend_ctx->gpu_family == INTEL) {
-                nth0 = 2;
-                nth1 = 16;
+                nth0 = 16;
+                nth1 = 2;
+                ndst = 1;
             } else if (backend_ctx->gpu_family == ADRENO) {
-                nth0 = 2;
-                nth1 = 64;
+                nth0 = 64;
+                nth1 = 2;
+                ndst = 1;
             } else {
                 GGML_ASSERT(false && "TODO: Unknown GPU");
             }
@@ -8439,6 +10548,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
             CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
             CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
             CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+#endif // GGML_OPENCL_SOA_Q
             break;
         case GGML_TYPE_MXFP4: {
 #ifdef GGML_OPENCL_SOA_Q
@@ -8535,13 +10645,16 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
 
         backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
     } else if (src0t == GGML_TYPE_Q4_K) {
-        GGML_ASSERT(false && "not implemented");
+        size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
     } else if (src0t == GGML_TYPE_Q3_K) {
         GGML_ASSERT(false && "not implemented");
     } else if (src0t == GGML_TYPE_Q5_K) {
         GGML_ASSERT(false && "not implemented");
     } else if (src0t == GGML_TYPE_Q6_K) {
-        size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
         size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
 
         backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
@@ -8973,7 +11086,16 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
     cl_ulong offset0 = extra0->offset + src0->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
 
-    cl_kernel kernel = backend_ctx->kernel_scale;
+    cl_kernel kernel;
+
+    int n = ggml_nelements(dst);
+
+    if (n % 4 == 0) {
+        kernel = backend_ctx->kernel_scale_f32_4;
+        n /= 4;
+    } else {
+        kernel = backend_ctx->kernel_scale_f32;
+    }
 
     CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@@ -8982,8 +11104,6 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
     CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &scale));
     CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float),    &bias));
 
-    int n = ggml_nelements(dst)/4;
-
     size_t global_work_size[] = {(size_t)n, 1, 1};
     size_t local_work_size[] = {64, 1, 1};
 
@@ -9005,28 +11125,13 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const
     // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst.
     UNUSED(dst);
 
-    const int ne00 = src0 ? src0->ne[0] : 0;
-    const int ne01 = src0 ? src0->ne[1] : 0;
-    const int ne02 = src0 ? src0->ne[2] : 0;
-    const int ne03 = src0 ? src0->ne[3] : 0;
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int,      ne1, src1, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);
 
-    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
-    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
-    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
-    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
-
-    const int ne10 = src1 ? src1->ne[0] : 0;
-    const int ne11 = src1 ? src1->ne[1] : 0;
-    const int ne12 = src1 ? src1->ne[2] : 0;
-    const int ne13 = src1 ? src1->ne[3] : 0;
-
-    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
-    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
-    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
-    const cl_ulong nb13 = src1 ? src1->nb[3] : 0;
-
-    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
-    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+    const enum ggml_type src0t = src0->type;
+    const enum ggml_type src1t = src1->type;
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -9063,6 +11168,15 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const
                     GGML_ASSERT(false && "not implemented");
             }
             break;
+        case GGML_TYPE_I32:
+            switch (src1t) {
+                case GGML_TYPE_I32:
+                    kernel = backend_ctx->kernel_cpy_i32_i32;
+                    break;
+                default:
+                    GGML_ASSERT(false && "not implemented");
+            }
+            break;
         default:
             GGML_ASSERT(false && "not implemented");
     }
@@ -9101,6 +11215,89 @@ static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const
     UNUSED(src1);
 }
 
+static void ggml_cl_set(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32) &&
+        src1->type == src0->type && dst->type == src0->type);
+
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int,      ne1, src1, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb);
+    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    const cl_ulong pnb1    = ((const int32_t *)dst->op_params)[0];
+    const cl_ulong pnb2    = ((const int32_t *)dst->op_params)[1];
+    const cl_ulong pnb3    = ((const int32_t *)dst->op_params)[2];
+    const cl_ulong offs    = ((const int32_t *)dst->op_params)[3];
+    const bool     inplace = (bool)((const int32_t *)dst->op_params)[4];
+
+    cl_kernel kernel = nullptr;
+
+    // for inplace case, dst is a view of src0 and is updated on top of it
+    // so for non-inplace case, copy src0 to dst first
+    if (!inplace) {
+        ggml_cl_cpy(backend, src0, dst, nullptr);
+    }
+
+    // then copy src1 to dst with specified offset
+    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        kernel = backend_ctx->kernel_cpy_f32_f32;
+    } else if (src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+        kernel = backend_ctx->kernel_cpy_i32_i32;
+    } else {
+        GGML_ASSERT(false && "not implemented");
+    }
+
+    offsetd += offs;
+    cl_ulong nb = ggml_element_size(dst);
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne11));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne11));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &pnb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &pnb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &pnb3));
+
+    int max_local_size = backend_ctx->get_kernel_workgroup_size(kernel);
+
+    const int nth = MIN(max_local_size, ne00);
+
+    size_t global_work_size[] = {(size_t)ne11*nth, (size_t)ne12, (size_t)ne13};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -9163,6 +11360,49 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
     }
 }
 
+static void ggml_cl_diag(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+    GGML_TENSOR_LOCALS(int,      ne,  dst,  ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb,  dst,  nb);
+
+    cl_kernel kernel = backend_ctx->kernel_diag_f32;
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_int),   &ne0));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb0));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb3));
+
+    int nth = 64;
+
+    size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -9474,6 +11714,72 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
+static void ggml_cl_solve_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_solve_tri_f32;
+    GGML_ASSERT(kernel != nullptr);
+
+    const int n = src0->ne[0];
+    const int k = src1->ne[0];
+
+    const cl_ulong nb00 = src0->nb[0];
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+
+    const cl_ulong nb10 = src1->nb[0];
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3];
+
+    const cl_ulong nb0 = dst->nb[0];
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &k));
+    CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),&nb10));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),&nb11));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),&nb12));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong),&nb13));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb0));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb3));
+
+    size_t global_work_size[3]= { (size_t)k, (size_t)dst->ne[2], (size_t)dst->ne[3]};
+    size_t local_work_size[] = {16, 4, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src1);
@@ -9611,7 +11917,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     GGML_UNUSED(src1);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(ggml_is_contiguous(src0));
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -9634,7 +11939,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     const cl_ulong nb2  = dst->nb[2];
     const cl_ulong nb3  = dst->nb[3];
 
-    cl_kernel kernel = backend_ctx->kernel_sum_rows_f32;
+    cl_kernel kernel;
+
+    const bool is_c4 = ne00 % 4 == 0;
+    if (is_c4) {
+        kernel = backend_ctx->kernel_sum_rows_f32_4;
+    } else {
+        kernel = backend_ctx->kernel_sum_rows_f32;
+    }
 
     CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset0));
@@ -9651,12 +11963,124 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nb2));
     CL_CHECK(clSetKernelArg(kernel,  13, sizeof(cl_ulong), &nb3));
 
-    size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
+    size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)64, 1, 1};
 
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
+static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+    GGML_UNUSED(src1);
+
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    GGML_TENSOR_LOCALS(int,      ne0, src0, ne);
+    GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb);
+
+    cl_kernel kernel = backend_ctx->kernel_cumsum_blk;
+
+    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+    int nth = 1;
+    while (nth < ne00 && 2*nth <= max_workgroup_size) {
+        nth *= 2;
+    }
+
+    GGML_ASSERT(ne00 <= nth*nth);
+
+    const int net0 = CEIL_DIV(ne00, nth);
+    const int net1 = ne01;
+    const int net2 = ne02;
+    const int net3 = ne03;
+
+    const cl_ulong nbt0 = sizeof(float);
+    const cl_ulong nbt1 = net0*nbt0;
+    const cl_ulong nbt2 = net1*nbt1;
+    const cl_ulong nbt3 = net2*nbt2;
+
+    static ggml_cl_buffer tmp_buffer;
+    tmp_buffer.allocate(backend_ctx->context, net0*ne01*ne02*ne03*sizeof(float));
+
+    CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &tmp_buffer.buffer));
+    CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,   8, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel,  13, sizeof(int),      &net0));
+    CL_CHECK(clSetKernelArg(kernel,  14, sizeof(int),      &net1));
+    CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &net2));
+
+    size_t global_work_size[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = { (size_t)nth, 1, 1};
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+
+    if(ne00 > nth) {
+        // if a single workgroup cannot handle an entire row, each workgroup
+        // computes a partial sum and stores to dst, tmp_buffer contains the sum
+        // of the each workgroup; cumsum this buffer and add to the partial sums in dst
+        cl_ulong offsett = 0;
+        kernel = backend_ctx->kernel_cumsum_blk;
+        CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &tmp_buffer.buffer));
+        CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_ulong), &offsett));
+        CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_mem),   &tmp_buffer.buffer));
+        CL_CHECK(clSetKernelArg(kernel,   3, sizeof(cl_mem),   &tmp_buffer.buffer));
+        CL_CHECK(clSetKernelArg(kernel,   4, sizeof(cl_ulong), &offsett));
+        CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &net0));
+        CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,   8, sizeof(int),      &ne03));
+        CL_CHECK(clSetKernelArg(kernel,   9, sizeof(cl_ulong), &nbt0));
+        CL_CHECK(clSetKernelArg(kernel,  10, sizeof(cl_ulong), &nbt1));
+        CL_CHECK(clSetKernelArg(kernel,  11, sizeof(cl_ulong), &nbt2));
+        CL_CHECK(clSetKernelArg(kernel,  12, sizeof(cl_ulong), &nbt3));
+        CL_CHECK(clSetKernelArg(kernel,  13, sizeof(int),      &net0));
+        CL_CHECK(clSetKernelArg(kernel,  14, sizeof(int),      &net1));
+        CL_CHECK(clSetKernelArg(kernel,  15, sizeof(int),      &net2));
+
+        size_t global_work_size_1[] = { (size_t)net1*nth, (size_t)net2, (size_t)net3};
+        size_t local_work_size_1[] = { (size_t)nth, 1, 1};
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_1, local_work_size_1, dst);
+
+        kernel = backend_ctx->kernel_cumsum_add;
+        CL_CHECK(clSetKernelArg(kernel,   0, sizeof(cl_mem),   &tmp_buffer.buffer));
+        CL_CHECK(clSetKernelArg(kernel,   1, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,   2, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,   3, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,   4, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,   5, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,   6, sizeof(int),      &ne03));
+        CL_CHECK(clSetKernelArg(kernel,   7, sizeof(int),      &nbt0));
+        CL_CHECK(clSetKernelArg(kernel,   8, sizeof(int),      &nbt1));
+        CL_CHECK(clSetKernelArg(kernel,   9, sizeof(int),      &nbt2));
+        CL_CHECK(clSetKernelArg(kernel,  10, sizeof(int),      &nbt3));
+
+        size_t global_work_size_2[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03};
+        size_t local_work_size_2[] = { (size_t)nth, 1, 1};
+        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_2, local_work_size_2, dst);
+    }
+}
+
 static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -9802,6 +12226,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_cpy;
             break;
+        case GGML_OP_SET:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_set;
+            break;
         case GGML_OP_DUP:
         case GGML_OP_CONT:
             if (!any_on_device) {
@@ -9901,6 +12331,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
                     }
                     func = ggml_cl_tanh;
                     break;
+                case GGML_UNARY_OP_NEG:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_neg;
+                    break;
+                case GGML_UNARY_OP_EXP:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_exp;
+                    break;
                 case GGML_UNARY_OP_EXPM1:
                     if (!any_on_device) {
                         return false;
@@ -9922,6 +12364,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_glu;
             break;
+        case GGML_OP_TRI:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_tri;
+            break;
         case GGML_OP_FILL:
             if (!any_on_device) {
                 return false;
@@ -9946,6 +12394,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_rms_norm;
             break;
+        case GGML_OP_L2_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_l2_norm;
+            break;
         case GGML_OP_GROUP_NORM:
             if (!any_on_device) {
                 return false;
@@ -10021,6 +12475,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_nop;
             break;
+        case GGML_OP_DIAG:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_diag;
+            break;
         case GGML_OP_DIAG_MASK_INF:
             if (!any_on_device) {
                 return false;
@@ -10039,6 +12499,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_rope;
             break;
+        case GGML_OP_SOLVE_TRI:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_solve_tri;
+            break;
         case GGML_OP_IM2COL:
             if (!any_on_device) {
                 return false;
@@ -10057,6 +12523,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_sum_rows;
             break;
+        case GGML_OP_CUMSUM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_cumsum;
+            break;
         case GGML_OP_FLASH_ATTN_EXT:
             if (!any_on_device) {
                 return false;
diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl
index 13275846..0c1b3d78 100644
--- a/ggml/src/ggml-opencl/kernels/concat.cl
+++ b/ggml/src/ggml-opencl/kernels/concat.cl
@@ -1,109 +1,51 @@
-kernel void kernel_concat_f32_contiguous(
-    global const char * p_src0, ulong off_src0,
-    global const char * p_src1, ulong off_src1,
-    global char * p_dst, ulong off_dst,
-    int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice
-    int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
-    int d_ne0,  int d_ne1,  int d_ne2,  // dst->ne[0..2] for the slice
-    int dim
+kernel void kernel_concat_f32(
+    global  const char * src0,
+    ulong                offset0,
+    global  const char * src1,
+    ulong                offset1,
+    global        char * dst,
+    ulong                offsetd,
+    int             ne00,
+    int             ne01,
+    int             ne02,
+    int             ne03,
+    ulong           nb00,
+    ulong           nb01,
+    ulong           nb02,
+    ulong           nb03,
+    ulong           nb10,
+    ulong           nb11,
+    ulong           nb12,
+    ulong           nb13,
+    int             ne0,
+    ulong           nb0,
+    ulong           nb1,
+    ulong           nb2,
+    ulong           nb3,
+    int             dim
 ) {
-    global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);
-    global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);
-    global float * dst        = (global float*)((global char*)p_dst + off_dst);
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
 
-    int i0 = get_global_id(0); // Index along dst's 0th dimension
-    int i1 = get_global_id(1); // Index along dst's 1st dimension
-    int i2 = get_global_id(2); // Index along dst's 2nd dimension
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-    if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {
-        return;
-    }
+    int o[4] = {0, 0, 0, 0};
+    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
 
-    ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;
-    ulong src_idx;
+    global const float * x;
 
-    if (dim == 0) {
-        if (i0 < d_ne00) { // Data from src0
-            src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
-            dst[dst_idx] = src0[src_idx];
-        } else { // Data from src1
-            src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
-            dst[dst_idx] = src1[src_idx];
-        }
-    } else if (dim == 1) {
-        if (i1 < d_ne01) { // Data from src0
-            src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
-            dst[dst_idx] = src0[src_idx];
-        } else { // Data from src1
-            src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
-            dst[dst_idx] = src1[src_idx];
-        }
-    } else if (dim == 2) {
-        if (i2 < d_ne02) { // Data from src0
-            src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
-            dst[dst_idx] = src0[src_idx];
-        } else { // Data from src1
-
-            src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
-            dst[dst_idx] = src1[src_idx];
-        }
-    }
-}
-
-kernel void kernel_concat_f32_non_contiguous(
-    global const char * p_src0, ulong off_src0,
-    global const char * p_src1, ulong off_src1,
-    global char * p_dst, ulong off_dst,
-
-    long ne00, long ne01, long ne02, long ne03,
-    ulong nb00, ulong nb01, ulong nb02, ulong nb03,
-
-    ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
-
-    long d_ne0, long d_ne1, long d_ne2, long d_ne3,
-    ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
-    int dim
-) {
-    global const char * src0_base = p_src0 + off_src0;
-    global const char * src1_base = p_src1 + off_src1;
-    global char * dst_base        = p_dst + off_dst;
-
-    long current_i1 = get_global_id(0); // Index for dst_dim_1
-    long current_i2 = get_global_id(1); // Index for dst_dim_2
-    long current_i3 = get_global_id(2); // Index for dst_dim_3
-
-    if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
-        return;
-    }
-
-    global const float * x_val_ptr;
-    global float * y_val_ptr;
-
-    for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
-        bool use_src0;
-        long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
-
-        if (dim == 0) {
-            use_src0 = (current_i0 < ne00);
-            if (!use_src0) { s_i0 = current_i0 - ne00; }
-        } else if (dim == 1) {
-            use_src0 = (current_i1 < ne01);
-            if (!use_src0) { s_i1 = current_i1 - ne01; }
-        } else if (dim == 2) {
-            use_src0 = (current_i2 < ne02);
-            if (!use_src0) { s_i2 = current_i2 - ne02; }
-        } else { // dim == 3
-            use_src0 = (current_i3 < ne03);
-            if (!use_src0) { s_i3 = current_i3 - ne03; }
-        }
-
-        if (use_src0) {
-            x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (global const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
         } else {
-            x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);
+            x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
         }
 
-        y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);
-        *y_val_ptr = *x_val_ptr;
+        global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
     }
 }
diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl
index 9369351a..820aa538 100644
--- a/ggml/src/ggml-opencl/kernels/cpy.cl
+++ b/ggml/src/ggml-opencl/kernels/cpy.cl
@@ -182,3 +182,48 @@ kernel void kernel_cpy_f32_f32(
         dst_data[i00] = src[0];
     }
 }
+
+kernel void kernel_cpy_i32_i32(
+        global int * src0,
+        ulong offset0,
+        global int * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global int*)((global char*)src0 + offset0);
+    dst = (global int*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global int * dst_data = (global int *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const int * src = (global int *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl
new file mode 100644
index 00000000..edfb74b7
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/cumsum.cl
@@ -0,0 +1,139 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+// max workgroup size is usually 1024, this covers various subgroups sizes
+#define MAX_SUBGROUPS 128
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_cumsum_blk(
+        global char * src0,
+        ulong offset0,
+        global char * tmp,
+        global char * dst,
+        ulong offsetd,
+        int   ne00,
+        int   ne01,
+        int   ne02,
+        int   ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        uint net0,
+        uint net1,
+        uint net2
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int nth = get_local_size(0);
+    const int tid = get_local_id(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    const int ib = i1 / ne01;
+    const int i00 = ib * nth;
+    const int i01 = i1 % ne01;
+    const int i02 = i2;
+    const int i03 = i3;
+
+    global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    global       float * tmp_row  = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03;
+    global       float * dst_row  = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    __local float partial[MAX_SUBGROUPS];
+
+    float v = 0.0f;
+    if (i00 + tid < ne00) {
+        v = src0_row[i00 + tid];
+    }
+
+    float s = sub_group_scan_inclusive_add(v);
+    if (sg_lid == sg_size - 1) {
+        partial[sg_id] = s;
+    }
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    // NB: subgroup size should be larger than number of subgroups
+    // assuming max workgroup size of 1024, subgroup size should be >= 32
+    if (sg_id == 0) {
+        float x = 0.0f;
+        if (sg_lid < get_num_sub_groups()) {
+            x = partial[sg_lid];
+        }
+        float ex = sub_group_scan_exclusive_add(x);
+        if (sg_lid < get_num_sub_groups()) {
+            partial[sg_lid] = ex;
+        }
+    }
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    s += partial[sg_id];
+
+    if (i00 + tid < ne00) {
+        dst_row[i00 + tid] = s;
+    }
+    if (ne00 > nth && tid == nth - 1) {
+        tmp_row[ib] = s;
+    }
+}
+
+kernel void kernel_cumsum_add(
+        global char * tmp,
+        global char * dst,
+        ulong offsetd,
+        int   ne00,
+        int   ne01,
+        int   ne02,
+        int   ne03,
+        uint nbt0,
+        uint nbt1,
+        uint nbt2,
+        uint nbt3
+) {
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int nth = get_local_size(0);
+    const int tid = get_local_id(0);
+
+    const int ib = i1 / ne01;
+    if (ib == 0) {
+        return;
+    }
+    const int i00 = ib * nth;
+    const int i01 = i1 % ne01;
+    const int i02 = i2;
+    const int i03 = i3;
+
+    global float * tmp_row  = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03);
+    global float * dst_row  = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    if (i00 + tid < ne00) {
+        dst_row[i00 + tid] += tmp_row[ib - 1];
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl
index 513a4d3e..78ef9c17 100644
--- a/ggml/src/ggml-opencl/kernels/cvt.cl
+++ b/ggml/src/ggml-opencl/kernels/cvt.cl
@@ -46,6 +46,25 @@ struct block_q4_0
     uint8_t qs[QK4_0 / 2];
 };
 
+//------------------------------------------------------------------------------
+// block_q4_1
+//------------------------------------------------------------------------------
+struct block_q4_1 {
+    half d; // delta
+    half m; // min
+    uchar qs[QK4_1 / 2]; // nibbles / quants
+};
+
+//------------------------------------------------------------------------------
+// block_q6_K
+//------------------------------------------------------------------------------
+struct block_q6_K {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    half d;                  // super-block scale
+};
+
 //------------------------------------------------------------------------------
 // kernel_convert_block_q4_0
 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
@@ -138,6 +157,100 @@ kernel void kernel_restore_block_q4_0_noshuffle(
     }
 }
 
+//------------------------------------------------------------------------------
+// kernel_convert_block_q4_1
+// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
+// This kernel does not deshuffle the bits.
+//------------------------------------------------------------------------------
+kernel void kernel_convert_block_q4_1(
+    global struct block_q4_1 * src0,
+    global uchar * dst_q,
+    global half  * dst_d,
+    global half  * dst_m
+) {
+    global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
+    global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
+    global half  * d = (global half *) dst_d + get_global_id(0);
+    global half  * m = (global half *) dst_m + get_global_id(0);
+
+    *d = b->d;
+    *m = b->m;
+
+    for (int i = 0; i < QK4_1/2; ++i) {
+        q[i] = b->qs[i];
+    }
+}
+
+kernel void kernel_restore_block_q4_1(
+    global uchar * src_q,
+    global half  * src_d,
+    global half  * src_m,
+    global struct block_q4_1 * dst
+) {
+    global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
+    global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
+    global half  * d = (global half *) src_d + get_global_id(0);
+    global half  * m = (global half *) src_m + get_global_id(0);
+
+    b->d = *d;
+    b->m = *m;
+    for (int i = 0; i < QK4_1/2; ++i) {
+        b->qs[i] = q[i];
+    }
+}
+
+kernel void kernel_convert_block_q4_1_noshuffle(
+    global struct block_q4_1 * src0,
+    global uchar * dst_q,
+    global half  * dst_d,
+    global half  * dst_m
+) {
+    global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
+    global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
+    global half  * d = (global half *) dst_d + get_global_id(0);
+    global half  * m = (global half *) dst_m + get_global_id(0);
+
+    *d = b->d;
+    *m = b->m;
+    for (int i = 0; i < QK4_1/4; ++i) {
+        uchar x0 = b->qs[2*i + 0];
+        uchar x1 = b->qs[2*i + 1];
+
+        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
+        q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
+
+#ifdef ADRENO_GPU
+        if (get_global_id(0) == 65536*4096) {
+            printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
+        }
+#endif
+    }
+}
+
+kernel void kernel_restore_block_q4_1_noshuffle(
+    global uchar * src_q,
+    global half  * src_d,
+    global half  * src_m,
+    global struct block_q4_1 * dst,
+    uchar mask_0F,
+    uchar mask_F0
+) {
+    global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
+    global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
+    global half  * d = (global half *) src_d + get_global_id(0);
+    global half  * m = (global half *) src_m + get_global_id(0);
+
+    b->d = *d;
+    b->m = *m;
+    for (int i = 0; i < QK4_1/4; ++i) {
+        uchar x0 = q[i + 0      ] ;
+        uchar x1 = q[i + QK4_1/4];
+
+        b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
+        b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
+    }
+}
+
 //------------------------------------------------------------------------------
 // block_mxfp4
 //------------------------------------------------------------------------------
@@ -263,3 +376,94 @@ kernel void kernel_restore_block_q8_0(
         b->qs[i] = q[i];
     }
 }
+
+kernel void kernel_restore_block_q8_0_trans(
+    global uchar * src_q,
+    global half  * src_d,
+    global block_q8_0 * dst,
+    uint ne00,
+    uint ne01
+){
+    uint num_blk_per_row = ne00 / QK8_0;
+
+    global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row;
+    global uchar      * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed
+    global half       * d = (global half *) src_d + get_global_id(0);
+
+    for (uint blk = 0; blk < num_blk_per_row; blk++) {
+        b->d = *d;
+
+        for (uint i = 0; i < QK8_0; i+=4) {
+            b->qs[i]   = q[0];
+            b->qs[i+1] = q[1];
+            b->qs[i+2] = q[2];
+            b->qs[i+3] = q[3];
+
+            q += 4 * ne01; // M stride
+        }
+
+        d += ne01;
+
+        b++;
+    }
+}
+
+//------------------------------------------------------------------------------
+// kernel_convert_block_q6_K
+// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA).
+// This kernel does not deshuffle the bits.
+// Each thread processes a super block.
+//------------------------------------------------------------------------------
+kernel void kernel_convert_block_q6_K(
+    global struct block_q6_K * src0,
+    global uchar * dst_ql,
+    global uchar * dst_qh,
+    global char  * dst_s,
+    global half  * dst_d
+) {
+    global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0);
+    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
+    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
+    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
+    global half  * d  = (global half  *) dst_d  + get_global_id(0);
+
+    *d = b->d;
+
+    for (int i = 0; i < QK_K/2; ++i) {
+        ql[i] = b->ql[i];
+    }
+    for (int i = 0; i < QK_K/4; ++i) {
+        qh[i] = b->qh[i];
+    }
+    for (int i = 0; i < QK_K/16; ++i) {
+        s[i] = b->scales[i];
+    }
+}
+
+// Restore block_q6_K from flattened arrays.
+// Each thread processes a super block.
+kernel void kernel_restore_block_q6_K(
+    global uchar * dst_ql,
+    global uchar * dst_qh,
+    global char  * dst_s,
+    global half  * dst_d,
+    global struct block_q6_K * dst
+) {
+    global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0);
+    global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0);
+    global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0);
+    global char  * s  = (global char  *) dst_s  + QK_K/16*get_global_id(0);
+    global half  * d  = (global half  *) dst_d  + get_global_id(0);
+
+    b->d = *d;
+
+    for (int i = 0; i < QK_K/2; ++i) {
+        b->ql[i] = ql[i];
+    }
+    for (int i = 0; i < QK_K/4; ++i) {
+        b->qh[i] = qh[i];
+    }
+    for (int i = 0; i < QK_K/16; ++i) {
+        b->scales[i] = s[i];
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/diag.cl b/ggml/src/ggml-opencl/kernels/diag.cl
new file mode 100644
index 00000000..884efa08
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/diag.cl
@@ -0,0 +1,27 @@
+kernel void kernel_diag_f32(
+    global const char * src0,
+    ulong               offset0,
+    global       char * dst,
+    ulong               offsetd,
+    ulong               nb01,
+    ulong               nb02,
+    ulong               nb03,
+    int                 ne0,
+    ulong               nb0,
+    ulong               nb2,
+    ulong               nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    global const float * src0_ptr = (global const float *)(src0 +           i2*nb02 + i3*nb03);
+    global       float * dst_ptr  = (global       float *)(dst  + i1*nb01 + i2*nb2  + i3*nb3);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/exp.cl b/ggml/src/ggml-opencl/kernels/exp.cl
new file mode 100644
index 00000000..a2458b65
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/exp.cl
@@ -0,0 +1,125 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+kernel void kernel_exp_f32(
+        global const float * src0,
+        ulong                offset0,
+        global       float * dst,
+        ulong                offsetd,
+        int                  n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst  = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);
+}
+
+kernel void kernel_exp_f32_4(
+        global const float4 * src0,
+        ulong                 offset0,
+        global       float4 * dst,
+        ulong                 offsetd,
+        int                   n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst  = (global float4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);
+}
+
+kernel void kernel_exp_f16(
+        global const half * src0,
+        ulong               offset0,
+        global       half * dst,
+        ulong               offsetd,
+        int                 n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst  = (global half*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);
+}
+
+kernel void kernel_exp_f16_4(
+        global const half4 * src0,
+        ulong                offset0,
+        global       half4 * dst,
+        ulong                offsetd,
+        int                  n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global half4*)((global char*)src0 + offset0);
+    dst  = (global half4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]);
+}
+
+kernel void kernel_exp_f32_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+        *y = exp(*x);
+    }
+}
+
+kernel void kernel_exp_f16_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+        *y = exp(*x);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/expm1.cl b/ggml/src/ggml-opencl/kernels/expm1.cl
index 126298a2..05442ac2 100644
--- a/ggml/src/ggml-opencl/kernels/expm1.cl
+++ b/ggml/src/ggml-opencl/kernels/expm1.cl
@@ -3,80 +3,111 @@
 //------------------------------------------------------------------------------
 // expm1
 //------------------------------------------------------------------------------
-kernel void kernel_expm1_f32_nd(
-        global void * p_src0_base,
-        ulong off_src0_abs,
-        global void * p_dst_base,
-        ulong off_dst_abs,
-        int ne00,
-        int ne01,
-        int ne02,
-        int ne03,
+
+kernel void kernel_expm1_f32(
+        global const float * src0,
+        ulong                offset0,
+        global       float * dst,
+        ulong                offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst  = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
+}
+
+kernel void kernel_expm1_f32_4(
+        global const float4 * src0,
+        ulong                 offset0,
+        global       float4 * dst,
+        ulong                 offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst  = (global float4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f;
+}
+
+kernel void kernel_expm1_f16(
+        global const half * src0,
+        ulong               offset0,
+        global       half * dst,
+        ulong               offsetd
+) {
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst  = (global half*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
+}
+
+kernel void kernel_expm1_f16_4(
+        global const half4 * src0,
+        ulong                offset0,
+        global       half4 * dst,
+        ulong                offsetd
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    dst  = (global half4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h;
+}
+
+kernel void kernel_expm1_f32_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
         ulong nb00,
         ulong nb01,
         ulong nb02,
         ulong nb03,
-        int ne10,
-        int ne11,
-        int ne12,
-        int ne13,
-        ulong nb10,
-        ulong nb11,
-        ulong nb12,
-        ulong nb13
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
 ) {
-    int i0 = get_global_id(0);
-    int i1 = get_global_id(1);
-    int i2 = get_global_id(2);
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
 
-    if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
-        for (int i3 = 0; i3 < ne13; ++i3) {
-            ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
-            global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-            ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
-            global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-            *dst_val_ptr = exp(*src_val_ptr) - 1;
-        }
+        *y = exp(*x) - 1.0f;
     }
 }
 
-kernel void kernel_expm1_f16_nd(
-        global void * p_src0_base,
-        ulong off_src0_abs,
-        global void * p_dst_base,
-        ulong off_dst_abs,
-        int ne00,
-        int ne01,
-        int ne02,
-        int ne03,
+kernel void kernel_expm1_f16_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
         ulong nb00,
         ulong nb01,
         ulong nb02,
         ulong nb03,
-        int ne10,
-        int ne11,
-        int ne12,
-        int ne13,
-        ulong nb10,
-        ulong nb11,
-        ulong nb12,
-        ulong nb13
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
 ) {
-    int i0 = get_global_id(0);
-    int i1 = get_global_id(1);
-    int i2 = get_global_id(2);
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
 
-    if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
-        for (int i3 = 0; i3 < ne13; ++i3) {
-            ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
-            global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-            ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
-            global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-            *dst_val_ptr = exp(*src_val_ptr) - 1;
-        }
+        *y = exp(*x) - 1.0f;
     }
 }
diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl
new file mode 100644
index 00000000..5c4d5cc8
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl
@@ -0,0 +1,132 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+#ifdef cl_qcom_reqd_sub_group_size
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_128
+#endif
+
+kernel void kernel_gemm_noshuffle_q4_1_f32(
+    global const ushort * src0_q,
+    global const half  * src0_d,
+    global const half  * src0_m,
+    read_only image1d_buffer_t src1,
+    global float * dst,
+    ulong offsetd,
+    int m,
+    int n,
+    int k,
+    int n_no_padding
+) {
+    dst = (global float *)((global char *)dst + offsetd);
+
+    int m_4 = m >> 2;
+    int n_4 = n >> 2;
+
+    int gy = get_global_id(0);
+    int gx = get_global_id(1);
+    int gx_2 = gx << 2;
+
+    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
+    half8 B;
+    half4 dequantized_weights;
+
+    global const ushort* weight_ptr = src0_q + gx_2;
+    global const half*   scale_ptr  = src0_d + gx_2;
+    global const half*   min_ptr    = src0_m + gx_2;
+
+    for(int i = 0; i < k; i += 4) {
+        B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
+
+        ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m));
+
+        half4 scale = vload4(0, scale_ptr + (i/32)*(m));
+        half4 minv  = vload4(0,   min_ptr + (i/32)*(m));
+
+        // j=0
+        dequantized_weights.s0 = (bits4.s0 & (0x000F)) * scale.s0 + minv.s0;
+        dequantized_weights.s1 = (bits4.s1 & (0x000F)) * scale.s1 + minv.s1;
+        dequantized_weights.s2 = (bits4.s2 & (0x000F)) * scale.s2 + minv.s2;
+        dequantized_weights.s3 = (bits4.s3 & (0x000F)) * scale.s3 + minv.s3;
+        c0 += B * dequantized_weights.s0;
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=1
+        B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
+        dequantized_weights.s0 = ((bits4.s0 & (0x00F0)) >> 4) * scale.s0 + minv.s0;
+        dequantized_weights.s1 = ((bits4.s1 & (0x00F0)) >> 4) * scale.s1 + minv.s1;
+        dequantized_weights.s2 = ((bits4.s2 & (0x00F0)) >> 4) * scale.s2 + minv.s2;
+        dequantized_weights.s3 = ((bits4.s3 & (0x00F0)) >> 4) * scale.s3 + minv.s3;
+        c0 += B * dequantized_weights.s0;
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=2
+        B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
+        dequantized_weights.s0 = ((bits4.s0 & (0x0F00)) >> 8) * scale.s0 + minv.s0;
+        dequantized_weights.s1 = ((bits4.s1 & (0x0F00)) >> 8) * scale.s1 + minv.s1;
+        dequantized_weights.s2 = ((bits4.s2 & (0x0F00)) >> 8) * scale.s2 + minv.s2;
+        dequantized_weights.s3 = ((bits4.s3 & (0x0F00)) >> 8) * scale.s3 + minv.s3;
+        c0 += B * dequantized_weights.s0;
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=3
+        B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
+        dequantized_weights.s0 = ((bits4.s0 & (0xF000)) >> 12) * scale.s0 + minv.s0;
+        dequantized_weights.s1 = ((bits4.s1 & (0xF000)) >> 12) * scale.s1 + minv.s1;
+        dequantized_weights.s2 = ((bits4.s2 & (0xF000)) >> 12) * scale.s2 + minv.s2;
+        dequantized_weights.s3 = ((bits4.s3 & (0xF000)) >> 12) * scale.s3 + minv.s3;
+        c0 += B * dequantized_weights.s0;
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+    }
+
+    int idx = (gy<<3)*m + (gx<<2);
+
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl
new file mode 100644
index 00000000..9703b693
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl
@@ -0,0 +1,195 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+
+#ifdef cl_qcom_reqd_sub_group_size
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
+#endif
+
+#define QK8_0 32
+#define N_SIMDGROUP 4
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \
+    float shared_y; \
+    char elem; \
+                                             \
+    shared_y = sub_group_broadcast(y.s0, 0); \
+    elem = (char)(bits8.s0 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 0); \
+    elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 0); \
+    elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 0); \
+    elem = (char)((bits8.s0 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s4, 0); \
+    elem = (char)(bits8.s1 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 0); \
+    elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 0); \
+    elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 0); \
+    elem = (char)((bits8.s1 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s0, 1); \
+    elem = (char)(bits8.s2 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 1); \
+    elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 1); \
+    elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 1); \
+    elem = (char)((bits8.s2 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s4, 1); \
+    elem = (char)(bits8.s3 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 1); \
+    elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 1); \
+    elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 1); \
+    elem = (char)((bits8.s3 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s0, 2); \
+    elem = (char)(bits8.s4 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 2); \
+    elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 2); \
+    elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 2); \
+    elem = (char)((bits8.s4 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s4, 2); \
+    elem = (char)(bits8.s5 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 2); \
+    elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 2); \
+    elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 2); \
+    elem = (char)((bits8.s5 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s0, 3); \
+    elem = (char)(bits8.s6 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 3); \
+    elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 3); \
+    elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 3); \
+    elem = (char)((bits8.s6 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+                                             \
+    shared_y = sub_group_broadcast(y.s4, 3); \
+    elem = (char)(bits8.s7 & 0x000000FF); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 3); \
+    elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 3); \
+    elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 3); \
+    elem = (char)((bits8.s7 & 0xFF000000) >> 24); \
+    total_sums += convert_int(elem) * scale * shared_y; \
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+__kernel void kernel_gemv_noshuffle_q8_0_f32(
+        __read_only  image1d_buffer_t src0_q,  // quantized A
+        global half  * src0_d,  // A scales
+        __read_only  image1d_buffer_t src1,    // B
+        ulong offset1,            // offset to B (0)
+        global float * dst,     // C
+        ulong offsetd,            // offset to C
+        int ne00,               // K
+        int ne01,               // M
+        int ne02,               // 1
+        int ne10,               // K
+        int ne12,               // 1
+        int ne0,                // M
+        int ne1,                // N
+        int r2,                 // 1
+        int r3)
+{
+    uint groupId = get_local_id(1);
+    uint gid     = get_global_id(0);
+    ushort slid    = get_sub_group_local_id();
+
+    uint K = ne00;
+    uint M = ne01;
+
+    uint LINE_STRIDE_A = M;
+    uint BLOCK_STRIDE_A = 8 * M;   // 32 / 4 = 8
+
+    __private uint8     regA;
+    __private half      regS;
+    __private float8    regB;
+
+    __private float totalSum = (float)(0.0f);
+
+    // loop along K in block granularity, skip 4 blocks every iter
+    #pragma unroll 1 /* tell compiler not to unroll */
+    for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) {
+        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows
+        // first 4 fibers in each wave load 8 B values to its private scope
+        if (slid < 4) {
+            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
+            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
+        }
+
+        // load weights for one block in consecutive rows
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
+        regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
+        regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
+        regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
+        regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
+
+        dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
+    }
+
+    // reduction in local memory, assumes #wave=4
+    __local float reduceLM[SIMDGROUP_WIDTH * 3];
+    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
+    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
+    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 1 outputs per fiber in wave 0
+    if (groupId == 0) {
+        dst = (global float*)((global char*)dst + offsetd);
+        dst[gid] = totalSum;
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl
new file mode 100644
index 00000000..fdc14724
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl
@@ -0,0 +1,283 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+
+#ifdef cl_qcom_reqd_sub_group_size
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
+#endif
+
+#define QK4_0 32
+#define NSUBGROUPS 4
+#define SUBGROUP_SIZE 64
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \
+    float shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 0); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 0); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 0); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 1); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 1); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 1); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \
+    shared_y = sub_group_broadcast(y.s0, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 2); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 2); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 2); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 3); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 3); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 3); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \
+    float8 shared_y; \
+    shared_y = sub_group_broadcast(y, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \
+    shared_y = sub_group_broadcast(y, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F)         * scale.s0 + minv.s0) * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4)  * scale.s0 + minv.s0) * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8)  * scale.s0 + minv.s0) * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F)         * scale.s1 + minv.s1) * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4)  * scale.s1 + minv.s1) * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8)  * scale.s1 + minv.s1) * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_gemv_noshuffle_q4_1_f32(
+        read_only  image1d_buffer_t src0_q,
+        global half2  * src0_d,
+        global half2  * src0_m,
+        read_only  image1d_buffer_t src1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01)
+{
+    uint groupId = get_local_id(1);
+    uint gid     = get_global_id(0);
+    ushort slid    = get_sub_group_local_id();
+
+    uint K = ne00;
+    uint M = ne01;
+
+    uint LINE_STRIDE_A = M / 2;
+    uint BLOCK_STRIDE_A = NSUBGROUPS * M;
+
+    private uint4     regA;
+    private half2     regS;
+    private half2     regM;
+    private float8    regB;
+
+    private float2 totalSum = (float2)(0.0f);
+
+    // loop along K in block granularity, skip 4 blocks every iter
+    for (uint k = groupId; k < (K / QK4_0); k += NSUBGROUPS) {
+        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
+        regM = src0_m[gid + k * LINE_STRIDE_A]; // each fiber loads min of two rows
+        // first 4 fibers in each wave load 8 B values to its private scope
+        if (slid < 4) {
+            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
+            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
+        }
+
+        // load half weights for two blocks in consecutive rows
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+    }
+
+    // reduction in local memory, assumes #wave=4
+    local float2 reduceLM[SUBGROUP_SIZE * 3];
+    if (groupId == 1) {
+        reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum;
+    }
+    if (groupId == 2) {
+        reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum;
+    }
+    if (groupId == 3) {
+        reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if (groupId == 0) {
+        totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid];
+    }
+    if (groupId == 0) {
+        totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid];
+    }
+    if (groupId == 0) {
+        totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid];
+    }
+
+    // 2 outputs per fiber in wave 0
+    if (groupId == 0) {
+        dst = (global float*)((global char*)dst + offsetd);
+        vstore2(totalSum, 0, &(dst[gid * 2]));
+    }
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/l2_norm.cl b/ggml/src/ggml-opencl/kernels/l2_norm.cl
new file mode 100644
index 00000000..fb95355a
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/l2_norm.cl
@@ -0,0 +1,71 @@
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_l2_norm_f32(
+        global void * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        float eps,
+        local float * sum
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    global float * y = (global float *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    float sumf = 0;
+
+    // parallel sum
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        sumf += x[i00] * x[i00];
+    }
+    sumf = sub_group_reduce_add(sumf);
+
+    if (get_sub_group_local_id() == 0) {
+        sum[get_sub_group_id()] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    // broadcast
+    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+       if (get_local_id(0) < i) {
+           sum[get_local_id(0)] += sum[get_local_id(0) + i];
+       }
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    const float scale = 1.0f/max(sqrt(sum[0]), eps);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        y[i00] = x[i00] * scale;
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl
index 5c3e8bcd..7c7e0a58 100644
--- a/ggml/src/ggml-opencl/kernels/mean.cl
+++ b/ggml/src/ggml-opencl/kernels/mean.cl
@@ -1,8 +1,13 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
 
+// Most devices have max workgroup size of 1024, so this is enough for subgroup
+// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
+#define MAX_SUBGROUPS 64
 kernel void kernel_mean_f32(
-    global float *  src0,
+    global char *  src0,
     ulong           offset0,
-    global float *  dst,
+    global char *  dst,
     ulong           offsetd,
     int             ne00,
     int             ne01,
@@ -15,25 +20,121 @@ kernel void kernel_mean_f32(
     ulong           nb2,
     ulong           nb3
 ) {
-    src0 = (global float *)((global char *)src0 + offset0);
-    dst  = (global float *)((global char *)dst  + offsetd);
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
 
-    int i3 = get_global_id(2);
-    int i2 = get_global_id(1);
-    int i1 = get_global_id(0);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
 
     if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
         return;
     }
 
-    global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    global float * dst_row = (global float *) ((global char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
-
-    float row_sum = 0;
-
-    for (int i0 = 0; i0 < ne00; i0++) {
-        row_sum += src_row[i0];
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
     }
 
-    dst_row[0] = row_sum / ne00;
+    global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float * dst_row = (global float *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float sumf = 0.0f;
+
+    for (int i0 = lid; i0 < ne00; i0 += lsize) {
+        sumf += src_row[i0];
+    }
+
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf / ne00;
+    }
+}
+
+kernel void kernel_mean_f32_4(
+    global char *  src0,
+    ulong           offset0,
+    global char *  dst,
+    ulong           offsetd,
+    int             ne00,
+    int             ne01,
+    int             ne02,
+    int             ne03,
+    ulong           nb01,
+    ulong           nb02,
+    ulong           nb03,
+    ulong           nb1,
+    ulong           nb2,
+    ulong           nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
+    }
+
+    global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float  * dst_row = (global float  *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float4 sum_vec = (float4)0.0f;
+
+    for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
+        sum_vec += src_row[i0];
+    }
+
+    float sumf = dot(sum_vec, (float4)(1.0f));
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf / ne00;
+    }
 }
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl
new file mode 100644
index 00000000..4100e308
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl
@@ -0,0 +1,163 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 8
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_q4_0_f32_l4_lm(
+    global uchar4 * src0_q,
+    global half   * src0_d,
+    global float4 * src1,
+    ulong offset1,
+    global float  * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst  = (global float *)((global char*)dst  + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            if (ir*BM + loadc_a + l < ne01) {
+                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+                int ib  = idx / 4;
+                int iqs = idx % 4;
+
+                float d = (float)src0_d[ib];
+                global uchar4 * qs = src0_q + ib*4 + iqs;
+                uchar4 q = *qs;
+                float4 v1 = (convert_float4((uchar4)((q.s0   )&0x0F, (q.s1   )&0x0F, (q.s2   )&0x0F, (q.s3   )&0x0F)) - 8.0f)*d;
+                float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
+
+                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = v1.s0;
+                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = v1.s1;
+                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = v1.s2;
+                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = v1.s3;
+                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
+                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
+                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
+                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
+            } else {
+                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
+            }
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            if (ic*BN + loadc_b + l < ne11) {
+                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+            }
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl
new file mode 100644
index 00000000..d0d2f083
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl
@@ -0,0 +1,165 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 8
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_q4_1_f32_l4_lm(
+    global uchar4 * src0_q,
+    global half   * src0_d,
+    global half   * src0_m,
+    global float4 * src1,
+    ulong offset1,
+    global float  * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst  = (global float *)((global char*)dst  + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            if (ir*BM + loadc_a + l < ne01) {
+                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+                int ib  = idx / 4;
+                int iqs = idx % 4;
+
+                float d = (float)src0_d[ib];
+                float m = (float)src0_m[ib];
+                global uchar4 * qs = src0_q + ib*4 + iqs;
+                uchar4 q = *qs;
+                float4 v1 = (convert_float4((uchar4)((q.s0   )&0x0F, (q.s1   )&0x0F, (q.s2   )&0x0F, (q.s3   )&0x0F)))*d + m;
+                float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
+
+                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = v1.s0;
+                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = v1.s1;
+                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = v1.s2;
+                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = v1.s3;
+                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
+                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
+                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
+                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
+            } else {
+                buf_a[(loadr_a * 4 +  0) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 +  1) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 +  2) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 +  3) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
+            }
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            if (ic*BN + loadc_b + l < ne11) {
+                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+            }
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl
new file mode 100644
index 00000000..3602c92f
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl
@@ -0,0 +1,158 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#define LOAD_VEC_A 2
+#define LOAD_VEC_B 4
+
+#define BM 64
+#define BN 64
+#define BK 32
+#define TM 4
+#define TN 8
+
+kernel void kernel_mul_mm_q6_k_f32_l4_lm(
+    global uchar * src0_ql,
+    global uchar * src0_qh,
+    global char  * src0_s,
+    global half  * src0_d,
+    global float4 * src1,
+    ulong offset1,
+    global float  * dst,
+    ulong offsetd,
+
+    int ne00,
+    int ne01,
+    int ne02,
+    int ne11,
+    int ne12,
+
+    int stride_a,
+    int stride_b,
+    int stride_d,
+
+    int batch_stride_a,
+    int batch_stride_b,
+    int batch_stride_d,
+
+    int r2,
+    int r3
+) {
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst  = (global float *)((global char*)dst  + offsetd);
+
+    local float buf_a[BM * BK];
+    local float buf_b[BN * BK];
+
+    const int batch_idx = get_global_id(2);
+
+    const int i13 = batch_idx / ne12;
+    const int i12 = batch_idx % ne12;
+
+    const int i03 = i13 / r3;
+    const int i02 = i12 / r2;
+
+    const int batch_idx_a = i03 * ne02 + i02;
+
+    const int ir = get_group_id(0);
+    const int ic = get_group_id(1);
+
+    const int tid = get_local_id(0);
+    const int th_r  = tid % (BM / TM);
+    const int th_c  = tid / (BM / TM);
+
+    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
+    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
+    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
+    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
+
+    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
+    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
+
+    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
+    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
+
+    float sums[TM * TN];
+    float cache_a[TM];
+    float cache_b[TN];
+
+    for (int i = 0; i < TM * TN; i++) {
+        sums[i] = 0.0f;
+    }
+
+    for (int block = 0; block < ne00; block += BK) {
+        for (int l = 0; l < BM; l += loadstride_a) {
+            if (ir*BM + loadc_a + l < ne01) {
+                int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+
+                int ib = idx / 128;                  // 2 values per idx
+                int iqs = idx % 128;                 // 0..127
+
+                int n = iqs / 64;                    // 0,1
+                int b = (iqs % 64) / 32;             // 0,1
+                int is_b = (iqs % 16) / 8;           // 0,1
+                int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
+                int is = 8 * n + qhshift + is_b;     // 0..15
+                int qsi = n * 64 + (iqs % 32) * 2;   // 0,2,4..126
+                int qhi = n * 32 + (iqs % 16) * 2;   // 0,2,4..62
+
+                float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
+
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
+            } else {
+                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
+                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
+            }
+        }
+
+        for (int l = 0; l < BN; l += loadstride_b) {
+            if (ic*BN + loadc_b + l < ne11) {
+                int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
+            } else {
+                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
+                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
+            }
+        }
+
+        barrier(CLK_LOCAL_MEM_FENCE);
+
+        pos_a += BK / LOAD_VEC_A;
+        pos_b += BK / LOAD_VEC_B;
+
+        for (int i = 0; i < BK; i++) {
+            for (int j = 0; j < TM; j++) {
+                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
+            }
+
+            for (int j = 0; j < TN; j++) {
+                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
+            }
+
+            for (int cc = 0; cc < TN; cc++) {
+                for (int cr = 0; cr < TM; cr++) {
+                    const int sums_idx = cc*TM + cr;
+                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
+                }
+            }
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+
+    const int dr = ir * BM + th_r * TM;
+    const int dc = ic * BN + th_c * TN;
+
+    const int offsets = batch_idx * batch_stride_d;
+
+    for (int cc = 0; cc < TN; cc++) {
+        for (int cr = 0; cr < TM; cr++) {
+            if (dr + cr < ne01 && dc + cc < ne11) {
+                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl
new file mode 100644
index 00000000..51ce2121
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl
@@ -0,0 +1,129 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+#ifdef cl_qcom_reqd_sub_group_size
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_128
+#endif
+
+kernel void kernel_mul_mm_q8_0_f32_8x4(
+        global const uint * src0_q,
+        global const half  * src0_d,
+        __read_only image1d_buffer_t src1,
+        global float * dst,
+        int k,
+        int m,
+        int n,
+        int n_no_padding,
+        ulong offsetd
+) {
+
+    int m_4 = m >> 2;
+    int n_4 = n >> 2;
+
+    int gy   = get_global_id(0);
+    int gx   = get_global_id(1);
+    int gx_2 = gx << 2;
+    dst  = (global float *)((global char*)dst  + offsetd);
+
+
+    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
+    half8 B;
+    half4 deq;
+
+    __global const uint* wptr = src0_q + gx_2;
+    __global const half* sptr = src0_d + gx_2;
+
+      for (int i = 0; i < k; i += 4) {
+        uint4 pack4 = vload4(0, wptr + (i / 4) * m);
+        half4 scale = vload4(0, sptr + (i / 32) * m);
+
+        char4 p0 = as_char4(pack4.s0);
+        char4 p1 = as_char4(pack4.s1);
+        char4 p2 = as_char4(pack4.s2);
+        char4 p3 = as_char4(pack4.s3);
+
+        // ------------------- j = 0 (k = i+0) -------------------
+        B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4);
+        B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1);
+
+        half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale;
+
+        c0 += B * wj0.s0;
+        c1 += B * wj0.s1;
+        c2 += B * wj0.s2;
+        c3 += B * wj0.s3;
+
+        // ------------------- j = 1 (k = i+1) -------------------
+        B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4);
+        B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1);
+
+        half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale;
+
+        c0 += B * wj1.s0;
+        c1 += B * wj1.s1;
+        c2 += B * wj1.s2;
+        c3 += B * wj1.s3;
+
+        // ------------------- j = 2 (k = i+2) -------------------
+        B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4);
+        B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1);
+
+        half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale;
+
+        c0 += B * wj2.s0;
+        c1 += B * wj2.s1;
+        c2 += B * wj2.s2;
+        c3 += B * wj2.s3;
+
+        // ------------------- j = 3 (k = i+3) -------------------
+        B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4);
+        B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1);
+
+        half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale;
+
+        c0 += B * wj3.s0;
+        c1 += B * wj3.s1;
+        c2 += B * wj3.s2;
+        c3 += B * wj3.s3;
+    }
+
+    int idx = (gy << 3) * m + (gx << 2);
+
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl
new file mode 100644
index 00000000..6fe828f2
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl
@@ -0,0 +1,219 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_subgroups
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#define QK4_1                   32
+
+struct block_q4_1 {
+    half d; // delta
+    half m; // min
+    uchar qs[QK4_1 / 2]; // nibbles / quants
+};
+
+inline float block_q4_1_dot_y(
+    global const struct block_q4_1 * qb_curr,
+    float sumy,
+    float16 yl,
+    int il
+) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
+
+    global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
+
+    acc.s0 += yl.s0 * (qs[0] & 0x000F);
+    acc.s0 += yl.s1 * (qs[0] & 0x0F00);
+    acc.s0 += yl.s8 * (qs[0] & 0x00F0);
+    acc.s3 += yl.s9 * (qs[0] & 0xF000);
+
+    acc.s0 += yl.s2 * (qs[1] & 0x000F);
+    acc.s1 += yl.s3 * (qs[1] & 0x0F00);
+    acc.s2 += yl.sa * (qs[1] & 0x00F0);
+    acc.s3 += yl.sb * (qs[1] & 0xF000);
+
+    acc.s0 += yl.s4 * (qs[2] & 0x000F);
+    acc.s1 += yl.s5 * (qs[2] & 0x0F00);
+    acc.s2 += yl.sc * (qs[2] & 0x00F0);
+    acc.s3 += yl.sd * (qs[2] & 0xF000);
+
+    acc.s0 += yl.s6 * (qs[3] & 0x000F);
+    acc.s1 += yl.s7 * (qs[3] & 0x0F00);
+    acc.s2 += yl.se * (qs[3] & 0x00F0);
+    acc.s3 += yl.sf * (qs[3] & 0xF000);
+
+    return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each subgroup works on 4 rows
+#define N_SIMDGROUP 1 // number of subgroups in a thread group
+#define N_SIMDWIDTH 16 // assuming subgroup size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32(
+        global void * src0,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_1;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
+    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix * QK4_1 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
+        sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
+        sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
+        sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
+
+        yb += QK4_1 * (N_SIMDWIDTH/2);
+    }
+
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q4_1_f32(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl
new file mode 100644
index 00000000..d7c4645d
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl
@@ -0,0 +1,229 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_subgroups
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+#define QK4_1                   32
+
+struct block_q4_1 {
+    half d; // delta
+    half m; // min
+    uchar qs[QK4_1 / 2]; // nibbles / quants
+};
+
+inline float block_q4_1_dot_y_flat(
+    global const uchar * x,
+    global const half  * dh,
+    global const half  * mh,
+    float sumy,
+    float16 yl,
+    int il
+) {
+    float                 d   = *dh;
+    float                 m   = *mh;
+    global const ushort * qs = ((global const ushort *) x + il/2);
+
+    float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
+
+    acc.s0 += yl.s0 * (qs[0] & 0x000F);
+    acc.s0 += yl.s1 * (qs[0] & 0x0F00);
+    acc.s0 += yl.s8 * (qs[0] & 0x00F0);
+    acc.s3 += yl.s9 * (qs[0] & 0xF000);
+
+    acc.s0 += yl.s2 * (qs[1] & 0x000F);
+    acc.s1 += yl.s3 * (qs[1] & 0x0F00);
+    acc.s2 += yl.sa * (qs[1] & 0x00F0);
+    acc.s3 += yl.sb * (qs[1] & 0xF000);
+
+    acc.s0 += yl.s4 * (qs[2] & 0x000F);
+    acc.s1 += yl.s5 * (qs[2] & 0x0F00);
+    acc.s2 += yl.sc * (qs[2] & 0x00F0);
+    acc.s3 += yl.sd * (qs[2] & 0xF000);
+
+    acc.s0 += yl.s6 * (qs[3] & 0x000F);
+    acc.s1 += yl.s7 * (qs[3] & 0x0F00);
+    acc.s2 += yl.se * (qs[3] & 0x00F0);
+    acc.s3 += yl.sf * (qs[3] & 0xF000);
+
+    return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each subgroup works on 4 rows
+#define N_SIMDGROUP 1 // number of subgroups in a thread group
+#define N_SIMDWIDTH 16 // assuming subgroup size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_flat(
+        global void * src0_q,
+        global void * src0_d,
+        global void * src0_m,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_1;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    // The number of scales/mins is the same as the number of blocks.
+    ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
+    // Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q  = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_dm;
+    global half  * m = (global half  *) src0_m + offset0_dm;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix * QK4_1 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
+
+        yb += QK4_1 * (N_SIMDWIDTH/2);
+    }
+
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q4_1_f32_flat(
+        global void * src0_q,
+        global void * src0_d,
+        global void * src0_m,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl
new file mode 100644
index 00000000..71ab9898
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl
@@ -0,0 +1,180 @@
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+//------------------------------------------------------------------------------
+// block_q4_K
+//------------------------------------------------------------------------------
+#define QK_K            256
+#define K_SCALE_SIZE    12
+
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+typedef struct {
+    half d;    // super-block scale for quantized scales
+    half dmin; // super-block scale for quantized mins
+
+    uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uchar qs[QK_K/2];           // 4-bit quants
+} block_q4_K;
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // number of rows each SIMD group works on
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // SIMD group size
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+#undef  BLOCK_STRIDE
+// number of (super) blocks each subgroup processes
+// each thread in a subgroup processes a block (32 weights)
+#define BLOCK_STRIDE (N_SIMDWIDTH/8)
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q4_K_f32(
+        global char * src0,
+        int offset0,
+        global char * src1,
+        int offset1,
+        global char * dst,
+        int offsetd,
+        int ne00,
+        int ne01,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne12,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
+
+    ushort kmask1 = 0x3f3f;
+    ushort kmask2 = 0x0f0f;
+    ushort kmask3 = 0xc0c0;
+
+    int ix = get_sub_group_local_id()/8;  // super block index
+    int it = get_sub_group_local_id()%8;  // block index (inside super block)
+    int iq = it/4;     // 0 or 1 - first or second half of the super block
+    int ir = it%4;     // 0...3 - block index in the half super block
+
+    int nb = ne00/QK_K;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    int offset_src1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
+    global float      * y = (global float      *) (src1 + offset_src1);
+
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST] = {0.f};
+    float all_sum;
+
+    global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+    ushort  sc16[4];
+    uchar * sc8 = (uchar *)sc16;
+
+    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+0];
+            sumy.s0 += yl[i+0];
+
+            yl[i+8] = y4[i+32];
+            sumy.s1 += yl[i+8];
+
+            yh[i+0] = y4[i+128];
+            sumy.s2 += yh[i+0];
+
+            yh[i+8] = y4[i+160];
+            sumy.s3 += yh[i+8];
+        }
+
+        global ushort * sc = (global ushort *)x[ib].scales + iq;
+        global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
+        global half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            global ushort * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
+                                 (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
+                                 (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
+                                 (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
+
+            q1 += nb01/2;
+            sc += nb01/2;
+            dh += nb01/2;
+        }
+
+        y4 += BLOCK_STRIDE * QK_K;
+    }
+
+    global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = sub_group_reduce_add(sumf[row]);
+        if (first_row + row < ne01) {
+            if (get_sub_group_local_id() == 0) {
+                dst_f32[first_row + row] = all_sum;
+            }
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl
similarity index 99%
rename from ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl
rename to ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl
index 8a17b9aa..819e5192 100644
--- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl
@@ -111,6 +111,10 @@ kernel void kernel_mul_mv_q6_K_f32(
 
     int row = N_SIMDGROUP * r0 + get_sub_group_id();
 
+    if (row >= ne01) {
+        return;
+    }
+
     int i12 = im%ne12;
     int i13 = im/ne12;
 
diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl
new file mode 100644
index 00000000..86fe09c6
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl
@@ -0,0 +1,194 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+#ifdef cl_intel_subgroups
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#endif
+
+//------------------------------------------------------------------------------
+// kernel_mul_mv_q6_K_f32_flat
+//------------------------------------------------------------------------------
+#define Q6_K_MASK1 0x03
+#define Q6_K_MASK2 0x0C
+#define Q6_K_MASK3 0x30
+#define Q6_K_MASK4 0xC0
+
+#define QK_K       256
+
+inline float block_q_6_K_dot_y_flat(
+    global uchar * blk_ql,
+    global uchar * blk_qh,
+    global char  * blk_scales,
+    global half  * blk_d,
+    global float * yy,
+    int ib,
+    int ip,
+    int is,
+    int l0
+) {
+    int y_offset   = 128*ip + l0;
+    int q_offset_l =  64*ip + l0;
+    int q_offset_h =  32*ip + l0;
+
+    global uchar * q1 = blk_ql     + ib*128 + q_offset_l;
+    global uchar * q2 = q1         + QK_K/8;
+    global uchar * qh = blk_qh     + ib*64 + q_offset_h;
+    global char  * sc = blk_scales + ib*16 + is;
+
+    global float * y = yy + ib * QK_K + y_offset;
+
+    float dall = blk_d[ib];
+
+    float  sumf = 0;
+    float4 sums = {0.f, 0.f, 0.f, 0.f};
+
+    sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f);
+    sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f);
+    sums.s2 += y[0+64] * ((float)((q1[0]  >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f);
+    sums.s3 += y[0+96] * ((float)((q2[0]  >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f);
+
+    sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f);
+    sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f);
+    sums.s2 += y[1+64] * ((float)((q1[1]  >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f);
+    sums.s3 += y[1+96] * ((float)((q2[1]  >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f);
+
+    sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f);
+    sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f);
+    sums.s2 += y[2+64] * ((float)((q1[2]  >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f);
+    sums.s3 += y[2+96] * ((float)((q2[2]  >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f);
+
+    sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f);
+    sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f);
+    sums.s2 += y[3+64] * ((float)((q1[3]  >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f);
+    sums.s3 += y[3+96] * ((float)((q2[3]  >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f);
+
+    sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
+
+    return sumf;
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4
+#define N_SIMDGROUP 2
+#define N_SIMDWIDTH 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 2
+#define N_SIMDWIDTH 64
+#endif
+
+#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q6_K_f32_flat(
+        global uchar * src0_ql,
+        global uchar * src0_qh,
+        global char  * src0_s,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int nb = ne00/QK_K;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST;
+
+    ulong offset_src0    = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    ulong offset_src0_ql = offset_src0 * 128;
+    ulong offset_src0_qh = offset_src0 * 64;
+    ulong offset_src0_s  = offset_src0 * 16;
+    ulong offset_src0_d  = offset_src0;
+
+    global uchar * blk_ql     = (global uchar *) src0_ql + offset_src0_ql;
+    global uchar * blk_qh     = (global uchar *) src0_qh + offset_src0_qh;
+    global char  * blk_scales = (global char  *) src0_s  + offset_src0_s;
+    global half  * blk_d      = (global half  *) src0_d  + offset_src0_d;
+    global float * yy         = (global float *) src1    + r1*ne10 + im*ne00*ne1;
+
+    int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
+    int ix  = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
+    int ip  = tid/8;   // first or second half of (super) block (0 or 1)
+    int il  = tid%8;   // each half has 8 parts, one per scale
+    int n   = 4;       // 4 scales at a time (and 4 sums)
+    int l0  = n*il;    // offset into half-block, 0..28
+    int is  = 8*ip + l0/16; // 0, 1, 8, 9
+
+    float4 sumf = 0;
+
+    for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
+        if (first_row + 0 < ne01) {
+            sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0);
+        }
+        if (first_row + 1 < ne01) {
+            sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0);
+        }
+        if (first_row + 2 < ne01) {
+            sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0);
+        }
+        if (first_row + 3 < ne01) {
+            sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0);
+        }
+    }
+
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0),
+        sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2),
+        sub_group_reduce_add(sumf.s3)
+    );
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/neg.cl b/ggml/src/ggml-opencl/kernels/neg.cl
new file mode 100644
index 00000000..a862d8bc
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/neg.cl
@@ -0,0 +1,125 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+kernel void kernel_neg_f32(
+        global const float * src0,
+        ulong                offset0,
+        global       float * dst,
+        ulong                offsetd,
+        int                  n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst  = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = -src0[get_global_id(0)];
+}
+
+kernel void kernel_neg_f32_4(
+        global const float4 * src0,
+        ulong                 offset0,
+        global       float4 * dst,
+        ulong                 offsetd,
+        int                   n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst  = (global float4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = -src0[get_global_id(0)];
+}
+
+kernel void kernel_neg_f16(
+        global const half * src0,
+        ulong               offset0,
+        global       half * dst,
+        ulong               offsetd,
+        int                 n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst  = (global half*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = -src0[get_global_id(0)];
+}
+
+kernel void kernel_neg_f16_4(
+        global const half4 * src0,
+        ulong                offset0,
+        global       half4 * dst,
+        ulong                offsetd,
+        int                  n
+) {
+    if (get_global_id(0) >= n) {
+        return;
+    }
+    src0 = (global half4*)((global char*)src0 + offset0);
+    dst  = (global half4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = -src0[get_global_id(0)];
+}
+
+kernel void kernel_neg_f32_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+        *y = -*x;
+    }
+}
+
+kernel void kernel_neg_f16_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+        *y = -*x;
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl
index 079498f5..53951a55 100644
--- a/ggml/src/ggml-opencl/kernels/repeat.cl
+++ b/ggml/src/ggml-opencl/kernels/repeat.cl
@@ -1,39 +1,38 @@
-kernel void kernel_repeat(
-    global const char * src0_data_in,
-    global       char * dst_data_in,
-    ulong src0_offset,
-    ulong dst_offset,
-    int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3,
-    ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3,
-    int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3,
-    ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3
+kernel void kernel_repeat_f32(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int     ne00,
+        int     ne01,
+        int     ne02,
+        int     ne03,
+        ulong   nb00,
+        ulong   nb01,
+        ulong   nb02,
+        ulong   nb03,
+        int     ne0,
+        ulong   nb0,
+        ulong   nb1,
+        ulong   nb2,
+        ulong   nb3
 ) {
-    global const char * src0_data = src0_data_in + src0_offset;
-    global       char * dst_data  = dst_data_in + dst_offset;
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
 
-    const int d3 = get_global_id(2);
-    const int d2 = get_global_id(1);
-    const int d1 = get_global_id(0);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-    if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) {
-        return;
-    }
+    const int i03 = i3%ne03;
+    const int i02 = i2%ne02;
+    const int i01 = i1%ne01;
 
-    const int s3 = d3 % src0_ne3;
-    const int s2 = d2 % src0_ne2;
-    const int s1 = d1 % src0_ne1;
+    global const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1;
 
-    const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1;
-    global char * p_dst_slice  = dst_data  + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1;
-
-    for (int d0 = 0; d0 < dst_ne0; ++d0) {
-        // Determine source index for dimension 0 based on tiling/broadcasting.
-        const int s0 = d0 % src0_ne0;
-
-        const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0;
-        global char * restrict current_dst_el_ptr  = p_dst_slice  + (ulong)d0*dst_nb0;
-        for (int k = 0; k < src0_nb0; ++k) {
-            current_dst_el_ptr[k] = current_src_el_ptr[k];
-        }
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i00 = i0%ne00;
+        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i00*nb00));
     }
 }
diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl
index aeca8a45..17ed97f0 100644
--- a/ggml/src/ggml-opencl/kernels/scale.cl
+++ b/ggml/src/ggml-opencl/kernels/scale.cl
@@ -1,9 +1,19 @@
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 
-//------------------------------------------------------------------------------
-// scale
-//------------------------------------------------------------------------------
-kernel void kernel_scale(
+kernel void kernel_scale_f32(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        float scale,
+        float bias
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+    dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
+}
+
+kernel void kernel_scale_f32_4(
         global float4 * src0,
         ulong offset0,
         global float4 * dst,
diff --git a/ggml/src/ggml-opencl/kernels/softplus.cl b/ggml/src/ggml-opencl/kernels/softplus.cl
index 033766e2..6f8b7474 100644
--- a/ggml/src/ggml-opencl/kernels/softplus.cl
+++ b/ggml/src/ggml-opencl/kernels/softplus.cl
@@ -3,86 +3,114 @@
 //------------------------------------------------------------------------------
 // softplus
 //------------------------------------------------------------------------------
-inline float softplus_f32(float x){
-    float ax = fabs(x);
-    float m = fmax(x, 0.0f);
-    return log1p(exp(-ax)) + m;
+
+kernel void kernel_softplus_f32(
+        global const float * src0,
+        ulong                offset0,
+        global       float * dst,
+        ulong                offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst  = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
 }
 
-kernel void kernel_softplus_f32_nd(
-        global void * p_src0_base,
-        ulong off_src0_abs,
-        global void * p_dst_base,
-        ulong off_dst_abs,
-        int ne00,
-        int ne01,
-        int ne02,
-        int ne03,
+kernel void kernel_softplus_f32_4(
+        global const float4 * src0,
+        ulong                 offset0,
+        global       float4 * dst,
+        ulong                 offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst  = (global float4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
+}
+
+kernel void kernel_softplus_f16(
+        global const half * src0,
+        ulong               offset0,
+        global       half * dst,
+        ulong               offsetd
+) {
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst  = (global half*)((global char*)dst + offsetd);
+
+    const float x = convert_float(src0[get_global_id(0)]);
+    dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
+}
+
+kernel void kernel_softplus_f16_4(
+        global const half4 * src0,
+        ulong                offset0,
+        global       half4 * dst,
+        ulong                offsetd
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    dst  = (global half4*)((global char*)dst + offsetd);
+
+    const float4 x = convert_float4(src0[get_global_id(0)]);
+    dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
+}
+
+kernel void kernel_softplus_f32_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
         ulong nb00,
         ulong nb01,
         ulong nb02,
         ulong nb03,
-        int ne10,
-        int ne11,
-        int ne12,
-        int ne13,
-        ulong nb10,
-        ulong nb11,
-        ulong nb12,
-        ulong nb13
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
 ) {
-    int i0 = get_global_id(0);
-    int i1 = get_global_id(1);
-    int i2 = get_global_id(2);
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
 
-    if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
-        for (int i3 = 0; i3 < ne13; ++i3) {
-            ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
-            global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-            ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
-            global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-            *dst_val_ptr = softplus_f32(*src_val_ptr);
-        }
+        *y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
     }
 }
 
-kernel void kernel_softplus_f16_nd(
-        global void * p_src0_base,
-        ulong off_src0_abs,
-        global void * p_dst_base,
-        ulong off_dst_abs,
-        int ne00,
-        int ne01,
-        int ne02,
-        int ne03,
+kernel void kernel_softplus_f16_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
         ulong nb00,
         ulong nb01,
         ulong nb02,
         ulong nb03,
-        int ne10,
-        int ne11,
-        int ne12,
-        int ne13,
-        ulong nb10,
-        ulong nb11,
-        ulong nb12,
-        ulong nb13
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
 ) {
-    int i0 = get_global_id(0);
-    int i1 = get_global_id(1);
-    int i2 = get_global_id(2);
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
 
-    if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
-        for (int i3 = 0; i3 < ne13; ++i3) {
-            ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
-            global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-            ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
-            global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       half * hy = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-            *dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
-        }
+        const float x = convert_float(*hx);
+        *hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
     }
 }
diff --git a/ggml/src/ggml-opencl/kernels/solve_tri.cl b/ggml/src/ggml-opencl/kernels/solve_tri.cl
new file mode 100644
index 00000000..80745fc7
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/solve_tri.cl
@@ -0,0 +1,51 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+//------------------------------------------------------------------------------
+// solve_tri
+//------------------------------------------------------------------------------
+kernel void kernel_solve_tri_f32(
+        global uchar * src0,
+        ulong offset0,
+        global uchar * src1,
+        ulong offset1,
+        global uchar * dst,
+        ulong offsetd,
+        int n,
+        int k,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    int col = get_global_id(0);
+    int i2 = get_global_id(1);
+    int i3 = get_global_id(2);
+
+    global const uchar * Lb = src0 + offset0 + i2 * nb02 + i3 * nb03;
+    global const uchar * Bb = src1 + offset1 + i2 * nb12 + i3 * nb13;
+    global       uchar * Xb = dst + offsetd + i2 * nb2 + i3 * nb3;
+
+    for(int row = 0; row < n; ++row){
+        global const float *pB = (global const float *)(Bb + row * nb11 + col * nb10);
+
+        float sum = 0.0f;
+        for(int j = 0; j < row; ++j){
+            global const float *pL = (global const float *)(Lb + row * nb01 + j * nb00);
+            global const float *pX = (global const float *)(Xb + j * nb1 + col * nb0);
+            sum += (*pL) * (*pX);
+        }
+
+        global const float * pDiag = (global const float *)(Lb + row * nb01 + row *nb00);
+        global float * pOut = (global float *)(Xb + row * nb1 + col *nb0);
+
+        *pOut = ((* pB) - sum) / (*pDiag);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl
index c5f7c570..84630aa8 100644
--- a/ggml/src/ggml-opencl/kernels/sum_rows.cl
+++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl
@@ -1,8 +1,13 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
 
+// Most devices have max workgroup size of 1024, so this is enough for subgroup
+// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes
+#define MAX_SUBGROUPS 64
 kernel void kernel_sum_rows_f32(
-    global float *  src0,
+    global char *  src0,
     ulong           offset0,
-    global float *  dst,
+    global char *  dst,
     ulong           offsetd,
     int             ne00,
     int             ne01,
@@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32(
     ulong           nb2,
     ulong           nb3
 ) {
-    src0 = (global float *)((global char *)src0 + offset0);
-    dst  = (global float *)((global char *)dst  + offsetd);
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
 
-    int i3 = get_global_id(2);
-    int i2 = get_global_id(1);
-    int i1 = get_global_id(0);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
 
     if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
         return;
     }
 
-    global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    global float * dst_row = (global float *) ((global char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
-
-    float row_sum = 0;
-
-    for (int i0 = 0; i0 < ne00; i0++) {
-        row_sum += src_row[i0];
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
     }
 
-    dst_row[0] = row_sum;
+    global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float * dst_row = (global float *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float sumf = 0.0f;
+
+    for (int i0 = lid; i0 < ne00; i0 += lsize) {
+        sumf += src_row[i0];
+    }
+
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf;
+    }
+}
+
+kernel void kernel_sum_rows_f32_4(
+    global char *  src0,
+    ulong           offset0,
+    global char *  dst,
+    ulong           offsetd,
+    int             ne00,
+    int             ne01,
+    int             ne02,
+    int             ne03,
+    ulong           nb01,
+    ulong           nb02,
+    ulong           nb03,
+    ulong           nb1,
+    ulong           nb2,
+    ulong           nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst  + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    const int lid = get_local_id(0);
+    const int lsize = get_local_size(0);
+
+    const uint sg_size = get_sub_group_size();
+    const uint sg_id = get_sub_group_id();
+    const uint sg_lid = get_sub_group_local_id();
+
+    __local float lmem[MAX_SUBGROUPS];
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    if(sg_id == 0){
+        lmem[sg_lid] = 0.0f;
+    }
+
+    global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    global float  * dst_row = (global float  *) (dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float4 sum_vec = (float4)0.0f;
+
+    for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) {
+        sum_vec += src_row[i0];
+    }
+
+    float sumf = dot(sum_vec, (float4)(1.0f));
+    sumf = sub_group_reduce_add(sumf);
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    if(sg_lid == 0){
+        lmem[sg_id] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    sumf = lmem[sg_lid];
+    sumf = sub_group_reduce_add(sumf);
+
+    if (lid == 0) {
+        dst_row[0] = sumf;
+    }
 }
diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl
index d9da86b1..2c4887ad 100644
--- a/ggml/src/ggml-opencl/kernels/tanh.cl
+++ b/ggml/src/ggml-opencl/kernels/tanh.cl
@@ -1,63 +1,109 @@
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 
-#ifdef cl_intel_required_subgroup_size
-#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
-#define INTEL_GPU 1
-#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
-#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
-#elif defined(cl_qcom_reqd_sub_group_size)
-#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
-#define ADRENO_GPU 1
-#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
-#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
-#endif
-
-kernel void kernel_tanh_f32_nd(
-    global void * p_src0_base, ulong off_src0_abs,
-    global void * p_dst_base,  ulong off_dst_abs,
-    int ne00, int ne01, int ne02, int ne03,
-    ulong nb00, ulong nb01, ulong nb02, ulong nb03,
-    int ne10, int ne11, int ne12, int ne13,
-    ulong nb10, ulong nb11, ulong nb12, ulong nb13
+kernel void kernel_tanh_f32(
+        global const float * src0,
+        ulong                offset0,
+        global       float * dst,
+        ulong                offsetd
 ) {
-    int i0 = get_global_id(0);
-    int i1 = get_global_id(1);
-    int i2 = get_global_id(2);
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst  = (global float*)((global char*)dst + offsetd);
 
-    if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
-        for (int i3 = 0; i3 < ne13; ++i3) {
-            ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
-            global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
+    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
+}
 
-            ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
-            global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
+kernel void kernel_tanh_f32_4(
+        global const float4 * src0,
+        ulong                 offset0,
+        global       float4 * dst,
+        ulong                 offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst  = (global float4*)((global char*)dst + offsetd);
 
-            *dst_val_ptr = tanh(*src_val_ptr);
-        }
+    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
+}
+
+kernel void kernel_tanh_f16(
+        global const half * src0,
+        ulong               offset0,
+        global       half * dst,
+        ulong               offsetd
+) {
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst  = (global half*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
+}
+
+kernel void kernel_tanh_f16_4(
+        global const half4 * src0,
+        ulong                offset0,
+        global       half4 * dst,
+        ulong                offsetd
+) {
+    src0 = (global half4*)((global char*)src0 + offset0);
+    dst  = (global half4*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
+}
+
+kernel void kernel_tanh_f32_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
+
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
+
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       float * y = (global       float *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+        *y = tanh(*x);
     }
 }
 
-kernel void kernel_tanh_f16_nd(
-    global void * p_src0_base, ulong off_src0_abs,
-    global void * p_dst_base,  ulong off_dst_abs,
-    int ne00, int ne01, int ne02, int ne03,
-    ulong nb00, ulong nb01, ulong nb02, ulong nb03,
-    int ne10, int ne11, int ne12, int ne13,
-    ulong nb10, ulong nb11, ulong nb12, ulong nb13
+kernel void kernel_tanh_f16_nc(
+        global const char * src0,
+        ulong               offset0,
+        global       char * dst,
+        ulong               offsetd,
+        int   ne00,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
 ) {
-    int i0 = get_global_id(0);
-    int i1 = get_global_id(1);
-    int i2 = get_global_id(2);
+    src0 = src0 + offset0;
+    dst  = dst + offsetd;
 
-    if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
-        for (int i3 = 0; i3 < ne13; ++i3) {
-            ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
-            global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
+    const int i3 = get_group_id(2);
+    const int i2 = get_group_id(1);
+    const int i1 = get_group_id(0);
 
-            ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
-            global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
+    for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
+        global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+        global       half * y = (global       half *)(dst  + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
-            *dst_val_ptr = tanh(*src_val_ptr);
-        }
+        *y = tanh(*x);
     }
 }
diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl
index 1279b653..ad89bdcb 100644
--- a/ggml/src/ggml-opencl/kernels/transpose.cl
+++ b/ggml/src/ggml-opencl/kernels/transpose.cl
@@ -44,6 +44,19 @@ kernel void kernel_transpose_16_4x1(
     write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));
 }
 
+// Transpose treating each element as 8-bit using buffer
+kernel void kernel_transpose_8_buf(
+    global const uchar * input,
+    global uchar * output,
+    const int ldi,
+    const int ldo
+) {
+    const int x = get_global_id(0);
+    const int y = get_global_id(1);
+
+    output[x*ldo + y] = input[y*ldi + x];
+}
+
 // Transpose treating each element as 16-bit using buffer
 kernel void kernel_transpose_16_buf(
     global const ushort * input,
@@ -57,6 +70,19 @@ kernel void kernel_transpose_16_buf(
     output[x*ldo + y] = input[y*ldi + x];
 }
 
+// Transpose treating each element as 32-bit using buffer
+kernel void kernel_transpose_32_buf(
+    global const uint * input,
+    global uint * output,
+    const int ldi,
+    const int ldo
+) {
+    const int x = get_global_id(0);
+    const int y = get_global_id(1);
+
+    output[x*ldo + y] = input[y*ldi + x];
+}
+
 // 32-bit transpose, loading/storing a 4x4 tile of elements
 kernel void kernel_transpose_32(
     __read_only image1d_buffer_t input,
diff --git a/ggml/src/ggml-opencl/kernels/tri.cl b/ggml/src/ggml-opencl/kernels/tri.cl
new file mode 100644
index 00000000..35cdd543
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/tri.cl
@@ -0,0 +1,32 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+//------------------------------------------------------------------------------
+// tri
+//------------------------------------------------------------------------------
+__kernel void kernel_tri_f32(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int n,
+        int ne0,
+        int ne1,
+        int tri_type
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int idx = get_global_id(0);
+    if (idx >= n) return;
+
+    int i0 = idx % ne0;
+    int i1 = (idx / ne0) % ne1;
+
+    int keep = 0;
+    if (tri_type == 0) keep = (i0 >= i1);
+    else if (tri_type == 1) keep = (i0 >  i1);
+    else if (tri_type == 2) keep = (i0 <= i1);
+    else                    keep = (i0 <  i1);
+
+    dst[idx] = keep ? src0[idx] : 0.0f;
+}
diff --git a/ggml/src/ggml-openvino/.clang-format b/ggml/src/ggml-openvino/.clang-format
new file mode 100644
index 00000000..a2a24d7d
--- /dev/null
+++ b/ggml/src/ggml-openvino/.clang-format
@@ -0,0 +1,154 @@
+---
+# Override root .clang-format
+AlignConsecutiveAssignments: false
+AlignConsecutiveDeclarations: false
+Cpp11BracedListStyle: true
+SpacesInContainerLiterals: false
+BreakBeforeBraces: Attach
+AccessModifierOffset: -4
+IndentCaseBlocks: false
+IndentCaseLabels: false
+
+Language:        Cpp
+AlignAfterOpenBracket: Align
+AlignArrayOfStructures: Left
+AlignConsecutiveBitFields: AcrossComments
+AlignConsecutiveMacros: AcrossComments
+# AlignConsecutiveShortCaseStatements: AcrossComments
+AlignEscapedNewlines: Left # LeftWithLastLine
+AlignOperands:   Align
+AlignTrailingComments:
+  Kind: Always
+  OverEmptyLines: 1
+AllowAllArgumentsOnNextLine: true
+AllowAllParametersOfDeclarationOnNextLine: false
+# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen
+AllowShortBlocksOnASingleLine: Never
+AllowShortCaseLabelsOnASingleLine: false
+AllowShortFunctionsOnASingleLine: Inline
+AllowShortIfStatementsOnASingleLine: Never
+AllowShortLambdasOnASingleLine: Inline
+AllowShortLoopsOnASingleLine: false
+AlwaysBreakBeforeMultilineStrings: true
+# Treat CUDA keywords/attributes as "attribute macros" and avoid breaking lines inside them
+AttributeMacros:
+  - __host__
+  - __device__
+  - __global__
+  - __forceinline__
+  - __launch_bounds__
+BinPackArguments: true
+BinPackParameters: false # OnePerLine
+BitFieldColonSpacing: Both
+# BreakAdjacentStringLiterals: true
+BreakAfterAttributes: Never
+BreakBeforeBinaryOperators: None
+BreakBeforeInlineASMColon: OnlyMultiline
+BreakBeforeTernaryOperators: false
+# BreakBinaryOperations: Never
+BreakConstructorInitializers: AfterColon
+# BreakFunctionDefinitionParameters: false
+BreakInheritanceList: AfterComma
+BreakStringLiterals: true
+# BreakTemplateDeclarations: Yes
+ColumnLimit:     120
+CommentPragmas:  '^ IWYU pragma:'
+CompactNamespaces: false
+ConstructorInitializerIndentWidth: 4
+ContinuationIndentWidth: 4
+DerivePointerAlignment: false
+DisableFormat:   false
+EmptyLineBeforeAccessModifier: Leave
+EmptyLineAfterAccessModifier: Never
+ExperimentalAutoDetectBinPacking: false
+FixNamespaceComments: true
+IncludeBlocks:   Regroup
+IncludeCategories:
+  - Regex:           '".*"'
+    Priority:        1
+    SortPriority:    0
+  - Regex:           '^<.*\.h>'
+    Priority:        2
+    SortPriority:    0
+  - Regex:           '^<.*'
+    Priority:        3
+    SortPriority:    0
+  - Regex:           '.*'
+    Priority:        4
+    SortPriority:    0
+IncludeIsMainRegex: '([-_](test|unittest))?$'
+IncludeIsMainSourceRegex: ''
+IndentAccessModifiers: false
+IndentExternBlock: NoIndent
+IndentGotoLabels: false
+IndentPPDirectives: AfterHash
+IndentWidth:     4
+IndentWrappedFunctionNames: false
+InsertBraces:    true # NOTE: may lead to incorrect formatting
+InsertNewlineAtEOF: true
+JavaScriptQuotes: Leave
+JavaScriptWrapImports: true
+KeepEmptyLinesAtTheStartOfBlocks: false
+LambdaBodyIndentation: Signature
+LineEnding: LF
+MacroBlockBegin: ''
+MacroBlockEnd:   ''
+MaxEmptyLinesToKeep: 1
+NamespaceIndentation: None
+ObjCBinPackProtocolList: Auto
+ObjCBlockIndentWidth: 4
+ObjCSpaceAfterProperty: true
+ObjCSpaceBeforeProtocolList: true
+PPIndentWidth: -1
+PackConstructorInitializers: CurrentLine
+PenaltyBreakAssignment: 2
+PenaltyBreakBeforeFirstCallParameter: 1
+PenaltyBreakComment: 300
+PenaltyBreakFirstLessLess: 120
+PenaltyBreakString: 1000
+PenaltyBreakTemplateDeclaration: 10
+PenaltyExcessCharacter: 1000000
+PenaltyReturnTypeOnItsOwnLine: 200
+PointerAlignment: Middle
+QualifierAlignment: Left
+#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']
+RawStringFormats:
+  - Language:        Cpp
+    Delimiters:
+      - cc
+      - CC
+      - cpp
+      - Cpp
+      - CPP
+      - 'c++'
+      - 'C++'
+    CanonicalDelimiter: ''
+ReferenceAlignment: Middle
+ReflowComments:  false # IndentOnly
+SeparateDefinitionBlocks: Always
+SortIncludes:    CaseInsensitive
+SortUsingDeclarations: LexicographicNumeric
+SpaceAfterCStyleCast: true
+SpaceAfterLogicalNot: false
+SpaceAfterTemplateKeyword: true
+SpaceBeforeAssignmentOperators: true
+SpaceBeforeCpp11BracedList: false
+SpaceBeforeCtorInitializerColon: true
+SpaceBeforeInheritanceColon: true
+SpaceBeforeParens: ControlStatements
+SpaceBeforeRangeBasedForLoopColon: true
+SpaceInEmptyBlock: false
+SpaceInEmptyParentheses: false
+SpacesBeforeTrailingComments: 2
+SpacesInAngles:  Never
+SpacesInLineCommentPrefix:
+  Minimum: 1
+  Maximum: -1
+SpacesInParentheses: false
+SpacesInSquareBrackets: false
+SpaceBeforeSquareBrackets: false
+Standard:        c++17
+TabWidth:        4
+UseTab:          Never
+WhitespaceSensitiveMacros: ['STRINGIZE']
+...
diff --git a/ggml/src/ggml-openvino/CMakeLists.txt b/ggml/src/ggml-openvino/CMakeLists.txt
new file mode 100644
index 00000000..175b5856
--- /dev/null
+++ b/ggml/src/ggml-openvino/CMakeLists.txt
@@ -0,0 +1,22 @@
+find_package(OpenVINO REQUIRED)
+find_package(OpenCL REQUIRED)
+
+include("${OpenVINO_DIR}/../3rdparty/tbb/lib/cmake/TBB/TBBConfig.cmake")
+
+file(GLOB_RECURSE GGML_HEADERS_OPENVINO "*.h" "*.hpp")
+file(GLOB_RECURSE GGML_SOURCES_OPENVINO "*.cpp")
+
+ggml_add_backend_library(ggml-openvino
+    ${GGML_SOURCES_OPENVINO}
+    ${GGML_HEADERS_OPENVINO}
+)
+
+target_link_libraries(ggml-openvino PRIVATE openvino::runtime TBB::tbb OpenCL::OpenCL)
+
+if (GGML_OPENVINO)
+    if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
+    elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64")
+    else()
+        message(FATAL_ERROR "OpenVINO: OpenVINO toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}")
+    endif()
+endif()
diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp
new file mode 100644
index 00000000..0938d227
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-decoder.cpp
@@ -0,0 +1,975 @@
+#include "ggml-decoder.h"
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-openvino-extra.h"
+#include "ggml-openvino.h"
+#include "ggml-quants.h"
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
+                             ModelParams & model_params,
+                             ComputeParams & compute_params,
+                             std::map> & model_weights,
+                             bool is_static,
+                             bool is_stateful,
+                             bool is_prefill,
+                             int prefill_chunk_size) :
+    m_is_static(is_static),
+    m_is_stateful(is_stateful),
+    m_is_prefill(is_prefill),
+    m_naive(false),
+    m_prefill_chunk_size(prefill_chunk_size),
+    m_cgraph(cgraph),
+    m_model_weights(model_weights),
+    m_model_params(model_params),
+    m_compute_params(compute_params) {
+    if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
+#ifdef _WIN32
+        _putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", "");
+#else
+        unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
+#endif
+        print_tensor_address_map(cgraph);
+    }
+
+    validate_cgraph();
+
+    set_input_output();
+    compute_model_inputs();
+    compute_model_outputs();
+
+    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+        m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node);
+        m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node);
+    }
+
+    add_extra_inputs();
+}
+
+void GgmlOvDecoder::update_io(ggml_cgraph * cgraph) {
+    m_cgraph = cgraph;
+    m_model_inputs.clear();
+    m_model_outputs.clear();
+    m_node_info_list.clear();
+    set_input_output();
+    compute_model_inputs();
+    compute_model_outputs();
+}
+
+GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map> & model_weights) {
+    m_cgraph = cgraph;
+    m_model_weights = model_weights;
+    m_naive = true;
+    set_input_output();
+    compute_model_inputs();
+    compute_model_outputs();
+    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+        m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node);
+        m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node);
+    }
+}
+
+void GgmlOvDecoder::set_input_output() {
+    for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) {
+        auto node = m_cgraph->nodes[node_n];
+
+        NodeInfo current_node_info;
+        auto node_name = std::string(node->name);
+        auto node_output_name = node_name;
+        auto * node_output = node;
+        if (node->op == GGML_OP_SET_ROWS) {
+            // SET_ROWS updates the tensor in place. For later ov op that uses the
+            // the view_src of SET_ROWS, we need to make sure they get the updated tensor
+            // by putting the view_src name in the tensor_map in
+            // /src/frontends/ggml/src/translate_session.cpp
+            node_output_name = std::string(node->view_src->name);
+            node_output = node->view_src;
+        }
+
+        current_node_info.node = node;
+        current_node_info.node_name = node_name;
+        current_node_info.node_output = node_output;
+        current_node_info.node_output_name = node_output_name;
+        current_node_info.node_op_case = 0;
+        current_node_info.data_addr = node->data;
+
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            auto * src = node->src[i];
+            if (src == nullptr) {
+                continue;
+            }
+            auto src_name = std::string(src->name);
+            if (src->flags & GGML_TENSOR_FLAG_INPUT) {
+                src_name = get_graph_input_ov_name(src, node);
+            }
+            current_node_info.node_inputs[src_name] = src;
+            current_node_info.node_inputs_names.push_back(src_name);
+        }
+
+        m_node_info_list.push_back(current_node_info);
+    }
+}
+
+int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const {
+    int op_case = 0;
+    switch (node->op) {
+    case GGML_OP_RESHAPE: {
+        auto * src = node->src[0];
+        if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {
+            op_case = 4;
+        } else if (node->ne[0] * node->ne[1] == src->ne[0]) {
+            op_case = 1;
+        } else if (src->ne[0] * src->ne[1] == node->ne[0]) {
+            op_case = 2;
+            if (src->ne[2] * src->ne[3] == node->ne[1]) {
+                op_case = 5;
+            }
+        } else if (src->ne[0] * src->ne[1] == node->ne[1]) {
+            op_case = 3;
+        } else if (src->ne[1] * src->ne[2] == node->ne[1]) {
+            op_case = 6;
+        }
+        break;
+    }
+    case GGML_OP_CONT: {
+        if (node->src[0]->op == GGML_OP_PERMUTE) {
+            op_case = 1;
+        } else if (node->src[0]->op == GGML_OP_TRANSPOSE) {
+            op_case = 2;
+        } else if (node->src[0]->op == GGML_OP_VIEW) {
+            op_case = 3;
+        }
+        break;
+    }
+    case GGML_OP_PERMUTE: {
+        if (node->src[0]->op != GGML_OP_VIEW) {
+            op_case = 1;
+        } else if (node->src[0]->src[0]->op == GGML_OP_NONE) {
+            // kv cache tensor
+            std::string src_name(node->view_src->name);
+            int layer = extract_layer_from_name(src_name);
+            if (!is_swa_layer(layer)) {
+                op_case = 2;
+            } else {
+                op_case = 3;
+            }
+        } else {
+            // rope'ed query tensor
+            op_case = 4;
+        }
+        break;
+    }
+    case GGML_OP_MUL_MAT: {
+        if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) {
+            op_case = 2;
+        } else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) {
+            op_case = 3;
+        }
+        break;
+    }
+    case GGML_OP_GET_ROWS: {
+        if (node->src[1]->op == GGML_OP_VIEW) {
+            op_case = 2;
+        }
+        break;
+    }
+    case GGML_OP_ROPE: {
+        if (node->src[0]->op == GGML_OP_VIEW) {
+            op_case = 2;
+        }
+        break;
+    }
+    case GGML_OP_VIEW: {
+        if (node->src[0]->op == GGML_OP_VIEW) {
+            auto * src = node->src[0];
+            if (ggml_nelements(node) != ggml_nelements(src)) {
+                throw std::runtime_error("Unsupported VIEW case");
+            }
+            op_case = 2;
+        }
+        {
+            auto * src = node->src[0];
+            if ((ggml_nelements(node) != ggml_nelements(src)) && m_naive) {
+                // Compare each dimension of node and src, if only one dimension differs then op_case=3
+                int diff_count = 0;
+                for (int i = 0; i < GGML_MAX_DIMS; i++) {
+                    if (node->ne[i] != src->ne[i]) {
+                        diff_count++;
+                    }
+                }
+                if (diff_count == 1) {
+                    op_case = 3;
+                }
+            }
+        }
+        break;
+    }
+    default:
+        break;
+    }
+    return op_case;
+}
+
+int extract_layer_from_name(const std::string & name) {
+    size_t pos1 = name.find("_l");
+    assert(pos1 != std::string::npos);
+    pos1 += 2;
+    size_t pos2 = name.find(' ', pos1);
+    if (pos2 == std::string::npos) {
+        pos2 = name.length();
+    }
+    std::string layer_str = name.substr(pos1, pos2 - pos1);
+    int layer = std::stoi(layer_str);
+    return layer;
+}
+
+std::pair GgmlOvDecoder::compute_llm_params(ggml_cgraph * cgraph, bool is_static) {
+    ModelParams model_params;
+    ComputeParams compute_params;
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        auto * node = cgraph->nodes[i];
+        std::string name = std::string(node->name);
+        if (node->op == GGML_OP_FLASH_ATTN_EXT) {
+            model_params.n_heads = node->src[0]->ne[2];
+            model_params.n_heads_kv = node->src[1]->ne[2];
+            model_params.head_size = node->src[0]->ne[0];
+            compute_params.input_len = node->src[0]->ne[1];
+
+            auto * cache_k_perm = node->src[1];
+            if (cache_k_perm->op == GGML_OP_CPY) {
+                cache_k_perm = cache_k_perm->src[0];
+            }
+            assert(cache_k_perm->op == GGML_OP_PERMUTE);
+            auto * cache_k_view = cache_k_perm->src[0];
+            assert(cache_k_view->op == GGML_OP_VIEW);
+
+            auto * cache_k = cache_k_view->src[0];
+            int layer = extract_layer_from_name(cache_k->name);
+            auto * mask = node->src[3];
+            std::string mask_name(mask->name);
+
+            model_params.kv_buffer_ctx_id = ggml_backend_openvino_buffer_get_ctx_id(cache_k->buffer);
+            if (mask_name.find("swa") != std::string::npos) {
+                model_params.swa_layers.push_back(layer);
+                model_params.ctx_per_seq_swa = cache_k->ne[1];
+            } else {
+                model_params.ctx_per_seq = cache_k->ne[1];
+                model_params.n_seq = cache_k->ne[2];
+            }
+
+            compute_params.n_seq_active = mask->ne[3];
+            auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type);
+            size_t offset;
+            memcpy(&offset, cache_k_view->op_params, sizeof(size_t));
+            compute_params.seq_active_start = offset / seq_size;
+            compute_params.token_len_per_seq = node->ne[2];
+
+            if (mask_name.find("swa") != std::string::npos) {
+                compute_params.attention_size_swa = mask->ne[0];
+            } else {
+                compute_params.attention_size = mask->ne[0];
+            }
+            if (is_static) {
+                compute_params.attention_size = model_params.ctx_per_seq;
+                compute_params.attention_size_swa = model_params.ctx_per_seq_swa;
+                compute_params.token_len_per_seq = 1;
+            }
+            break;
+        }
+        if (node->op == GGML_OP_ROPE) {
+            memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15);
+        }
+    }
+    auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1];
+    compute_params.output_len = output_tensor->ne[1];
+    // for NPU, output_len is always 1 except for llama-perplexity
+    if (is_static && compute_params.output_len == 0) {
+        compute_params.output_len = 1;
+    }
+    model_params.ctx = model_params.ctx_per_seq * model_params.n_seq;
+    model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq;
+    return {model_params, compute_params};
+}
+
+void GgmlOvDecoder::validate_cgraph() const {
+    if (m_model_params.n_seq > 1 && m_is_static == true) {
+        throw std::runtime_error("n_seq > 1 is not supported on NPU. Try setting -np 1.");
+    }
+}
+
+ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const {
+    if (m_naive) {
+        return input!= nullptr ? ov::PartialShape{get_shape(input)} : ov::PartialShape{get_shape(op)};
+    }
+    auto name = std::string(input->name);
+    ov::PartialShape input_shape;
+
+    if (is_inp_tok(input, op) || is_inp_pos(input, op)) {
+        // tokens or positions
+        int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
+        input_shape = ov::PartialShape{1, 1, 1, len};
+
+    } else if (is_output_idx(input, op)) {
+        // output index
+        input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1};
+
+    } else if (is_inp_mask(input, op)) {
+        // mask
+        if (m_is_static) {
+            input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};
+        } else if (m_is_stateful) {
+            input_shape = ov::PartialShape{1, 1, -1, -1};
+        } else {
+            input_shape = ov::PartialShape{-1, 1, -1, -1};
+        }
+
+    } else if (is_kvcache(input, op)) {
+        // kvcache
+        input_shape = ov::PartialShape{get_shape(input)};
+        if (!m_is_static) {
+            // do not fix ctx size to make llama-bench work across test params
+            input_shape[2] = -1;
+        }
+        if (is_stateful()) {
+            // Convert stateless KV cache layout [1, 1, seq, n_heads_kv * head_size]
+            // to stateful layout [1, seq, n_heads_kv, head_size].
+            assert(input_shape.size() == 4 && input_shape[0] == 1 && input_shape[1] == 1 &&
+                   input_shape[2].is_dynamic() &&
+                   input_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size));
+            input_shape = {input_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv,
+                           m_model_params.head_size};
+        }
+
+    } else if (is_kv_idx(input, op)) {
+        // kv update index
+        int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
+        input_shape = ov::PartialShape{1, 1, 1, len};
+
+    } else {
+        input_shape = ov::PartialShape{get_shape(input)};
+    }
+    return input_shape;
+}
+
+void GgmlOvDecoder::add_extra_inputs() {
+    // Extra inputs:
+    // 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,
+    //     see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
+    // 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch
+
+    auto create_1d_input = [this](const std::string & name, int64_t value) {
+        if (m_is_static) {
+            auto constant =
+                std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{value});
+            constant->set_friendly_name(name);
+            m_model_extra_inputs[name] = constant;
+        } else {
+            auto param_node = std::make_shared(ov::element::i64, ov::Shape{1});
+            param_node->set_friendly_name(name);
+            param_node->output(0).get_tensor().set_names({name});
+            m_model_extra_inputs[name] = param_node;
+
+            auto tensor = std::make_shared(ov::element::i64, ov::Shape{1});
+            *tensor->data() = value;
+            m_model_extra_input_values[name] = tensor;
+        }
+    };
+
+    create_1d_input("attention_size", m_compute_params.attention_size);
+    if (m_compute_params.attention_size_swa != -1) {
+        create_1d_input("attention_size_swa", m_compute_params.attention_size_swa);
+    }
+    create_1d_input("n_seq_active", m_compute_params.n_seq_active);
+    create_1d_input("seq_active_start", m_compute_params.seq_active_start);
+    create_1d_input("seq_active_end", m_compute_params.seq_active_start + m_compute_params.n_seq_active);
+    create_1d_input("token_len_per_seq", m_compute_params.token_len_per_seq);
+    // create_1d_input("token_len", m_token_len_per_seq * m_n_seq_active);
+}
+
+bool GgmlOvDecoder::node_is_used_as_src(const int node_idx) {
+    ggml_tensor * node = m_cgraph->nodes[node_idx];
+    for (int i = node_idx; i < m_cgraph->n_nodes; i++) {
+        ggml_tensor * other_node = m_cgraph->nodes[i];
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            if (other_node->src[j] == node) {
+                return true;
+            }
+        }
+    }
+    return false;
+}
+
+void GgmlOvDecoder::compute_model_inputs() {
+    m_model_inputs.clear();
+    m_inputs.clear();
+    for (int i = 0; i < m_cgraph->n_nodes; i++) {
+        ggml_tensor * node = m_cgraph->nodes[i];
+        // the node op is NONE means this node maybe as input of later nodes, we should add it to model inputs for this node.
+        if (node->op == GGML_OP_NONE && node_is_used_as_src(i)) {
+            std::string node_name(node->name);
+            if (m_model_weights.find(node_name) == m_model_weights.end()) {
+                m_inputs[node_name] = node;
+                auto param_node =
+                    std::make_shared(get_ov_type(node), get_graph_input_shape(node, nullptr));
+                param_node->set_friendly_name(node_name);
+                param_node->output(0).get_tensor().set_names({node_name});
+                m_model_inputs[node_name] = param_node;
+            }
+            continue;
+        }
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            auto * src = node->src[i];
+            if (src == nullptr) {
+                continue;
+            }
+            std::string src_name = std::string(src->name);
+            if (src->flags & GGML_TENSOR_FLAG_INPUT) {
+                src_name = get_graph_input_ov_name(src, node);
+            }
+            if (m_model_weights.find(src_name) != m_model_weights.end()) {
+                continue;
+            }
+
+            bool is_intermediate_node = false;
+            for (const auto & node_info : m_node_info_list) {
+                if (node_info.node == src) {
+                    is_intermediate_node = true;
+                    break;
+                }
+            }
+            if (is_intermediate_node) {
+                continue;
+            }
+            if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
+                continue;
+            }
+
+            m_inputs[src_name] = src;
+
+            ggml_backend_buffer * buffer = src->buffer;
+            // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches
+            if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) {
+                if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name);
+                    it == m_model_params.kv_names.end()) {
+                    m_model_params.kv_names.push_back(src_name);
+                }
+            }
+            ov::PartialShape param_shape = get_graph_input_shape(node, src);
+            auto param_node = std::make_shared(get_ov_type(src), param_shape);
+            param_node->set_friendly_name(src_name);
+            param_node->output(0).get_tensor().set_names({src_name});
+            m_model_inputs[src_name] = param_node;
+        }
+    }
+}
+
+void GgmlOvDecoder::compute_model_outputs() {
+    m_model_outputs.clear();
+    m_model_output_names.clear();
+    for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) {
+        auto * cur_node = m_cgraph->nodes[node_n];
+        // if the node op is NONE means this node is not used at all, we can skip it directly without adding to model outputs.
+        if (cur_node->op == GGML_OP_NONE) {
+            continue;
+        }
+        auto cur_node_use_count = m_cgraph->use_counts[ggml_hash_find(&m_cgraph->visited_hash_set, cur_node)];
+        if (cur_node_use_count == 0) {
+            // The output of SET_ROWS is the view_src tensor, which is updated in place. We should use the view_src name as the output name to make sure it can be correctly matched with the later ops that use the view_src.
+            if (cur_node != nullptr && cur_node->op == GGML_OP_SET_ROWS) {
+                cur_node = cur_node->view_src;
+            }
+        } else {
+            int input_use_count = 0;
+            for (int i = 0; i < m_cgraph->n_nodes; i++) {
+                ggml_tensor * node = m_cgraph->nodes[i];
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    if (node->src[j] != NULL && node->src[j] == cur_node) {
+                        input_use_count++;
+                    }
+                }
+            }
+            if (input_use_count == cur_node_use_count) {
+                cur_node = nullptr;
+            }
+        }
+        if (cur_node != nullptr) {
+            std::string node_output_name(cur_node->name);
+            m_model_outputs[node_output_name] = cur_node;
+            m_model_output_names.push_back(node_output_name);
+        }
+    }
+}
+
+const ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const {
+    if (tensor == nullptr) {
+        return nullptr;
+    }
+    for (int i = 0; i < m_cgraph->n_nodes; i++) {
+        const auto * node = m_cgraph->nodes[i];
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            if (node->src[j] == tensor) {
+                return node;
+            }
+        }
+    }
+    return nullptr;
+}
+
+const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name) const {
+    for (int i = 0; i < m_cgraph->n_nodes; i++) {
+        const auto * node = m_cgraph->nodes[i];
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            const auto * src = node->src[j];
+            if (src == nullptr) {
+                break;
+            }
+            if (std::string(src->name) == name) {
+                return src;
+            }
+        }
+    }
+    return nullptr;
+}
+
+std::map GgmlOvDecoder::get_kv_param_res_names() const {
+    std::map kv_param_res_names;
+    for (const auto & name : m_model_params.kv_names) {
+        kv_param_res_names[name] = name;
+    }
+    return kv_param_res_names;
+}
+
+std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) {
+    static std::mutex weights_mutex;
+    std::lock_guard lock(weights_mutex);
+
+    std::map> model_weights;
+    auto * nodes = cgraph->nodes;
+    auto n_nodes = cgraph->n_nodes;
+    for (int node_i = 0; node_i < n_nodes; node_i++) {
+        auto * node = nodes[node_i];
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            auto * src = node->src[i];
+            if (src == nullptr) {
+                continue;
+            }
+
+            std::string src_name(src->name);
+            if (is_rope_freqs_weight(src, node)) {
+                src_name = "rope_freqs.weight";
+            }
+            if (!src->view_src) {
+                ggml_backend_buffer * buffer = src->buffer;
+                if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) {
+                    if (model_weights.find(src_name) == model_weights.end()) {
+                        auto weight_node = create_weight_node(src, naive);
+                        weight_node->set_friendly_name(src_name);
+                        model_weights[src_name] = weight_node;
+                    }
+                }
+            }
+        }
+    }
+    return model_weights;
+}
+
+std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor, bool naive) {
+    const bool is_ov_buffer = ggml_backend_buffer_is_openvino(tensor->buffer);
+
+    // Check if we have a pre-built constant from the OpenVINO backend buffer
+    // This is set during ggml_backend_openvino_buffer_set_tensor
+    if (tensor->extra) {
+        OPENVINO_ASSERT(is_ov_buffer, "Unsupported weight tensor: " + std::string(tensor->name) +
+                                          " Possibly this is a cpu backend repacked quantized weights");
+        // Cast to our extra base type and check the type
+        auto * extra_base = static_cast(tensor->extra);
+
+        if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT) {
+            // F16/F32/BF16 weight with shared-memory constant
+            auto * weight_extra = static_cast(tensor->extra);
+            if (weight_extra->weight_node) {
+                // GGML_LOG_DEBUG("%s: using pre-built weight node for %s\n", __func__, tensor->name);
+                return weight_extra->weight_node;
+            }
+        } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) {
+            // Quantized weight with pre-extracted data
+            auto * quant_extra = static_cast(tensor->extra);
+            if (quant_extra->weight_node) {
+                // GGML_LOG_DEBUG("%s: using pre-extracted quantized weight node for %s\n", __func__, tensor->name);
+                return quant_extra->weight_node;
+            }
+        }
+    }
+
+    // There are three cases where we need to create a new weight node:
+    // 1. weights are in openvino_host_buffer. Weight loading to host buffer will not trigger backend_buffer_set_tensor
+    // 2. weights are in cpu/cpu_mapped buffer. On token_embd.weight goes to case 1 or 2, depending on whether mmap or direct_io is used
+    // 3. test-backend-ops. buffers in test-backend-ops does not set USAGE_WEIGHT so backend_buffer_set_tensor will not create weight node
+
+    // GGML_LOG_DEBUG("%s: creating new weight node for %s\n", __func__, tensor->name);
+    static const std::set weight_types = {GGML_TYPE_F32,  GGML_TYPE_F16,  GGML_TYPE_BF16,
+                                                     GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+                                                     GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K};
+    if (weight_types.find(tensor->type) == weight_types.end()) {
+        throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " +
+                                 ggml_type_name(tensor->type));
+    }
+
+    OvWeight ov_weight;
+    if (ggml_is_quantized(tensor->type)) {
+        auto use_bias = naive;
+        if (is_ov_buffer) {
+            // For quantized weights, copy raw data to a temp buffer first because
+            // process_weight_tensor reads from data and writes extracted results
+            // (weights/scales/zp) to output_base_ptr — they would overlap if both
+            // point to tensor->data.
+            size_t raw_size = ggml_nbytes(tensor);
+            std::vector tmp(raw_size);
+            memcpy(tmp.data(), tensor->data, raw_size);
+            ov_weight = process_weight_tensor(tensor, tmp.data(), tensor->data, use_bias);
+        } else {
+            ov_weight = process_weight_tensor(tensor, tensor->data, nullptr, use_bias);
+        }
+    } else {
+        // For non-quantized weights (F16/F32/BF16), data is already in tensor->data.
+        // process_weight_tensor will create an ov::Tensor wrapping tensor->data directly.
+        ov_weight = process_weight_tensor(tensor, tensor->data, tensor->data);
+    }
+
+    ov_weight.weight_node->set_friendly_name(tensor->name);
+    if (!is_ov_buffer) {
+        return ov_weight.weight_node;
+    }
+
+    ggml_openvino_extra_base * extra;
+    if (ov_weight.is_quantized()) {
+        extra = new ggml_openvino_quantized_weight_extra(std::move(ov_weight.weights), std::move(ov_weight.scales),
+                                                         std::move(ov_weight.zp), ov_weight.weight_node);
+    } else {
+        extra = new ggml_openvino_weight_extra(std::move(ov_weight.weights), ov_weight.weight_node);
+    }
+    ggml_openvino_buffer_register_extra(tensor, extra);
+
+    return ov_weight.weight_node;
+}
+
+void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filename) {
+    std::ofstream file(filename);
+    if (!file.is_open()) {
+        std::cerr << "Failed to open file" << std::endl;
+        return;
+    }
+
+    file << "=== GRAPH ===\n";
+
+    // clang-format off
+    file << "n_nodes = " << cgraph->n_nodes << "\n";
+    file << " " << std::setw(3) << "nodes"
+                <<  std::setw(15) << "shape"
+                << std::setw(20) << "op"
+                << std::setw(20) << "name"
+                << std::setw(3) << "    "
+                << std::setw(62) << "stride"
+                << std::setw(20) << "buffer_type"
+                << "\n";
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        // Get buffer type name
+        const char * buf_name = "none";
+        ggml_backend_buffer_t buf = node->view_src ? node->view_src->buffer : node->buffer;
+        if (buf) {
+            buf_name = ggml_backend_buffer_name(buf);
+        }
+
+        file << " - " << std::setw(3) << i << ": [ "
+             << std::setw(5) << node->ne[0] << ", "
+             << std::setw(5) << node->ne[1] << ", "
+             << std::setw(5) << node->ne[2] << ", "
+             << std::setw(5) << node->ne[3] << "] "
+             << std::left << std::setw(20) << ggml_op_name(node->op) << std::right << " "
+             << std::left << std::setw(45) << node->name << std::right
+             << std::setw(2) << "[ "
+             << std::setw(0) << node->nb[0] << ", "
+             << std::setw(5) << node->nb[1] << ", "
+             << std::setw(5) << node->nb[2] << ", "
+             << std::setw(5) << node->nb[3] << "] "
+             << std::right << std::setw(15) << buf_name << std::right
+             << "\n";
+
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (auto* src = node->src[i]) {
+                // Get buffer type name for source
+                const char * src_buf_name = "none";
+                ggml_backend_buffer_t src_buf = src->view_src ? src->view_src->buffer : src->buffer;
+                if (src_buf) {
+                    src_buf_name = ggml_backend_buffer_name(src_buf);
+                }
+
+                file << std::setw(10) << " [ "
+                << std::setw(5) << src->ne[0] << ", "
+                << std::setw(5) << src->ne[1] << ", "
+                << std::setw(5) << src->ne[2] << ", "
+                << std::setw(5) << src->ne[3] << "] "
+                << std::setw(12)
+                << i << ": " << std::left << std::setw(12) << ggml_op_name(src->op) << std::right;
+                file << std::left << std::setw(30) << src->name << std::right
+                << std::setw(16) << "[ "
+                << std::setw(0) << src->nb[0] << ", "
+                << std::setw(5) << src->nb[1] << ", "
+                << std::setw(5) << src->nb[2] << ", "
+                << std::setw(5) << src->nb[3] << "] "
+                << std::right << std::setw(15) << src_buf_name << std::right
+                << "\n";
+            }
+        }
+    }
+
+    file << "n_leafs = " << cgraph->n_leafs << "\n";
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        ggml_tensor * node = cgraph->leafs[i];
+
+        // Get buffer type name for leaf
+        const char * leaf_buf_name = "none";
+        ggml_backend_buffer_t leaf_buf = node->view_src ? node->view_src->buffer : node->buffer;
+        if (leaf_buf) {
+            leaf_buf_name = ggml_backend_buffer_name(leaf_buf);
+        }
+
+        file << " - " << std::setw(3) << i << ": [ "
+             << std::setw(5) << node->ne[0] << ", "
+             << std::setw(5) << node->ne[1] << "] "
+             << std::setw(8) << ggml_op_name(node->op) << " "
+             << std::setw(16) << ggml_get_name(node)
+             << std::setw(20) << leaf_buf_name << "\n";
+    }
+    // clang-format on
+    file << "========================================\n";
+
+    file.close();
+}
+
+void print_tensor_address_map(const ggml_cgraph * cgraph) {
+    std::map> address_map;
+    for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
+        auto * node = cgraph->nodes[node_n];
+        if (node->data) {
+            auto it = address_map.find(node->data);
+            if (it == address_map.end()) {
+                address_map[node->data] = std::vector();
+            }
+            address_map[node->data].push_back(node->name);
+        }
+    }
+    for (const auto & pair : address_map) {
+        std::cout << "Address: " << pair.first << std::endl;
+        for (const auto & name : pair.second) {
+            std::cout << name << " ; ";
+        }
+        std::cout << std::endl << std::endl;
+    }
+}
+
+ov::Shape GgmlOvDecoder::get_shape(const ggml_tensor * tensor) {
+    std::vector shape;
+    for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
+        shape.push_back(static_cast(tensor->ne[i]));
+    }
+    return shape;
+}
+
+std::vector GgmlOvDecoder::get_stride(const ggml_tensor * tensor) {
+    std::vector stride;
+    for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
+        stride.push_back(static_cast(tensor->nb[i]));
+    }
+    return stride;
+}
+
+ov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor * tensor) {
+    switch (tensor->type) {
+    case GGML_TYPE_F64:
+        return ov::element::f64;
+    case GGML_TYPE_F32:
+        return ov::element::f32;
+    case GGML_TYPE_F16:
+        return ov::element::f16;
+    case GGML_TYPE_BF16:
+        return ov::element::bf16;
+    case GGML_TYPE_I8:
+        return ov::element::i8;
+    case GGML_TYPE_I16:
+        return ov::element::i16;
+    case GGML_TYPE_I32:
+        return ov::element::i32;
+    case GGML_TYPE_I64:
+        return ov::element::i64;
+    default:
+        return ov::element::dynamic;
+    }
+}
+
+ov::PartialShape GgmlOvDecoder::get_input_shape(int node_idx, const std::string & name) const {
+    return ov::PartialShape(get_shape(m_node_info_list[node_idx].node_inputs.at(name)));
+}
+
+std::vector GgmlOvDecoder::get_input_stride(int node_idx, const std::string & name) const {
+    return get_stride(m_node_info_list[node_idx].node_inputs.at(name));
+}
+
+ov::element::Type GgmlOvDecoder::get_input_type(int node_idx, const std::string & name) const {
+    return get_ov_type(m_node_info_list[node_idx].node_inputs.at(name));
+}
+
+size_t GgmlOvDecoder::get_input_size() const {
+    return m_model_inputs.size();
+}
+
+size_t GgmlOvDecoder::get_input_size(int node_idx) const {
+    return m_node_info_list[node_idx].node_inputs_names.size();
+}
+
+std::vector GgmlOvDecoder::get_input_names(int node_idx) const {
+    return m_node_info_list[node_idx].node_inputs_names;
+}
+
+ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const {
+    auto * ggml_tensor = m_node_info_list[node_idx].node_output;
+    return ov::PartialShape(get_shape(ggml_tensor));
+}
+
+ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const {
+    return get_ov_type(m_node_info_list[node_idx].node);
+}
+
+std::vector GgmlOvDecoder::get_output_names(int node_idx) const {
+    return {m_node_info_list[node_idx].node_output_name};
+}
+
+const std::string & GgmlOvDecoder::get_op_name() const {
+    static const std::string unknown_name = "UNKNOWN_OP_NAME";
+    return unknown_name;
+}
+
+const std::string & GgmlOvDecoder::get_op_name(int node_idx) const {
+    return m_node_info_list[node_idx].node_name;
+}
+
+int32_t * GgmlOvDecoder::get_input_op_params(int node_idx, const std::string & name) const {
+    return m_node_info_list[node_idx].node_inputs.at(name)->op_params;
+}
+
+int32_t * GgmlOvDecoder::get_output_op_params(int node_idx) const {
+    return m_node_info_list[node_idx].node->op_params;
+}
+
+void GgmlOvDecoder::visit_subgraph(std::function, int node_idx)> node_visitor) const {
+    for (int node_idx = 0; node_idx < m_cgraph->n_nodes; node_idx++) {
+        if (m_cgraph->nodes[node_idx]->op == GGML_OP_NONE) {
+            continue;
+        }
+        node_visitor(std::make_shared(*this), node_idx);
+    }
+}
+
+std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
+    static const std::map ops = {
+        {GGML_OP_NONE,           "GGML_OP_NONE"          },
+        {GGML_OP_ACC,            "GGML_OP_ACC"           },
+        {GGML_OP_ADD,            "GGML_OP_ADD"           },
+        {GGML_OP_ADD1,           "GGML_OP_ADD1"          },
+        {GGML_OP_CONT,           "GGML_OP_CONT"          },
+        {GGML_OP_DIV,            "GGML_OP_DIV"           },
+        {GGML_OP_DUP,            "GGML_OP_DUP"           },
+        {GGML_OP_GET_ROWS,       "GGML_OP_GET_ROWS"      },
+        {GGML_OP_MUL,            "GGML_OP_MUL"           },
+        {GGML_OP_MUL_MAT,        "GGML_OP_MUL_MAT"       },
+        {GGML_OP_PERMUTE,        "GGML_OP_PERMUTE"       },
+        {GGML_OP_RESHAPE,        "GGML_OP_RESHAPE"       },
+        {GGML_OP_RMS_NORM,       "GGML_OP_RMS_NORM"      },
+        {GGML_OP_ROPE,           "GGML_OP_ROPE"          },
+        {GGML_OP_SCALE,          "GGML_OP_SCALE"         },
+        {GGML_OP_SOFT_MAX,       "GGML_OP_SOFT_MAX"      },
+        {GGML_OP_SUB,            "GGML_OP_SUB"           },
+        {GGML_OP_TRANSPOSE,      "GGML_OP_TRANSPOSE"     },
+        {GGML_OP_VIEW,           "GGML_OP_VIEW"          },
+        {GGML_OP_SET_ROWS,       "GGML_OP_SET_ROWS"      },
+        {GGML_OP_CPY,            "GGML_OP_CPY"           },
+        {GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT"},
+    };
+    static const std::map unary_ops = {
+        {GGML_UNARY_OP_ABS,         "GGML_UNARY_OP_ABS"        },
+        {GGML_UNARY_OP_SGN,         "GGML_UNARY_OP_SGN"        },
+        {GGML_UNARY_OP_NEG,         "GGML_UNARY_OP_NEG"        },
+        {GGML_UNARY_OP_STEP,        "GGML_UNARY_OP_STEP"       },
+        {GGML_UNARY_OP_TANH,        "GGML_UNARY_OP_TANH"       },
+        {GGML_UNARY_OP_ELU,         "GGML_UNARY_OP_ELU"        },
+        {GGML_UNARY_OP_RELU,        "GGML_UNARY_OP_RELU"       },
+        {GGML_UNARY_OP_SIGMOID,     "GGML_UNARY_OP_SIGMOID"    },
+        {GGML_UNARY_OP_GELU,        "GGML_UNARY_OP_GELU"       },
+        {GGML_UNARY_OP_GELU_QUICK,  "GGML_UNARY_OP_GELU_QUICK" },
+        {GGML_UNARY_OP_SILU,        "GGML_UNARY_OP_SILU"       },
+        {GGML_UNARY_OP_HARDSWISH,   "GGML_UNARY_OP_HARDSWISH"  },
+        {GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"},
+        {GGML_UNARY_OP_EXP,         "GGML_UNARY_OP_EXP"        },
+        {GGML_UNARY_OP_COUNT,       "GGML_UNARY_OP_COUNT"      }
+    };
+    static const std::map glu_ops = {
+        {GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"},
+        {GGML_GLU_OP_GEGLU,  "GGML_GLU_OP_GEGLU" },
+        {GGML_GLU_OP_REGLU,  "GGML_GLU_OP_REGLU" }
+    };
+
+    switch (node->op) {
+    case GGML_OP_UNARY:
+        return unary_ops.at(ggml_get_unary_op(node));
+    case GGML_OP_GLU:
+        return glu_ops.at(ggml_get_glu_op(node));
+    default:
+        return ops.at(node->op);
+    }
+    static const std::string unknown_op = "UNKNOWN_GGML_OP";
+    return unknown_op;
+}
+
+const std::string & GgmlOvDecoder::get_op_type(int node_idx) const {
+    return m_node_info_list[node_idx].node_op_type;
+}
+
+const std::string & GgmlOvDecoder::get_op_type() const {
+    static const std::string unknown_op = "UNKNOWN_GGML_OP";
+    return unknown_op;
+}
diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h
new file mode 100644
index 00000000..3ae25ddd
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-decoder.h
@@ -0,0 +1,294 @@
+#pragma once
+
+#include "ggml-quants.h"
+#include "ggml.h"
+#include "openvino/decoder.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+struct ModelParams {
+    int ctx = -1;
+    int ctx_swa = -1;
+    int ctx_per_seq = -1;
+    int ctx_per_seq_swa = -1;
+    int n_seq = 1;
+    int n_heads = -1;
+    int n_heads_kv = -1;
+    int head_size = -1;
+    int32_t rope_params[15];
+    std::vector swa_layers;
+
+    std::vector kv_names;
+    size_t kv_buffer_ctx_id = 0;
+
+    bool same_rope_params(const ModelParams & other) const {
+        return memcmp(rope_params, other.rope_params, sizeof(int32_t) * 15) == 0;
+    }
+
+    bool can_reuse_dynamically(const ModelParams & other) const { return same_rope_params(other); }
+
+    bool can_reuse_statically(const ModelParams & other) const { return same_rope_params(other) && ctx == other.ctx; }
+
+    bool kv_buffer_changed(const ModelParams & other) const { return kv_buffer_ctx_id != other.kv_buffer_ctx_id; }
+};
+
+struct ComputeParams {
+    int n_seq_active = 1;
+    int seq_active_start = 0;
+    int attention_size = -1;
+    int attention_size_swa = -1;
+    int input_len = -1;
+    int token_len_per_seq = -1;
+    int past_kv_len = -1;
+    int output_len = 1;
+};
+
+class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
+public:
+    struct NodeInfo {
+        ggml_tensor * node;
+        std::string node_name;
+        std::string node_op_type;
+        std::map node_inputs;
+        std::vector node_inputs_names;
+        ggml_tensor * node_output;
+        std::string node_output_name;
+        int node_op_case = 0;
+        void * data_addr;
+    };
+    // Graph decoder
+    GgmlOvDecoder(ggml_cgraph * cgraph,
+                  ModelParams & model_params,
+                  ComputeParams & compute_params,
+                  std::map> & model_weights,
+                  bool is_static,
+                  bool is_stateful = false,
+                  bool is_prefill = false,
+                  int prefill_chunk_size = 256);
+
+    // Naive graph decoder
+    GgmlOvDecoder(ggml_cgraph * cgraph, std::map> & model_weights);
+
+    virtual ov::Any get_attribute(const std::string & name) const override {
+        return nullptr;
+        GGML_UNUSED(name);
+    }
+
+    virtual ov::PartialShape get_input_shape(int node_idx, const std::string & name) const override;
+
+    virtual std::vector get_input_stride(int node_idx, const std::string & name) const override;
+
+    virtual ov::element::Type get_input_type(int node_idx, const std::string & name) const override;
+
+    virtual size_t get_input_size() const override;
+
+    virtual size_t get_input_size(int node_idx) const override;
+
+    virtual void get_input_node(size_t input_port_idx,
+                                std::string & producer_name,
+                                std::string & producer_output_port_name,
+                                size_t & producer_output_port_index) const override {
+        GGML_UNUSED(input_port_idx);
+        GGML_UNUSED(producer_name);
+        GGML_UNUSED(producer_output_port_name);
+        GGML_UNUSED(producer_output_port_index);
+    }
+
+    virtual std::vector get_input_names(int node_idx) const override;
+
+    virtual ov::PartialShape get_output_shape(int node_idx) const override;
+
+    virtual ov::element::Type get_output_type(int node_idx) const override;
+
+    virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override;
+
+    virtual int32_t * get_output_op_params(int node_idx) const override;
+
+    virtual std::vector get_output_names(int node_idx) const override;
+
+    virtual const std::string & get_op_type() const override;
+
+    virtual const std::string & get_op_type(int node_idx) const override;
+
+    virtual const std::string & get_op_name() const override;
+
+    virtual const std::string & get_op_name(int node_idx) const override;
+
+    virtual void visit_subgraph(std::function, int node_idx)> node_visitor) const override;
+
+    ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
+
+    virtual int get_op_case(int node_idx) const override { return m_node_info_list[node_idx].node_op_case; }
+
+    virtual const std::map> & get_model_inputs() const override {
+        return m_model_inputs;
+    }
+
+    virtual const std::map> & get_model_extra_inputs() const override {
+        return m_model_extra_inputs;
+    }
+
+    virtual const std::map> & get_model_extra_input_values() const {
+        return m_model_extra_input_values;
+    }
+
+    virtual const std::map> & get_model_weights() const override {
+        return m_model_weights;
+    }
+
+    virtual std::vector get_model_output_names() const override {
+        return m_model_output_names;
+    }
+
+    const std::map & get_model_outputs() const { return m_model_outputs; }
+
+    virtual int get_ctx_size() const { return m_model_params.ctx; }
+
+    virtual int get_ctx_swa_size() const { return m_model_params.ctx_swa; }
+
+    virtual int get_ctx_per_seq() const { return m_model_params.ctx_per_seq; }
+
+    virtual int get_ctx_per_seq_swa() const { return m_model_params.ctx_per_seq_swa; }
+
+    virtual int get_n_seq() const { return m_model_params.n_seq; }
+
+    virtual int is_swa_layer(int layer) const override {
+        return std::find(m_model_params.swa_layers.begin(), m_model_params.swa_layers.end(), layer) !=
+               m_model_params.swa_layers.end();
+    }
+
+    int get_past_kv_len() const { return m_compute_params.past_kv_len; }
+
+    int get_input_len() const { return m_compute_params.input_len; }
+
+    virtual int32_t * get_rope_params() const override { return const_cast(m_model_params.rope_params); }
+
+    virtual std::map get_kv_param_res_names() const override;
+
+    virtual bool is_static() const override { return m_is_static; }
+
+    virtual bool is_stateful() const override { return m_is_stateful; }
+
+    ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;
+
+    static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
+
+    static std::shared_ptr create_weight_node(ggml_tensor * tensor, bool naive = false);
+
+    static std::map> create_weight_nodes(ggml_cgraph * cgraph,
+                                                                                bool naive = false);
+
+    const ggml_tensor * get_tensor_used_op(const ggml_tensor * tensor) const;
+
+    const ggml_tensor * get_tensor_from_name(const std::string & name) const;
+
+    void clear_model_weights() { m_model_weights.clear(); }
+
+    static std::pair compute_llm_params(ggml_cgraph * cgraph, bool is_static);
+
+    ModelParams get_model_params() const { return m_model_params; }
+
+    ComputeParams get_compute_params() const { return m_compute_params; }
+
+    void set_model_params(const ModelParams & model_params) { m_model_params = model_params; }
+
+    void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; }
+
+    bool m_is_static = false;
+    bool m_is_stateful = false;
+    bool m_is_prefill = false;
+    bool m_naive = false;
+    int m_prefill_chunk_size = 0;
+
+    static ov::Shape get_shape(const ggml_tensor * tensor);
+    static std::vector get_stride(const ggml_tensor * tensor);
+    static ov::element::Type get_ov_type(const ggml_tensor * tensor);
+    static std::string compute_op_type(const ggml_tensor * node);
+    void add_extra_inputs();
+
+    void update_io(ggml_cgraph * cgraph);
+
+    inline static bool is_inp_tok(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op == GGML_OP_NONE;
+    }
+
+    inline static bool is_inp_pos(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_ROPE && tensor == op->src[1];
+    }
+
+    inline static bool is_inp_emb(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return tensor->op == GGML_OP_GET_ROWS && op->op == GGML_OP_RMS_NORM;
+    }
+
+    inline static bool is_inp_mask(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_CPY || (op->op == GGML_OP_FLASH_ATTN_EXT && tensor == op->src[3]);
+    }
+
+    inline static bool is_rope_freqs_weight(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_ROPE && tensor == op->src[2];
+    }
+
+    inline static bool is_kvcache(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_SET_ROWS && op->src[2] == tensor;
+    }
+
+    inline static bool is_kv_idx(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_SET_ROWS && op->src[1] == tensor;
+    }
+
+    inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) {
+        return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE;
+    }
+
+    static std::string get_graph_input_ov_name(const ggml_tensor * tensor, const ggml_tensor * op) {
+        if (is_inp_tok(tensor, op)) {
+            return "inp_tokens";
+        }
+        if (is_inp_pos(tensor, op)) {
+            return "inp_pos";
+        }
+        if (is_inp_emb(tensor, op)) {
+            return "embd";
+        }
+        if (is_output_idx(tensor, op)) {
+            return "inp_out_ids";
+        }
+        if (is_inp_mask(tensor, op)) {
+            return std::string(tensor->name).find("swa") == std::string::npos ? "self_kq_mask" : "self_kq_mask_swa";
+        }
+        return tensor->name;
+    }
+
+private:
+    void set_input_output();
+    int compute_op_case(const ggml_tensor * node) const;
+    bool node_is_used_as_src(const int node_idx);
+    void compute_model_inputs();
+    void compute_model_outputs();
+
+    void validate_cgraph() const;
+
+    ggml_cgraph * m_cgraph = nullptr;
+    std::map m_inputs;
+
+    std::map> m_model_inputs;
+    std::map> m_model_extra_inputs;
+    std::map> m_model_extra_input_values;
+    std::map> m_model_weights;
+    std::map m_model_outputs;
+    std::vector m_model_output_names;
+    std::vector m_node_info_list;
+
+    ModelParams m_model_params;
+    ComputeParams m_compute_params;
+};
+
+void print_tensor_address_map(const ggml_cgraph * cgraph);
+
+int extract_layer_from_name(const std::string & name);
diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp
new file mode 100644
index 00000000..cc3cb458
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp
@@ -0,0 +1,373 @@
+#include "ggml-openvino-extra.h"
+
+#include "ggml-impl.h"
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+#include 
+
+ov::Core & ov_singleton_core() {
+    static ov::Core core;
+    return core;
+}
+
+// =====================================================
+// Device Configuration Implementations
+// =====================================================
+
+void ggml_openvino_device_config::init() {
+    if (initialized) {
+        return;
+    }
+    device_name = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "CPU";
+    auto available_devices = ov_singleton_core().get_available_devices();
+    if (std::find(available_devices.begin(), available_devices.end(), device_name) == available_devices.end()) {
+        GGML_LOG_WARN("GGML OpenVINO Backend: device %s is not available, fallback to CPU\n", device_name.c_str());
+        device_name = "CPU";
+    }
+    is_npu = (device_name == "NPU");
+
+    auto * cache_dir = getenv("GGML_OPENVINO_CACHE_DIR");
+    if (device_name == "NPU") {
+        compile_config = {
+            {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES"   },
+            {"NPU_USE_NPUW",                      "YES"   },
+            {"NPUW_DEVICES",                      "NPU"   },
+            {"NPUW_FOLD",                         "YES"   },
+            {"NPUW_WEIGHTS_BANK",                 "shared"},
+            {"NPUW_FUNCALL_FOR_ALL",              "YES"   },
+            {"NPUW_FUNCALL_ASYNC",                "YES"   },
+            {"NPUW_DQ",                           "YES"   },
+            {"NPUW_DQ_FULL",                      "NO"    },
+        };
+        if (cache_dir) {
+            compile_config["NPUW_CACHE_DIR"] = cache_dir;
+        }
+    } else if (cache_dir) {
+        ov_singleton_core().set_property(ov::cache_dir(cache_dir));
+    }
+
+    // Initialize remote context with queue sharing for GPU
+    if (device_name == "GPU") {
+        // Create OpenCL context and queue
+        cl_int err;
+        cl_platform_id platform;
+        err = clGetPlatformIDs(1, &platform, nullptr);
+        if (err != CL_SUCCESS) {
+            GGML_LOG_ERROR("Failed to get OpenCL platform: %d\n", err);
+            return;
+        }
+
+        cl_device_id cl_device;
+        err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &cl_device, nullptr);
+        if (err != CL_SUCCESS) {
+            GGML_LOG_ERROR("Failed to get OpenCL device: %d\n", err);
+            return;
+        }
+
+        cl_context cl_ctx = clCreateContext(nullptr, 1, &cl_device, nullptr, nullptr, &err);
+        if (err != CL_SUCCESS) {
+            GGML_LOG_ERROR("Failed to create OpenCL context: %d\n", err);
+            return;
+        }
+
+        cl_queue = clCreateCommandQueueWithProperties(cl_ctx, cl_device, nullptr, &err);
+        if (err != CL_SUCCESS) {
+            GGML_LOG_ERROR("Failed to create OpenCL command queue: %d\n", err);
+            clReleaseContext(cl_ctx);
+            return;
+        }
+
+        // Create OpenVINO remote context with queue sharing
+        remote_context = ov::intel_gpu::ocl::ClContext(ov_singleton_core(), cl_queue);
+
+        // Release the context (queue keeps a reference)
+        clReleaseContext(cl_ctx);
+    } else if (device_name == "NPU") {
+        // remote tensor is not used for NPU yet
+        // remote_context = ov_singleton_core().get_default_context(device_name);
+    }
+
+    initialized = true;
+}
+
+ggml_openvino_device_config::~ggml_openvino_device_config() {
+    if (cl_queue != nullptr) {
+        clReleaseCommandQueue(cl_queue);
+        cl_queue = nullptr;
+    }
+}
+
+// Get the global device config singleton
+ggml_openvino_device_config & ggml_openvino_get_device_config() {
+    static ggml_openvino_device_config config;
+    return config;
+}
+
+// Initialize device config (call during backend init)
+void ggml_openvino_init_device_config() {
+    ggml_openvino_get_device_config().init();
+}
+
+// Get the device name
+const std::string & ggml_openvino_get_device_name() {
+    return ggml_openvino_get_device_config().device_name;
+}
+
+// Check if running on NPU
+bool ggml_openvino_is_npu() {
+    return ggml_openvino_get_device_config().is_npu;
+}
+
+// Get the remote context for the current device (returns empty optional for CPU)
+std::optional ggml_openvino_get_remote_context() {
+    return ggml_openvino_get_device_config().remote_context;
+}
+
+// Get the compile config for the current device
+const ov::AnyMap & ggml_openvino_get_compile_config() {
+    return ggml_openvino_get_device_config().compile_config;
+}
+
+// Get the OpenCL command queue for GPU operations
+cl_command_queue ggml_openvino_get_cl_queue() {
+    return ggml_openvino_get_device_config().cl_queue;
+}
+
+// Get the clEnqueueMemFillINTEL function pointer (lazy load)
+clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL() {
+    static clEnqueueMemFillINTEL_fn fn = nullptr;
+    static bool loaded = false;
+    if (!loaded) {
+        loaded = true;
+        cl_platform_id platform;
+        if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) {
+            fn = (clEnqueueMemFillINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemFillINTEL");
+        }
+    }
+    return fn;
+}
+
+// Get the clEnqueueMemcpyINTEL function pointer (lazy load)
+clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() {
+    static clEnqueueMemcpyINTEL_fn fn = nullptr;
+    static bool loaded = false;
+    if (!loaded) {
+        loaded = true;
+        cl_platform_id platform;
+        if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) {
+            fn = (clEnqueueMemcpyINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemcpyINTEL");
+        }
+    }
+    return fn;
+}
+
+// Get requantization type for a tensor type (returns nullopt if no requant needed)
+std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant) {
+    if (no_requant) {
+        return std::nullopt;
+    }
+    if (strncmp(tensor->name, "token_embd.weight", 17) == 0) {
+        return ((ggml_openvino_is_npu() && tensor->type == GGML_TYPE_Q6_K) ? ExtraQuantType::F16 : ExtraQuantType::Q8_0_C);
+    }
+    if (strncmp(tensor->name, "output.weight", 13) == 0) {
+        return ExtraQuantType::Q8_0_C;
+    }
+    if (ggml_openvino_is_npu()) {
+        return ExtraQuantType::Q4_0_128;
+    }
+    switch (tensor->type) {
+    case GGML_TYPE_Q6_K:
+    case GGML_TYPE_Q5_K:
+        return ExtraQuantType::Q8_0_C;
+    default:
+        return std::nullopt;
+    }
+}
+
+// =====================================================
+// Extracted Layout Calculation
+// =====================================================
+
+ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias) {
+    ggml_openvino_extracted_layout layout = {};
+    layout.is_symmetric = false;
+
+    if (!ggml_is_quantized(tensor->type)) {
+        return layout;
+    }
+
+    // Only handle 2D weight tensors
+    if (tensor->ne[2] != 1 || tensor->ne[3] != 1) {
+        return layout;
+    }
+
+    int64_t n_elements = ggml_nelements(tensor);
+    const size_t alignment = 64;  // Good for SIMD
+
+    // Check if requantization is needed (NPU-specific)
+    auto requant_type = ggml_openvino_get_requant_type(tensor, use_bias);
+    if (requant_type.has_value()) {
+        layout.is_requant = true;
+        layout.requant_type = requant_type;
+
+        // Special case: requant to F16 - just store F16 weights, no scales/zp
+        if (requant_type.value() == ExtraQuantType::F16) {
+            layout.weights_size = n_elements * sizeof(uint16_t);  // F16 = 2 bytes
+            layout.total_size = layout.weights_size;
+            layout.weights_offset = 0;
+            // No scales/zp for F16
+            return layout;
+        }
+
+        // Requant to different quantized format (e.g., Q4_0_128)
+        switch (requant_type.value()) {
+        case ExtraQuantType::Q4_0_128:
+            layout.is_u4 = true;
+            layout.weights_per_block = 128;
+            layout.is_symmetric = true;
+            break;
+        case ExtraQuantType::Q4_0_C:
+            layout.is_u4 = true;
+            layout.weights_per_block = tensor->ne[0];
+            layout.is_symmetric = true;
+            break;
+        case ExtraQuantType::Q8_0_32:
+            layout.is_u4 = false;
+            layout.weights_per_block = 32;
+            layout.is_symmetric = true;
+            break;
+        case ExtraQuantType::Q8_0_C:
+            layout.is_u4 = false;
+            layout.weights_per_block = tensor->ne[0];
+            layout.is_symmetric = true;
+            break;
+        case ExtraQuantType::Q8_1_C:
+            layout.is_u4 = false;
+            layout.weights_per_block = tensor->ne[0];
+            break;
+        default:
+            layout.weights_per_block = -1;
+            GGML_ABORT("Code of re-quantizing to channel-wise is not updated");
+            break;
+        }
+
+        if (layout.is_requant) {
+            // Calculate sizes for requantized format
+            layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements;
+            int64_t n_blocks = n_elements / layout.weights_per_block;
+            layout.scales_size = n_blocks * sizeof(uint16_t);
+            // For symmetric quantization, we only need one zp value (not one per block)
+            // Zero points are stored in U4 or U8 format matching the weight type
+            size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks;
+            layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements;
+
+            layout.weights_offset = 0;
+            layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment;
+            layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment;
+            layout.total_size = layout.zp_offset + layout.zp_size;
+            layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor));
+            return layout;
+        }
+    }
+
+    // Normal extraction (no requant) - determine format based on tensor type
+    layout.is_u4 = false;
+    layout.weights_per_block = 32;
+    layout.is_symmetric = false;
+
+    switch (tensor->type) {
+    case GGML_TYPE_Q4_0:
+        layout.is_u4 = true;
+        layout.is_symmetric = true;
+        break;
+
+    case GGML_TYPE_Q4_1:
+    case GGML_TYPE_Q4_K:
+        layout.is_u4 = true;
+        break;
+
+    case GGML_TYPE_Q8_0:
+        layout.is_symmetric = true;
+        break;
+
+    case GGML_TYPE_Q6_K:
+        layout.weights_per_block = 16;
+        layout.is_symmetric = true;
+        break;
+
+    case GGML_TYPE_Q5_K:
+        break;
+
+    default:
+        // Unsupported quantization type
+        return layout;
+    }
+
+    // Calculate sizes
+    // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes
+    layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements;
+
+    // Scales: F16 per block
+    int64_t n_blocks = n_elements / layout.weights_per_block;
+    layout.scales_size = n_blocks * sizeof(uint16_t);  // F16 = 2 bytes
+    // Zero points: U4 or U8 matching weight type
+    // For symmetric quantization, we only need one zp value (not one per block)
+    size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks;
+    layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements;
+
+    // Layout in buffer: [weights | scales | zp] with alignment
+    layout.weights_offset = 0;
+    layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment;
+    layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment;
+    layout.total_size = layout.zp_offset + layout.zp_size;
+    layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor));
+
+    return layout;
+}
+
+ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote) {
+    ov::Shape shape;
+    for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
+        shape.push_back(static_cast(tensor->ne[i]));
+    }
+
+    ov::element::Type element_type;
+    switch (tensor->type) {
+    case GGML_TYPE_F32:
+        element_type = ov::element::f32;
+        break;
+    case GGML_TYPE_F16:
+        element_type = ov::element::f16;
+        break;
+    case GGML_TYPE_BF16:
+        element_type = ov::element::bf16;
+        break;
+    case GGML_TYPE_I32:
+        element_type = ov::element::i32;
+        break;
+    case GGML_TYPE_I64:
+        element_type = ov::element::i64;
+        break;
+    default:
+        // GGML_LOG_WARN("%s: unsupported tensor type for ov::Tensor: %s\n", __func__, ggml_type_name(tensor->type));
+        return nullptr;
+    }
+
+    const auto & device_name = ggml_openvino_get_device_name();
+    auto remote_context = ggml_openvino_get_remote_context();
+
+    std::shared_ptr ov_tensor;
+    if (is_remote) {
+        GGML_ASSERT(device_name == "GPU");
+        auto gpu_context = remote_context->as();
+        auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data);
+        ov_tensor = std::make_shared(std::move(usm_tensor));
+    } else {
+        ov_tensor = std::make_shared(element_type, shape, tensor->data);
+    }
+
+    return new ggml_openvino_tensor_extra(ov_tensor);
+}
diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h
new file mode 100644
index 00000000..cd0baf4a
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h
@@ -0,0 +1,182 @@
+#pragma once
+
+#include "ggml.h"
+#include "openvino/runtime/core.hpp"
+
+#define CL_TARGET_OPENCL_VERSION 300
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// ExtraQuantType enum - defines requantization target formats
+enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 };
+
+ov::Core & ov_singleton_core();
+
+// Get the remote context for the current device (returns empty optional for CPU)
+std::optional ggml_openvino_get_remote_context();
+
+// Get the compile config for the current device
+const ov::AnyMap & ggml_openvino_get_compile_config();
+
+// Get the OpenCL command queue for GPU operations (returns nullptr for CPU/NPU)
+cl_command_queue ggml_openvino_get_cl_queue();
+
+// Intel USM extension function type
+typedef cl_int(CL_API_CALL * clEnqueueMemFillINTEL_fn)(cl_command_queue queue,
+                                                       void * dst_ptr,
+                                                       const void * pattern,
+                                                       size_t pattern_size,
+                                                       size_t size,
+                                                       cl_uint num_events_in_wait_list,
+                                                       const cl_event * event_wait_list,
+                                                       cl_event * event);
+
+typedef cl_int(CL_API_CALL * clEnqueueMemcpyINTEL_fn)(cl_command_queue queue,
+                                                      cl_bool blocking,
+                                                      void * dst_ptr,
+                                                      const void * src_ptr,
+                                                      size_t size,
+                                                      cl_uint num_events_in_wait_list,
+                                                      const cl_event * event_wait_list,
+                                                      cl_event * event);
+
+// Get the clEnqueueMemFillINTEL function pointer (returns nullptr if not available)
+clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL();
+
+// Get the clEnqueueMemcpyINTEL function pointer (returns nullptr if not available)
+clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL();
+
+// =====================================================
+// Global Device Configuration (singleton)
+// =====================================================
+// Initialized once during backend init from GGML_OPENVINO_DEVICE env var
+
+struct ggml_openvino_device_config {
+    std::string device_name = "CPU";
+    bool is_npu = false;
+    bool initialized = false;
+    std::optional remote_context;
+    ov::AnyMap compile_config;
+    cl_command_queue cl_queue = nullptr;
+
+    void init();
+    ~ggml_openvino_device_config();
+};
+
+// Get the global device config singleton
+ggml_openvino_device_config & ggml_openvino_get_device_config();
+
+// Initialize device config (call during backend init)
+void ggml_openvino_init_device_config();
+
+// Get the device name
+const std::string & ggml_openvino_get_device_name();
+
+// Check if running on NPU
+bool ggml_openvino_is_npu();
+
+// Get requantization type for a tensor type (returns nullopt if no requant needed)
+std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant = false);
+
+// =====================================================
+// OpenVINO Tensor Extra Types
+// =====================================================
+// These types are stored in tensor->extra by the OpenVINO backend buffer.
+// They allow:
+// 1. Pre-built ov::Constant nodes for weights (avoiding memcpy during graph construction)
+// 2. ov::Tensor wrappers for KV cache / compute tensors (for direct use with infer_request)
+
+// Base class for OpenVINO tensor extra data
+struct ggml_openvino_extra_base {
+    enum class Type { WEIGHT, QUANTIZED_WEIGHT, TENSOR };
+    Type type;
+    virtual ~ggml_openvino_extra_base() = default;
+protected:
+    explicit ggml_openvino_extra_base(Type t) : type(t) {}
+};
+
+// Extra data for F16/F32/BF16 weight tensors - stores the pre-built weight node
+struct ggml_openvino_weight_extra : public ggml_openvino_extra_base {
+    ov::Tensor weights;                     // The underlying weight data tensor
+    std::shared_ptr weight_node;  // Pre-built OpenVINO weight node
+
+    ggml_openvino_weight_extra(ov::Tensor w, std::shared_ptr n) :
+        ggml_openvino_extra_base(Type::WEIGHT),
+        weights(std::move(w)),
+        weight_node(std::move(n)) {}
+};
+
+// Extra data for quantized weight tensors - stores extracted weights/scales/zp and weight node
+struct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base {
+    ov::Tensor weights;   // U4 or U8 extracted weights
+    ov::Tensor scales;    // F16 scales
+    ov::Tensor zp;        // U4 or U8 zero points (same type as weights)
+    std::shared_ptr weight_node;  // Pre-built OpenVINO weight subgraph
+
+    ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr n) :
+        ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT),
+        weights(std::move(w)),
+        scales(std::move(s)),
+        zp(std::move(z)),
+        weight_node(std::move(n)) {}
+};
+
+// Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request
+struct ggml_openvino_tensor_extra : public ggml_openvino_extra_base {
+    std::shared_ptr tensor;  // For direct use with infer_request
+
+    explicit ggml_openvino_tensor_extra(std::shared_ptr t)
+        : ggml_openvino_extra_base(Type::TENSOR), tensor(std::move(t)) {}
+};
+
+// =====================================================
+// Extracted Size Calculation for Quantized Tensors
+// =====================================================
+// For quantized tensors, we need extra space to store extracted weights, scales, and zero points.
+// Returns the total size needed in the buffer for extracted data.
+
+struct ggml_openvino_extracted_layout {
+    size_t total_size = 0;      // Total bytes needed
+    size_t weights_offset = 0;  // Offset to weights in buffer
+    size_t weights_size = 0;    // Size of weights in bytes
+    size_t scales_offset = 0;   // Offset to scales in buffer
+    size_t scales_size = 0;     // Size of scales in bytes
+    size_t zp_offset = 0;       // Offset to zero points in buffer
+    size_t zp_size = 0;         // Size of zero points in bytes (U4 or U8)
+    bool is_u4;                 // true for U4 weights, false for U8
+    int64_t weights_per_block;  // weights per scale/zp block
+    bool is_symmetric;        // true for symmetric quantization
+
+    // Requantization info
+    bool is_requant = false;                      // true if this tensor needs requantization
+    std::optional requant_type;   // target requant type if is_requant
+};
+
+// Calculate the buffer layout for extracted quantized data
+ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias = false);
+
+ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote);
+
+// Register an extra with the tensor's OpenVINO buffer context for proper lifetime management.
+// This sets tensor->extra and tracks the extra in the buffer context for cleanup.
+void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra);
+
+// =====================================================
+// OpenVINO Backend Context and Interface
+// =====================================================
+struct ggml_backend_openvino_context {
+    int device = 0;
+    std::string name = "OpenVINO";
+    std::string description = "OpenVINO Backend Context";
+
+    std::shared_ptr runtime_context = nullptr;
+
+    ggml_backend_openvino_context() = default;
+};
diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp
new file mode 100644
index 00000000..0031cb73
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-openvino.cpp
@@ -0,0 +1,1110 @@
+#include "ggml-openvino.h"
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "ggml-openvino-extra.h"
+#include "ggml-openvino/utils.h"
+#include "ggml-quants.h"
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(_WIN32)
+#    define WIN32_LEAN_AND_MEAN
+#    ifndef NOMINMAX
+#        define NOMINMAX
+#    endif
+#    include 
+#else
+#    include 
+#endif
+
+// =====================================================
+// OpenVINO Buffer Implementation using ov::Tensor
+// =====================================================
+//
+// Design: This implementation uses a hybrid approach:
+// 1. For weight tensors: Store a pre-built ov::op::v0::Constant in tensor->extra
+//    - This avoids the memcpy during graph construction
+//    - For quantized weights, the constant is already converted to OpenVINO format
+// 2. For KV cache / compute tensors: Store an ov::Tensor in tensor->extra
+//    - This can be directly passed to infer_request
+//    - Future: can be changed to ov::RemoteTensor for GPU/NPU
+//
+// This design is similar to:
+// - CUDA split buffer: tensor->extra stores device pointers
+// - CPU repack buffer: tensor->extra stores tensor_traits with repacked data
+// =====================================================
+
+// Buffer context that manages per-tensor allocations (no contiguous buffer for weights)
+struct ggml_backend_openvino_buffer_context {
+    int device;
+    std::string name;
+    size_t id;
+
+    // For non-weight buffers (KV cache, compute), we still use contiguous allocation
+    void * data;
+    size_t size;
+    bool is_remote;
+
+    // Wrapping of the buffer
+    std::shared_ptr ov_buffer;
+
+    // Track all extras for cleanup
+    std::map tensor_extras;
+
+    // Used for re-allocation on device for kvcache
+    void * data_prev;
+
+    ggml_backend_openvino_buffer_context(int device, size_t size, bool is_remote = false) :
+        device(device),
+        name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)),
+        id([]() {
+            static std::atomic next_id{1};
+            return next_id.fetch_add(1);
+        }()),
+        data(nullptr),
+        size(size),
+        is_remote(is_remote) {
+        if (size == 0) {
+            return;
+        }
+
+        const auto & device_name = ggml_openvino_get_device_name();
+
+        if (is_remote) {
+            GGML_ASSERT(device_name == "GPU");
+            auto remote_context = ggml_openvino_get_remote_context();
+            auto gpu_context = remote_context->as();
+            ov::intel_gpu::ocl::USMTensor usm_tensor =
+                gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size});
+            data = usm_tensor.get();
+            ov_buffer = std::make_shared(std::move(usm_tensor));
+        } else {
+            data = ggml_aligned_malloc(size);
+            ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data);
+        }
+
+        if (data == nullptr) {
+            GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size);
+            return;
+        }
+
+        if (reinterpret_cast(data) % TENSOR_ALIGNMENT != 0) {
+            GGML_LOG_ERROR("%s: %s buffer is not aligned to %d bytes\n", __func__, device_name.c_str(),
+                           TENSOR_ALIGNMENT);
+            GGML_ABORT("fatal error");
+        }
+    }
+
+    ~ggml_backend_openvino_buffer_context() {
+        // Clean up all tensor extras
+        // GGML_LOG_DEBUG("Deleting OpenVINO buffer context #%zu for device %d, size %zu MB\n", id, device,
+        //                size / 1024 / 1024);
+        for (auto & pair : tensor_extras) {
+            delete pair.second;
+        }
+        tensor_extras.clear();
+        if (!is_remote && data != nullptr) {
+            ggml_aligned_free(data, size);
+        }
+    }
+};
+
+// Buffer type context (per-device)
+struct ggml_backend_openvino_buffer_type_context {
+    int device;
+    std::string name;
+};
+
+// Buffer interface functions
+static void ggml_backend_openvino_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+    delete ctx;
+}
+
+static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer) {
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+    return ctx->data;
+}
+
+static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name);
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+
+    // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache)
+    if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" &&
+        !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) {
+        GGML_ASSERT(ctx->tensor_extras.empty());
+        auto device = ctx->device;
+        auto size = ctx->size;
+        auto * data_prev = ctx->data;
+        delete ctx;
+        ctx = new ggml_backend_openvino_buffer_context(device, size, true);
+        buffer->context = ctx;
+        tensor->data = (char *) ctx->data + ((char *) tensor->data - (char *) data_prev);
+    }
+
+    // Views share the extra from view_src
+    if (tensor->view_src != nullptr) {
+        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
+        if (tensor->view_src->extra != nullptr) {
+            tensor->extra = tensor->view_src->extra;
+        }
+        return GGML_STATUS_SUCCESS;
+    }
+
+    ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+
+    if (tensor->data != nullptr && !ggml_is_quantized(tensor->type)) {
+        ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote);
+        if (extra != nullptr) {
+            auto it = ctx->tensor_extras.find(tensor);
+            if (it != ctx->tensor_extras.end()) {
+                delete it->second;
+            }
+            ctx->tensor_extras[tensor] = extra;
+            tensor->extra = extra;
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buffer,
+                                                       ggml_tensor * tensor,
+                                                       uint8_t value,
+                                                       size_t offset,
+                                                       size_t size) {
+    // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name);
+    GGML_ASSERT(tensor != nullptr && tensor->data != nullptr);
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+
+    if (ctx->is_remote) {
+        // For remote (device) buffers, use OpenCL USM memfill
+        cl_command_queue queue = ggml_openvino_get_cl_queue();
+        auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL();
+        if (queue != nullptr && mem_fill_fn != nullptr) {
+            uint8_t pattern = value;
+            cl_int err = mem_fill_fn(queue, (char *) tensor->data + offset, &pattern, sizeof(pattern), size, 0, nullptr,
+                                     nullptr);
+            if (err != CL_SUCCESS) {
+                GGML_LOG_ERROR("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err);
+            }
+            clFinish(queue);
+        } else {
+            GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer\n", __func__);
+        }
+    } else {
+        memset((char *) tensor->data + offset, value, size);
+    }
+}
+
+static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                                    ggml_tensor * tensor,
+                                                    const void * data,
+                                                    size_t offset,
+                                                    size_t size) {
+    // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name);
+    GGML_ASSERT(tensor != nullptr && tensor->data != nullptr);
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+
+    // Check if this is a weight buffer (usage is set BEFORE set_tensor is called, except in test-backend-ops)
+    bool is_weight_buffer = (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+    // Full tensor set: offset=0, full size, not a view
+    bool is_full_tensor_set = (offset == 0 && size == ggml_nbytes(tensor) && tensor->view_src == nullptr);
+    // 2D tensor (typical weight shape)
+    bool is_2d = (tensor->ne[2] == 1 && tensor->ne[3] == 1);
+
+    if (is_weight_buffer && is_full_tensor_set && is_2d) {
+        try {
+            auto result = process_weight_tensor(tensor, data, tensor->data);
+            result.weight_node->set_friendly_name(tensor->name);
+
+            // const auto & layout = result.layout;
+            ggml_openvino_extra_base * extra;
+
+            // Quantized path with extracted weight/scale/zp tensors
+            if (result.is_quantized()) {
+                extra = new ggml_openvino_quantized_weight_extra(std::move(result.weights), std::move(result.scales),
+                                                                 std::move(result.zp), result.weight_node);
+
+                // if (layout.is_requant) {
+                //     GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name,
+                //                    extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8,
+                //                    layout.weights_per_block);
+                // } else {
+                //     int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block;
+                //     GGML_LOG_DEBUG("%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\n",
+                //                    __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks);
+                // }
+            } else {
+                // F16/F32/BF16 weight or F16-requant
+                extra = new ggml_openvino_weight_extra(std::move(result.weights), result.weight_node);
+
+                // if (layout.total_size > 0) {
+                //     GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name);
+                // } else {
+                //     GGML_LOG_DEBUG("%s: created shared-memory weight node for %s\n", __func__, tensor->name);
+                // }
+            }
+
+            ctx->tensor_extras[tensor] = extra;
+            tensor->extra = extra;
+
+        } catch (const std::exception & e) {
+            GGML_LOG_ERROR("%s: failed to process weight tensor for %s: %s\n", __func__, tensor->name, e.what());
+            memcpy((char *) tensor->data + offset, data, size);
+        }
+    } else {
+        // Non-weight tensor (KV cache, activations, etc.) - copy data. test-backend-ops also goes here
+        if (ctx->is_remote) {
+            cl_command_queue queue = ggml_openvino_get_cl_queue();
+            auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL();
+            if (queue != nullptr && mem_cpy_fn != nullptr) {
+                cl_int err =
+                    mem_cpy_fn(queue, CL_TRUE, (char *) tensor->data + offset, data, size, 0, nullptr, nullptr);
+                if (err != CL_SUCCESS) {
+                    GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err);
+                }
+            } else {
+                GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__);
+            }
+        } else {
+            memcpy((char *) tensor->data + offset, data, size);
+        }
+
+        ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote);
+        if (extra == nullptr) {
+            // GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name);
+            return;
+        }
+
+        auto it = ctx->tensor_extras.find(tensor);
+        if (it != ctx->tensor_extras.end()) {
+            delete it->second;
+        }
+        ctx->tensor_extras[tensor] = extra;
+        tensor->extra = extra;
+    }
+}
+
+static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer,
+                                                    const ggml_tensor * tensor,
+                                                    void * data,
+                                                    size_t offset,
+                                                    size_t size) {
+    // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name);
+    GGML_ASSERT(tensor != nullptr && tensor->data != nullptr);
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+
+    if (ctx->is_remote) {
+        // For remote (device) buffers, use OpenCL USM memcpy (device-to-host)
+        cl_command_queue queue = ggml_openvino_get_cl_queue();
+        auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL();
+        if (queue != nullptr && mem_cpy_fn != nullptr) {
+            cl_int err =
+                mem_cpy_fn(queue, CL_TRUE, data, (const char *) tensor->data + offset, size, 0, nullptr, nullptr);
+            if (err != CL_SUCCESS) {
+                GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err);
+            }
+        } else {
+            GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__);
+        }
+    } else {
+        memcpy(data, (const char *) tensor->data + offset, size);
+    }
+}
+
+static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
+                                                    const ggml_tensor * src,
+                                                    ggml_tensor * dst) {
+    // GGML_LOG_DEBUG("%s: src tensor name=%s, dst tensor name=%s\n", __func__, src->name, dst->name);
+    GGML_ASSERT(src != nullptr && dst != nullptr);
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+
+    if (ctx->is_remote) {
+        // For remote (device) buffers, use OpenCL USM memcpy
+        cl_command_queue queue = ggml_openvino_get_cl_queue();
+        auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL();
+        if (queue == nullptr || mem_cpy_fn == nullptr) {
+            GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__);
+            return false;
+        }
+        // Can copy from host to device
+        if (ggml_backend_buffer_is_host(src->buffer)) {
+            cl_int err = mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr);
+            if (err != CL_SUCCESS) {
+                GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (host-to-device) failed with error %d\n", __func__, err);
+                return false;
+            }
+            return true;
+        }
+        // Can also copy from device to device if both are OpenVINO remote buffers
+        if (ggml_backend_buffer_is_openvino(src->buffer)) {
+            ggml_backend_openvino_buffer_context * src_ctx =
+                (ggml_backend_openvino_buffer_context *) src->buffer->context;
+            if (src_ctx->is_remote) {
+                cl_int err =
+                    mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr);
+                if (err != CL_SUCCESS) {
+                    GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (device-to-device) failed with error %d\n", __func__,
+                                   err);
+                    return false;
+                }
+                return true;
+            }
+        }
+        return false;
+    }
+
+    // Host buffer - can copy from any host buffer
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        memcpy(dst->data, src->data, ggml_nbytes(src));
+        return true;
+    }
+    return false;
+}
+
+static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+    GGML_ASSERT(ctx->data != nullptr);
+    if (ctx->is_remote) {
+        cl_command_queue queue = ggml_openvino_get_cl_queue();
+        auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL();
+        if (queue != nullptr && mem_fill_fn != nullptr) {
+            uint8_t pattern = value;
+            cl_int err = mem_fill_fn(queue, ctx->data, &pattern, sizeof(pattern), ctx->size, 0, nullptr, nullptr);
+            if (err != CL_SUCCESS) {
+                GGML_LOG_WARN("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err);
+            }
+            clFinish(queue);
+        } else {
+            GGML_LOG_WARN("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer clear\n",
+                          __func__);
+        }
+    } else {
+        memset(ctx->data, value, ctx->size);
+    }
+}
+
+static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_openvino_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_openvino_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_openvino_buffer_init_tensor,
+    /* .memset_tensor   = */ ggml_backend_openvino_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_openvino_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_openvino_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_openvino_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_openvino_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// Buffer type interface functions
+static const char * ggml_backend_openvino_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context;
+    return ctx->name.c_str();
+}
+
+static ggml_backend_buffer_t ggml_backend_openvino_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+                                                                            size_t size) {
+    ggml_backend_openvino_buffer_type_context * buft_ctx = (ggml_backend_openvino_buffer_type_context *) buft->context;
+
+    // Create buffer context with contiguous memory allocation
+    ggml_backend_openvino_buffer_context * ctx = new ggml_backend_openvino_buffer_context(buft_ctx->device, size);
+
+    if (ctx->data == nullptr && size > 0) {
+        GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
+        delete ctx;
+        return nullptr;
+    }
+
+    return ggml_backend_buffer_init(buft, ggml_backend_openvino_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_openvino_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    GGML_UNUSED(buft);
+    return TENSOR_ALIGNMENT;
+}
+
+static size_t ggml_backend_openvino_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    GGML_UNUSED(buft);
+    return SIZE_MAX;
+}
+
+static size_t ggml_backend_openvino_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
+                                                               const ggml_tensor * tensor) {
+    GGML_UNUSED(buft);
+
+    // For quantized 2D tensors (weights), we need extra space for extracted data
+    if (ggml_is_quantized(tensor->type) && tensor->ne[2] == 1 && tensor->ne[3] == 1) {
+        ggml_openvino_extracted_layout layout = ggml_openvino_get_extracted_layout(tensor);
+        if (layout.total_size > 0) {
+            // GGML_LOG_DEBUG("%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu zp=%zu)\n",
+            //                __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size,
+            //                layout.scales_size, layout.zp_size);
+            return layout.total_size;
+        }
+    }
+
+    return ggml_nbytes(tensor);
+}
+
+static const ggml_backend_buffer_type_i ggml_backend_openvino_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_openvino_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_openvino_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_openvino_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_openvino_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_openvino_buffer_type_get_alloc_size,
+    /* .is_host          = */ nullptr,
+};
+
+// Get buffer type for a specific device
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) {
+    GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count());
+
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    static std::vector buffer_types;
+    static std::vector buffer_type_contexts;
+
+    if (buffer_types.empty()) {
+        int device_count = ggml_backend_openvino_get_device_count();
+        buffer_types.resize(device_count);
+        buffer_type_contexts.resize(device_count);
+
+        for (int i = 0; i < device_count; i++) {
+            buffer_type_contexts[i].device = i;
+            buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i);
+
+            buffer_types[i] = ggml_backend_buffer_type{
+                /* .iface   = */ ggml_backend_openvino_buffer_type_interface,
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i),
+                /* .context = */ &buffer_type_contexts[i],
+            };
+        }
+    }
+
+    return &buffer_types[device];
+}
+
+// =====================================================
+// OpenVINO Host Buffer Implementation
+// =====================================================
+
+static const char * ggml_backend_openvino_host_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context;
+    static std::string name;
+    name = ctx->name + "_HOST";
+    return name.c_str();
+}
+
+static bool ggml_backend_openvino_host_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    GGML_UNUSED(buft);
+    return true;
+}
+
+static const ggml_backend_buffer_type_i ggml_backend_openvino_host_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_openvino_host_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_openvino_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_openvino_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_openvino_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_openvino_buffer_type_get_alloc_size,
+    /* .is_host          = */ ggml_backend_openvino_host_buffer_type_is_host,
+};
+
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device) {
+    GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count());
+
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    static std::vector buffer_types;
+    static std::vector buffer_type_contexts;
+
+    if (buffer_types.empty()) {
+        int device_count = ggml_backend_openvino_get_device_count();
+        buffer_types.resize(device_count);
+        buffer_type_contexts.resize(device_count);
+
+        for (int i = 0; i < device_count; i++) {
+            buffer_type_contexts[i].device = i;
+            buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i);
+
+            buffer_types[i] = ggml_backend_buffer_type{
+                /* .iface   = */ ggml_backend_openvino_host_buffer_type_interface,
+                /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i),
+                /* .context = */ &buffer_type_contexts[i],
+            };
+        }
+    }
+
+    return &buffer_types[device];
+}
+
+bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer;
+}
+
+size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer) {
+    if (!ggml_backend_buffer_is_openvino(buffer)) {
+        return 0;
+    }
+    ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context;
+    return ctx->id;
+}
+
+void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra) {
+    GGML_ASSERT(tensor != nullptr);
+    GGML_ASSERT(tensor->buffer != nullptr);
+    GGML_ASSERT(ggml_backend_buffer_is_openvino(tensor->buffer));
+
+    auto * ctx = static_cast(tensor->buffer->context);
+
+    auto it = ctx->tensor_extras.find(tensor);
+    if (it != ctx->tensor_extras.end()) {
+        delete it->second;
+    }
+
+    ctx->tensor_extras[tensor] = extra;
+    tensor->extra = extra;
+}
+
+bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name;
+}
+
+bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_openvino_host_buffer_type_get_name;
+}
+
+static void ggml_backend_openvino_free(ggml_backend_t backend) {
+    ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context;
+    delete ctx;
+    delete backend;
+}
+
+static const char * ggml_backend_openvino_get_name(ggml_backend_t backend) {
+    return GGML_OPENVINO_NAME;
+    GGML_UNUSED(backend);
+}
+
+static enum ggml_status ggml_backend_openvino_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    return ov_graph_compute(cgraph, backend);
+    GGML_UNUSED(backend);
+}
+
+static const ggml_backend_i ggml_backend_openvino_interface = {
+    /* .get_name                = */ ggml_backend_openvino_get_name,
+    /* .free                    = */ ggml_backend_openvino_free,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_openvino_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .graph_optimize          = */ NULL,
+};
+
+int ggml_backend_openvino_get_device_count() {
+    return 1;
+}
+
+static ggml_guid_t ggml_backend_openvino_guid(void) {
+    static ggml_guid guid = {0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97,
+                             0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d};
+    return &guid;
+}
+
+static std::shared_ptr get_ov_runtime_context_ptr() {
+    static std::shared_ptr r_ctx = std::make_shared();
+    return r_ctx;
+}
+
+// backend API
+GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) {
+    if (device < 0 || device >= ggml_backend_openvino_get_device_count()) {
+        GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device);
+        return nullptr;
+    }
+
+    ggml_backend_openvino_context * ctx = new ggml_backend_openvino_context;
+    if (ctx == nullptr) {
+        GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
+        return nullptr;
+    }
+
+    ctx->runtime_context = get_ov_runtime_context_ptr();
+    if (ctx->runtime_context == nullptr) {
+        GGML_LOG_ERROR("%s: failed to allocate runtime context\n", __func__);
+        delete ctx;
+        return nullptr;
+    }
+
+    std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context);
+    r_ctx->device = ggml_openvino_get_device_name();
+    r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu();
+
+    ggml_backend_t openvino_backend = new ggml_backend{
+        /* .guid      = */ ggml_backend_openvino_guid(),
+        /* .interface = */ ggml_backend_openvino_interface,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), device),
+        /* .context   = */ ctx,
+    };
+
+    return openvino_backend;
+}
+
+GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_openvino_guid());
+}
+
+struct ggml_backend_openvino_device_context {
+    int device;
+    std::string name;
+    std::string description;
+};
+
+static const char * ggml_backend_openvino_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;
+    return ctx->name.c_str();
+}
+
+static const char * ggml_backend_openvino_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+#ifdef _WIN32
+    MEMORYSTATUSEX status;
+    status.dwLength = sizeof(status);
+    GlobalMemoryStatusEx(&status);
+    *total = status.ullTotalPhys;
+    *free = status.ullAvailPhys;
+#else
+    long pages = sysconf(_SC_PHYS_PAGES);
+    long page_size = sysconf(_SC_PAGE_SIZE);
+    *total = pages * page_size;
+
+    // "free" system memory is ill-defined, for practical purposes assume that all of it is free:
+    *free = *total;
+#endif  // _WIN32
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return GGML_BACKEND_DEVICE_TYPE_GPU;
+}
+
+static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
+    props->name = ggml_backend_openvino_device_get_name(dev);
+    props->description = ggml_backend_openvino_device_get_description(dev);
+    props->type = ggml_backend_openvino_device_get_type(dev);
+    ggml_backend_openvino_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_openvino_device_init(ggml_backend_dev_t dev, const char * params) {
+    GGML_UNUSED(params);
+    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;
+    return ggml_backend_openvino_init(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;
+    return ggml_backend_openvino_buffer_type(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context;
+    return ggml_backend_openvino_host_buffer_type(ctx->device);
+}
+
+static bool has_view_op_input(const ggml_tensor * op) {
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] == nullptr) {
+            break;
+        }
+        if (op->src[i]->op == GGML_OP_VIEW) {
+            return true;
+        }
+    }
+    return false;
+}
+
+static bool is_supported_flash_attn_pattern(const ggml_tensor * op) {
+    // pattern of q,k,v should be q->op==PERMUTE, q->src[0]->op==VIEW, q->src[0]->src[0]->view_src==nullptr
+    for (int i = 0; i < 3; i++) {
+        const ggml_tensor * src = op->src[i];
+        if (src->op != GGML_OP_PERMUTE || src->src[0] == nullptr || src->src[0]->op != GGML_OP_VIEW ||
+            src->src[0]->src[0] == nullptr || src->src[0]->src[0]->view_src != nullptr) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool is_op_unsupported_case(const ggml_tensor * op) {
+    switch (op->op) {
+    case GGML_OP_GET_ROWS:
+    case GGML_OP_SET_ROWS: {
+        if (op->ne[3] != 1) {
+            return true;
+        }
+        break;
+    }
+    case GGML_OP_ADD:
+    case GGML_OP_MUL: {
+        if (op->src[1]->op == GGML_OP_PERMUTE) {
+            return true;
+        }
+        for (int i = 0; i < 4; i++) {
+            if (op->src[0]->ne[i] != op->src[1]->ne[i] && (op->src[0]->ne[i] != 1 && op->src[1]->ne[i] != 1)) {
+                return true;
+            }
+        }
+        break;
+    }
+    case GGML_OP_SOFT_MAX: {
+        if (op->src[2] != nullptr) {
+            // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with sinks\n");
+            return true;
+        }
+        float scale = 1.0f;
+        float max_bias = 0.0f;
+        const auto * op_params = op->op_params;
+        memcpy(&scale, (const float *) op_params + 0, sizeof(float));
+        memcpy(&max_bias, (const float *) op_params + 1, sizeof(float));
+        if (max_bias > 0) {
+            // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n");
+            return true;
+        }
+        break;
+    }
+    case GGML_OP_FLASH_ATTN_EXT: {
+        if (op->src[4] != nullptr) {
+            // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n");
+            return true;
+        }
+        if (!is_supported_flash_attn_pattern(op)) {
+            return true;
+        }
+        float scale = 1.0f;
+        float max_bias = 0.0f;
+        float logit_softcap = 0.0f;
+        const auto * op_params = op->op_params;
+        memcpy(&scale, (const float *) op_params + 0, sizeof(float));
+        memcpy(&max_bias, (const float *) op_params + 1, sizeof(float));
+        memcpy(&logit_softcap, (const float *) op_params + 2, sizeof(float));
+        if (max_bias > 0) {
+            // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n");
+            return true;
+        }
+        if (logit_softcap != 0) {
+            // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n");
+            return true;
+        }
+        break;
+    }
+    case GGML_OP_PERMUTE: {
+        if (op->type == GGML_TYPE_BF16) {
+            // err msg: [GPU] Could not find a suitable kernel for transpose
+            // GGML_LOG_WARN("OpenVINO backend does not support PERMUTE with BF16 type\n");
+            return true;
+        }
+        break;
+    }
+    case GGML_OP_CPY: {
+        if (op->src[1] != op) {
+            // GGML_LOG_WARN("OpenVINO backend only supports CPY that is a cast\n");
+            return true;
+        }
+        break;
+    }
+    case GGML_OP_MUL_MAT: {
+        if (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16) {
+            // Has accuracy issue, try enabling this and see `test-backend-ops -o "MUL_MAT"`
+            // GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n");
+            return true;
+        }
+        if (op->src[0]->ne[3] != op->src[1]->ne[3] && op->src[0]->ne[3] != 1 && op->src[1]->ne[3] != 1) {
+            return true;
+        }
+        if (op->src[0]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_PERMUTE) {
+            return true;
+        }
+        if (ggml_is_quantized(op->src[0]->type) && op->src[0]->ne[1] == 1) {
+            // MUL_MAT(type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1)
+            // triggers a bug in ov matmul_shape_inference.hpp
+            return true;
+        }
+        if (op->src[0]->op == GGML_OP_VIEW && op->src[1]->op == GGML_OP_VIEW) {
+            return true;
+        }
+        break;
+    }
+    case GGML_OP_ROPE: {
+        const int32_t * op_params = op->op_params;
+        const int n_dims = op_params[1];
+        const int mode = op_params[2];
+        if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
+            // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
+            return true;
+        }
+        if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) {
+            // GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n", n_dims,
+            //               op->src[0]->ne[0]);
+            return true;
+        }
+        if (op->type != GGML_TYPE_F32) {
+            // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type));
+            return true;
+        }
+        float freq_scale;
+        float ext_factor;
+        memcpy(&freq_scale, op_params + 6, sizeof(float));
+        memcpy(&ext_factor, op_params + 7, sizeof(float));
+        if (ext_factor != 0.0f) {
+            // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor);
+            return true;
+        }
+        if (op->src[0]->op == GGML_OP_VIEW) {
+            if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) {
+                // GGML_LOG_WARN(
+                //     "OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] "
+                //     "%ld\n",
+                //     op->src[0]->view_src->ne[1], op->src[0]->ne[2]);
+                return true;
+            }
+        }
+        break;
+    }
+    default:
+        break;
+    }
+    if (op->op == GGML_OP_GET_ROWS) {
+        if (op->ne[0] == 256 && (op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K)) {
+            // ERR = 0.000000306 > 0.000000100   GET_ROWS(type=q4_K,n=256,m=5,r=4,be1=1,be2=1,v=0)
+            // ERR = 0.000000197 > 0.000000100   GET_ROWS(type=q5_K,n=256,m=5,r=4,be1=1,be2=1,v=0)
+            return true;
+        }
+    }
+    return false;
+}
+
+static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    GGML_ASSERT(dev->reg != nullptr);
+
+    static std::set supported_types{GGML_TYPE_F32,  GGML_TYPE_F16,  GGML_TYPE_BF16, GGML_TYPE_I64,
+                                               GGML_TYPE_I32,  GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K,
+                                               GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K};
+
+    static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW,
+                                                 /*GGML_OP_CONT,*/ GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE,
+                                                 GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE,
+                                                 // softmax is not updated due to replaced by flash_attn_ext
+                                                 // GGML_OP_SOFT_MAX,
+                                                 GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY};
+    static const std::set supported_unary_ops{
+        GGML_UNARY_OP_SILU,
+    };
+    static const std::set supported_glu_ops{
+        GGML_GLU_OP_SWIGLU,
+        GGML_GLU_OP_GEGLU,
+    };
+
+    switch (op->op) {
+    case GGML_OP_UNARY: {
+        auto supported = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end();
+        if (!supported) {
+            // GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op)));
+            return false;
+        }
+        if (has_view_op_input(op)) {
+            // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n",
+            //               ggml_unary_op_name(ggml_get_unary_op(op)));
+            return false;
+        }
+        break;
+    }
+    case GGML_OP_GLU: {
+        auto supported = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end();
+        if (!supported) {
+            // GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op)));
+            return false;
+        }
+        if (has_view_op_input(op)) {
+            // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n",
+            //               ggml_glu_op_name(ggml_get_glu_op(op)));
+            return false;
+        }
+        if (op->src[1] == nullptr && op->src[0]->ne[0] % 2 != 0) {
+            // triggers bug in ov gpu
+            return false;
+        }
+        break;
+    }
+    default: {
+        auto supported = supported_ops.find(op->op) != supported_ops.end();
+        if (!supported) {
+            // GGML_LOG_WARN("OpenVINO backend does not support op %s\n", ggml_op_name(op->op));
+            return false;
+        }
+        static std::set ops_not_support_view_input{
+            GGML_OP_GET_ROWS,
+            GGML_OP_RMS_NORM,
+        };
+        if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_op_input(op)) {
+            // GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op));
+            return false;
+        }
+    }
+    }
+
+    if (supported_types.find(op->type) == supported_types.end()) {
+        // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type));
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        auto * src = op->src[i];
+        if (src == nullptr) {
+            break;
+        }
+        if (supported_types.find(src->type) == supported_types.end()) {
+            // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type));
+            return false;
+        }
+        if (ggml_is_quantized(src->type) && src->ne[2] != 1) {
+            // GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n");
+            return false;
+        }
+    }
+
+    if (is_op_unsupported_case(op)) {
+        return false;
+    }
+    return true;
+}
+
+static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_host(buft);
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_openvino_device_interface = {
+    /* .get_name             = */ ggml_backend_openvino_device_get_name,
+    /* .get_description      = */ ggml_backend_openvino_device_get_description,
+    /* .get_memory           = */ ggml_backend_openvino_device_get_memory,
+    /* .get_type             = */ ggml_backend_openvino_device_get_type,
+    /* .get_props            = */ ggml_backend_openvino_device_get_props,
+    /* .init_backend         = */ ggml_backend_openvino_device_init,
+    /* .get_buffer_type      = */ ggml_backend_openvino_device_get_buffer_type,
+    /* .get_host_buffer_type = */ ggml_backend_openvino_device_get_host_buffer_type,
+    /* .buffer_from_host_ptr = */ NULL,
+    /* .supports_op          = */ ggml_backend_openvino_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_openvino_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+struct ggml_backend_openvino_reg_context {
+    std::vector devices;
+};
+
+static const char * ggml_backend_openvino_reg_get_name(ggml_backend_reg_t reg) {
+    return GGML_OPENVINO_NAME;
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_openvino_reg_get_device_count(ggml_backend_reg_t reg) {
+    GGML_UNUSED(reg);
+    return (size_t) ggml_backend_openvino_get_device_count();
+}
+
+static ggml_backend_dev_t ggml_backend_openvino_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    ggml_backend_openvino_reg_context * ctx = (ggml_backend_openvino_reg_context *) reg->context;
+    GGML_ASSERT(index < ctx->devices.size());
+    return ctx->devices[index];
+}
+
+static const struct ggml_backend_reg_i ggml_backend_openvino_reg_interface = {
+    /* .get_name         = */ ggml_backend_openvino_reg_get_name,
+    /* .get_device_count = */ ggml_backend_openvino_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_openvino_reg_get_device,
+    /* .get_proc_address = */ NULL,
+};
+
+static void ggml_openvino_init() {
+    // Initialize device config singleton from env var
+    ggml_openvino_init_device_config();
+    GGML_LOG_INFO("OpenVINO: using device %s\n", ggml_openvino_get_device_name().c_str());
+}
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void) {
+    static ggml_backend_reg reg;
+
+    static bool initialized = false;
+    {
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+        if (!initialized) {
+            ggml_openvino_init();
+
+            ggml_backend_openvino_reg_context * ctx = new ggml_backend_openvino_reg_context;
+
+            for (int i = 0; i < ggml_backend_openvino_get_device_count(); i++) {
+                ggml_backend_openvino_device_context * dev_ctx = new ggml_backend_openvino_device_context;
+                dev_ctx->device = i;
+                dev_ctx->name = GGML_OPENVINO_NAME + std::to_string(i);
+
+                dev_ctx->description = ov::get_openvino_version().description;
+
+                ggml_backend_dev_t dev =
+                    new ggml_backend_device{/* .interface = */ ggml_backend_openvino_device_interface,
+                                            /* .reg       = */ ®,
+                                            /* .context   = */ dev_ctx};
+                ctx->devices.push_back(dev);
+            }
+
+            reg = ggml_backend_reg{/* .api_version = */ GGML_BACKEND_API_VERSION,
+                                   /* .iface       = */ ggml_backend_openvino_reg_interface,
+                                   /* .context     = */ ctx};
+        }
+
+        initialized = true;
+    }
+
+    return ®
+}
diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp
new file mode 100644
index 00000000..dbf38646
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-quants.cpp
@@ -0,0 +1,884 @@
+#include "ggml-quants.h"
+
+#include "ggml-common.h"
+#include "ggml-impl.h"
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+void unpack_32_4(const uint8_t * data, uint8_t * dst) {
+    std::fill_n(dst, 16, 0);
+    for (int j = 0; j < 16; ++j) {
+        uint8_t x = (data[j] & 0x0F);
+        uint8_t y = (data[j] >> 4);
+        if (j % 2 != 0) {
+            x <<= 4;
+            y <<= 4;
+        }
+        dst[j / 2] |= x;
+        dst[8 + j / 2] |= y;  // Last 16 weights are in the higher bits
+    }
+}
+
+// Extracts (weight, scales, zp) from Q4_0 tensors.
+// Data layout is: |16 bit scale|32 x 4bit weights|.
+void extract_q4_0_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr) {
+    const uint64_t bytes_per_block = 18;  // 2 bytes scale, 32x0.5 byte weights
+
+    auto * data = static_cast(tensor->data);
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+    auto * zp = static_cast(zp_arr.data());
+
+    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization
+
+    // For Q4_0, zero point is always 8
+    if (is_scalar_zp) {
+        zp[0] = 8 | (8 << 4);  // Pack two 4-bit values
+    }
+
+    ov::parallel_for(scales_arr.get_size(), [&](size_t i) {
+        scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)));
+        // For asymmetric quantization, compute per-block zero points
+        if (!is_scalar_zp) {
+            // Pack two 4-bit zero points per byte
+            if (i % 2 == 0) {
+                zp[i / 2] = 8;          // Lower nibble
+            } else {
+                zp[i / 2] |= (8 << 4);  // Upper nibble
+            }
+        }
+        unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16);
+    });
+}
+
+// Extracts (weight, scales, zp) from Q4_1 tensors.
+// Data layout is: |16 bit scale|16 bit min|32 x 4bit weights|.
+void extract_q4_1_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr,
+                       bool use_bias) {
+    const uint64_t bytes_per_block = 20;  // 2 bytes scale, 2 bytes min, 32x0.5 byte weights
+
+    auto * data = static_cast(tensor->data);
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+
+    if (use_bias) {
+        // Store bias (min) directly as f16 instead of computing u4 zero points
+        auto * bias = zp_arr.data::value_type>();
+        ov::parallel_for(scales_arr.get_size(), [&](size_t i) {
+            float scale = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))));
+            float min = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2))));
+            scales[i] = ov::float16(scale);
+            bias[i] = ov::float16(min);  // bias = min, dequant: w*s + bias
+            unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16);
+        });
+    } else {
+        auto * zp = static_cast(zp_arr.data());
+        ov::parallel_for(scales_arr.get_size(), [&](size_t i) {
+            float scale = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))));
+            float min = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2))));
+            scales[i] = ov::float16(scale);
+            // zp = -min / scale (bias = min, so zp = -bias/scale)
+            uint8_t zp_val = (scale != 0.0f) ? (uint8_t) std::round(-min / scale) : 0;
+            // Pack two 4-bit zero points per byte
+            if (i % 2 == 0) {
+                zp[i / 2] = zp_val & 0x0F;   // Lower nibble
+            } else {
+                zp[i / 2] |= (zp_val << 4);  // Upper nibble
+            }
+            unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16);
+        });
+    }
+}
+
+// Extracts (weight, scales, zp) from Q8_0 tensors.
+// Data layout is: |16 bit scale|32 x 8bit weights|.
+void extract_q8_0_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr) {
+    const uint64_t weights_per_block = 32;
+    const uint64_t bytes_per_block = 34;  // 2 bytes scale, 32x1 byte weights
+
+    auto * data = static_cast(tensor->data);
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+    auto * zp = static_cast(zp_arr.data());
+
+    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization
+
+    // For Q8_0, zero point is always 128
+    if (is_scalar_zp) {
+        zp[0] = 128;
+    }
+
+    ov::parallel_for(scales_arr.get_size(), [&](size_t i) {
+        uint8_t * block_data = data + i * bytes_per_block;
+        scales[i] = ov::float16::from_bits(*(uint16_t *) block_data);
+        // For asymmetric quantization, store per-block zero points
+        if (!is_scalar_zp) {
+            zp[i] = 128;
+        }
+        for (size_t j = 0; j < weights_per_block; ++j) {
+            uint8_t x = block_data[j + 2];  // j+2 to skip the scale bytes.
+            // Original data is in int8_t, so we add a bias of -128 and invert the first bit.
+            x ^= 1 << 7;
+            weights[i * weights_per_block + j] = x;
+        }
+    });
+}
+
+void unpack_256_4(const uint8_t * data, uint8_t * dst) {
+    // Initialize the output array with zeros
+    std::fill_n(dst, 128, 0);
+
+    for (size_t i = 0; i < 4; ++i) {
+        for (int j = 0; j < 32; ++j) {
+            uint8_t x = (data[i * 32 + j] & 0x0F);
+            uint8_t y = (data[i * 32 + j] >> 4);
+            if (j % 2 != 0) {
+                x <<= 4;
+                y <<= 4;
+            }
+            dst[i * 32 + j / 2] |= x;
+            dst[i * 32 + 16 + j / 2] |= y;  // Last 16 weights are in the higher bits
+        }
+    }
+}
+
+void extract_q4_k_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr,
+                       bool use_bias) {
+    const uint64_t bytes_per_block = 2 + 2 + 12 + 128;
+    const uint64_t n_super_block = tensor->nb[3] / bytes_per_block;
+
+    auto * data = static_cast(tensor->data);
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+
+    // For bias path, zp_arr holds f16 bias values; for zp path, it holds packed u4 zero points
+    auto * zp_u4 = use_bias ? nullptr : static_cast(zp_arr.data());
+    auto * bias_f16 = use_bias ? zp_arr.data::value_type>() : nullptr;
+
+    ov::parallel_for(n_super_block, [&](size_t i) {
+        uint8_t * block_data = data + i * bytes_per_block;
+
+        // Extract scale factors and offsets
+        float scale_scales = static_cast(ov::float16::from_bits(*((uint16_t *) block_data)));
+        float scale_mins = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1)));
+
+        // Extract qs1 and qs2
+        uint8_t * qs1 = block_data + 4;
+
+        // Calculate scales
+        float scale_vals[8];
+        scale_vals[0] = scale_scales * static_cast((*(qs1) & 0b111111));
+        scale_vals[1] = scale_scales * static_cast((*(qs1 + 1) & 0b111111));
+        scale_vals[2] = scale_scales * static_cast((*(qs1 + 2) & 0b111111));
+        scale_vals[3] = scale_scales * static_cast((*(qs1 + 3) & 0b111111));
+        scale_vals[4] = scale_scales * static_cast((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4));
+        scale_vals[5] = scale_scales * static_cast((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4));
+        scale_vals[6] = scale_scales * static_cast((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4));
+        scale_vals[7] = scale_scales * static_cast((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4));
+
+        // Calculate min values (bias = -min)
+        float min_vals[8];
+        min_vals[0] = scale_mins * static_cast((*(qs1 + 4) & 0b111111));
+        min_vals[1] = scale_mins * static_cast((*(qs1 + 5) & 0b111111));
+        min_vals[2] = scale_mins * static_cast((*(qs1 + 6) & 0b111111));
+        min_vals[3] = scale_mins * static_cast((*(qs1 + 7) & 0b111111));
+        min_vals[4] = scale_mins * static_cast((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4));
+        min_vals[5] = scale_mins * static_cast((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4));
+        min_vals[6] = scale_mins * static_cast((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4));
+        min_vals[7] = scale_mins * static_cast((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4));
+
+        // Store scales and compute zero points or bias
+        for (int j = 0; j < 8; j++) {
+            scales[i * 8 + j] = ov::float16(scale_vals[j]);
+            if (use_bias) {
+                // Store bias = -min directly as f16, dequant: w*s + bias
+                bias_f16[i * 8 + j] = ov::float16(-min_vals[j]);
+            } else {
+                // zp = min / scale (since bias = -min and zp = -bias/scale)
+                uint8_t zp_val = (scale_vals[j] != 0.0f) ? (uint8_t) std::round(min_vals[j] / scale_vals[j]) : 0;
+                // Pack two 4-bit zero points per byte
+                size_t idx = i * 8 + j;
+                if (idx % 2 == 0) {
+                    zp_u4[idx / 2] = zp_val & 0x0F;
+                } else {
+                    zp_u4[idx / 2] |= (zp_val << 4);
+                }
+            }
+        }
+        unpack_256_4(block_data + 16, weights + i * 128);
+    });
+}
+
+void extract_q6_k_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr) {
+    const uint64_t bytes_per_block = 128 + 64 + 16 + 2;
+    const uint64_t n_super_block = tensor->nb[3] / bytes_per_block;
+
+    auto * data = static_cast(tensor->data);
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+    auto * zp = static_cast(zp_arr.data());
+
+    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization
+
+    // For Q6_K, zero point is always 32
+    if (is_scalar_zp) {
+        zp[0] = 32;
+    }
+
+    ov::parallel_for(n_super_block, [&](size_t i) {
+        uint8_t * block_data = data + i * bytes_per_block;
+
+        float scale_factor =
+            static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104)));  // (128+64+16)/2
+
+        for (size_t j = 0; j < 16; j++) {
+            scales[j + i * 16] =
+                ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j))));
+            // For asymmetric quantization, store per-block zero points
+            if (!is_scalar_zp) {
+                zp[j + i * 16] = 32;
+            }
+        }
+
+        uint8_t * ql = block_data;
+        uint8_t * qh = block_data + 128;
+
+        for (int64_t j = 0; j < 32; ++j) {
+            weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4);
+            weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4);
+            weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4);
+            weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4);
+            weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4);
+            weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4);
+            weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4);
+            weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4);
+        }
+    });
+}
+
+static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) {
+    if (j < 4) {
+        *d = q[j] & 63;
+        *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
+        *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
+    }
+}
+
+void extract_q5_k_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr,
+                       bool use_bias) {
+    const uint64_t bytes_per_block = 4 + 12 + 32 + 128;
+    const uint64_t n_super_block = tensor->nb[3] / bytes_per_block;
+
+    auto * data = static_cast(tensor->data);
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+
+    // For bias path, zp_arr holds f16 bias values; for zp path, it holds u8 zero points
+    auto * zp_u8 = use_bias ? nullptr : static_cast(zp_arr.data());
+    auto * bias_f16 = use_bias ? zp_arr.data::value_type>() : nullptr;
+
+    ov::parallel_for(n_super_block, [&](size_t i) {
+        uint8_t * block_data = data + i * bytes_per_block;
+
+        const float d = static_cast(ov::float16::from_bits(*((uint16_t *) block_data)));
+        const float min_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1)));
+
+        const uint8_t * scales_data = block_data + 4;   // 12 bytes of scales
+        const uint8_t * qh = block_data + 4 + 12;       // 32 bytes of high bits
+        const uint8_t * ql = block_data + 4 + 12 + 32;  // 128 bytes of low bits
+
+        int is = 0;
+        uint8_t u1 = 1;
+        uint8_t u2 = 2;
+
+        // Process 2 blocks in one iteration
+        for (int j = 0; j < 256; j += 64) {  // 256 = QK_K, so 4 iterations of 64
+            uint8_t sc;
+            uint8_t m;
+
+            // Get scale and min for first 32 elements
+            get_scale_min_k4(is + 0, scales_data, &sc, &m);
+            const float d1 = d * sc;
+            const float m1 = min_factor * m;
+
+            // Get scale and min for second 32 elements
+            get_scale_min_k4(is + 1, scales_data, &sc, &m);
+            const float d2 = d * sc;
+            const float m2 = min_factor * m;
+
+            scales[i * 8 + is] = ov::float16(d1);
+            scales[i * 8 + is + 1] = ov::float16(d2);
+            if (use_bias) {
+                // Store bias = -min directly as f16, dequant: w*s + bias
+                bias_f16[i * 8 + is] = ov::float16(-m1);
+                bias_f16[i * 8 + is + 1] = ov::float16(-m2);
+            } else {
+                // zp = min / scale (since bias = -min and zp = -bias/scale)
+                zp_u8[i * 8 + is] = (d1 != 0.0f) ? (uint8_t) std::round(m1 / d1) : 0;
+                zp_u8[i * 8 + is + 1] = (d2 != 0.0f) ? (uint8_t) std::round(m2 / d2) : 0;
+            }
+
+            // Extract weights for first 32 elements (matching deq formula exactly)
+            for (int l = 0; l < 32; ++l) {
+                weights[i * 256 + j + l] = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0);
+            }
+
+            // Extract weights for second 32 elements
+            for (int l = 0; l < 32; ++l) {
+                weights[i * 256 + j + l + 32] = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0);
+            }
+
+            ql += 32;
+            is += 2;
+            u1 <<= 2;
+            u2 <<= 2;
+        }
+    });
+}
+
+// TODO Reorder for make_intX_weights
+
+ov::Output make_int8_weights(ov::Tensor & weight,
+                                       ov::Tensor & scales,
+                                       ov::Tensor & zp,
+                                       size_t group_size,
+                                       bool use_bias) {
+    ov::Shape orig_shape = weight.get_shape();
+
+    // Expand dimensions for scales and zp/bias
+    auto scale_shape = scales.get_shape();
+    auto zp_shape = zp.get_shape();
+    bool is_scalar_zp = zp_shape.empty();  // Symmetric quantization
+
+    ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size};
+
+    if (packed_shape[1] == 1) {
+        // Requantized channel-wise case
+        packed_shape.erase(packed_shape.begin() + 1);
+    } else {
+        scale_shape.push_back(1);
+        scales.set_shape(scale_shape);
+        // For symmetric quantization, zp remains scalar (don't resize)
+        if (!is_scalar_zp) {
+            zp_shape.push_back(1);
+            zp.set_shape(zp_shape);
+        }
+    }
+
+    // Create graph nodes
+    auto weights_node = std::make_shared(ov::element::u8, packed_shape,
+                                                               static_cast(weight.data()), nullptr);
+    weights_node->get_rt_info()["__gguf_tensor_holder"] = weight;
+    auto scales_f16 = std::make_shared(scales);
+    auto weights_f16 = std::make_shared(weights_node, ov::element::f16);
+
+    ov::Output result;
+    if (use_bias && !is_scalar_zp) {
+        // Bias path: w * s + b (zp tensor holds f16 bias values)
+        auto bias_f16 = std::make_shared(zp);
+        auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY);
+        result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY);
+    } else {
+        // Zero point path: (w - zp) * s
+        auto zero_point = std::make_shared(zp);
+        float zp_value;
+        if (ov::op::util::get_single_value(zero_point, zp_value)) {
+            zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value});
+        }
+        auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16);
+        auto w_zp =
+            std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY);
+        result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY);
+    }
+
+    if (packed_shape.size() != 2) {
+        // If not requantized channel-wise case, reshape back to original shape
+        auto final_shape =
+            std::make_shared(ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape);
+        result = std::make_shared(result, final_shape, false);
+    }
+
+    return std::make_shared(result, ov::element::f32);
+}
+
+ov::Output make_int4_weights(ov::Tensor & weight,
+                                       ov::Tensor & scales,
+                                       ov::Tensor & zp,
+                                       size_t group_size,
+                                       bool use_bias) {
+    ov::Shape orig_weight_shape = weight.get_shape();
+
+    // Expand dimensions for scales and zp/bias
+    ov::Shape scale_shape = scales.get_shape();
+    auto zp_shape = zp.get_shape();
+    bool is_scalar_zp = zp_shape.empty();  // Symmetric quantization
+
+    // Create INT4 weight tensor
+    ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size};
+
+    if (packed_shape[1] == 1) {
+        // Requantized channel-wise case
+        packed_shape.erase(packed_shape.begin() + 1);
+    } else {
+        scale_shape.push_back(1);
+        scales.set_shape(scale_shape);
+        // For symmetric quantization, zp remains scalar (don't resize)
+        if (!is_scalar_zp) {
+            zp_shape.push_back(1);
+            zp.set_shape(zp_shape);
+        }
+    }
+
+    auto weights_node = std::make_shared(ov::element::u4, packed_shape,
+                                                               static_cast(weight.data()), nullptr);
+    weights_node->get_rt_info()["__gguf_tensor_holder"] = weight;
+    auto weights_f16 = std::make_shared(weights_node, ov::element::f16);
+    auto scales_f16 = std::make_shared(scales);
+
+    ov::Output result;
+    if (use_bias && !is_scalar_zp) {
+        // Bias path: w * s + b (zp tensor holds f16 bias values)
+        auto bias_f16 = std::make_shared(zp);
+        auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY);
+        result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY);
+    } else {
+        // Zero point path: (w - zp) * s
+        auto zero_points_node = std::make_shared(zp);
+        float zp_value;
+        if (ov::op::util::get_single_value(zero_points_node, zp_value)) {
+            zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value});
+        }
+        auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16);
+        auto w_zp =
+            std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY);
+        result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY);
+    }
+
+    if (packed_shape.size() != 2) {
+        // If not requantized channel-wise case, reshape back to original shape
+        auto final_shape = std::make_shared(ov::element::i64, ov::Shape{orig_weight_shape.size()},
+                                                                  orig_weight_shape);
+        result = std::make_shared(result, final_shape, false);
+    }
+
+    return std::make_shared(result, ov::element::f32);
+}
+
+// Extract quantized weights from tensor and create weight subgraph
+std::shared_ptr extract_quantized_weights(const ggml_tensor * tensor,
+                                                    const void * data,
+                                                    ov::Tensor & weights,
+                                                    ov::Tensor & scales,
+                                                    ov::Tensor & zp,
+                                                    bool use_bias) {
+    // Create a temporary tensor for extraction functions that read from tensor->data
+    ggml_tensor temp_tensor = *tensor;
+    temp_tensor.data = const_cast(data);
+
+    // Determine block size based on tensor type
+    int64_t weights_per_block;
+    bool is_u4;
+    switch (tensor->type) {
+    case GGML_TYPE_Q4_0:
+    case GGML_TYPE_Q4_1:
+    case GGML_TYPE_Q4_K:
+        is_u4 = true;
+        weights_per_block = 32;
+        break;
+    case GGML_TYPE_Q8_0:
+    case GGML_TYPE_Q5_K:
+        is_u4 = false;
+        weights_per_block = 32;
+        break;
+    case GGML_TYPE_Q6_K:
+        is_u4 = false;
+        weights_per_block = 16;
+        break;
+    default:
+        throw std::runtime_error("Unsupported quantized type for extraction: " +
+                                 std::string(ggml_type_name(tensor->type)));
+    }
+
+    // Extract quantized data
+    switch (tensor->type) {
+    case GGML_TYPE_Q4_0:
+        extract_q4_0_data(&temp_tensor, weights, scales, zp);
+        break;
+    case GGML_TYPE_Q4_1:
+        extract_q4_1_data(&temp_tensor, weights, scales, zp, use_bias);
+        break;
+    case GGML_TYPE_Q4_K:
+        extract_q4_k_data(&temp_tensor, weights, scales, zp, use_bias);
+        break;
+    case GGML_TYPE_Q8_0:
+        extract_q8_0_data(&temp_tensor, weights, scales, zp);
+        break;
+    case GGML_TYPE_Q6_K:
+        extract_q6_k_data(&temp_tensor, weights, scales, zp);
+        break;
+    case GGML_TYPE_Q5_K:
+        extract_q5_k_data(&temp_tensor, weights, scales, zp, use_bias);
+        break;
+    default:
+        throw std::runtime_error("Unsupported quantized type: " + std::string(ggml_type_name(tensor->type)));
+    }
+
+    // Create the OpenVINO weight subgraph
+    ov::Output weight_node;
+    if (is_u4) {
+        weight_node = make_int4_weights(weights, scales, zp, weights_per_block, use_bias);
+    } else {
+        weight_node = make_int8_weights(weights, scales, zp, weights_per_block, use_bias);
+    }
+
+    auto result = weight_node.get_node_shared_ptr();
+    result->set_friendly_name(tensor->name);
+    return result;
+}
+
+// Requantize weights to target format, writing to provided buffers
+std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor,
+                                                const void * data,
+                                                ExtraQuantType requant_type,
+                                                int64_t block_size,
+                                                ov::Tensor & weights,
+                                                ov::Tensor & scales,
+                                                ov::Tensor & zp) {
+    int64_t n_elements = ggml_nelements(tensor);
+
+    // First dequantize to F32
+    std::vector weights_f32(n_elements);
+    ggml_get_type_traits(tensor->type)->to_float(data, weights_f32.data(), n_elements);
+
+    // Handle F16 case - just convert and create constant
+    if (requant_type == ExtraQuantType::F16) {
+        ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), n_elements);
+        auto result = std::make_shared(weights);
+        result->set_friendly_name(tensor->name);
+        return result;
+    }
+
+    // Requantize to target quantized format
+    bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128);
+
+    if (is_u4) {
+        quantize_q4_0(weights_f32.data(), weights, scales, zp, n_elements, block_size);
+    } else if (requant_type == ExtraQuantType::Q8_1_C) {
+        quantize_q8_1(weights_f32.data(), weights, scales, zp, n_elements, block_size);
+    } else {
+        quantize_q8_0(weights_f32.data(), weights, scales, zp, n_elements, block_size);
+    }
+
+    // Create the OpenVINO weight subgraph
+    ov::Output weight_node;
+    if (is_u4) {
+        weight_node = make_int4_weights(weights, scales, zp, block_size);
+    } else {
+        weight_node = make_int8_weights(weights, scales, zp, block_size);
+    }
+
+    auto result = weight_node.get_node_shared_ptr();
+    result->set_friendly_name(tensor->name);
+    return result;
+}
+
+OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr, bool use_bias) {
+    GGML_ASSERT(tensor != nullptr);
+    GGML_ASSERT(data != nullptr);
+
+    OvWeight result;
+
+    // Get 2D shape for weights [rows, cols]
+    ov::Shape node_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])};
+
+    // Handle F16/F32/BF16 weights
+    if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
+        ov::element::Type element_type;
+        switch (tensor->type) {
+        case GGML_TYPE_F32:
+            element_type = ov::element::f32;
+            break;
+        case GGML_TYPE_F16:
+            element_type = ov::element::f16;
+            break;
+        case GGML_TYPE_BF16:
+            element_type = ov::element::bf16;
+            break;
+        default:
+            OPENVINO_THROW("Unexpected tensor type in F16/F32/BF16 path");
+        }
+
+        if (output_base_ptr && output_base_ptr != data) {
+            // Using external buffer - copy data and create shared-memory constant
+            size_t tensor_bytes = ggml_nbytes(tensor);
+            memcpy(output_base_ptr, data, tensor_bytes);
+            result.weights = ov::Tensor(element_type, node_shape, output_base_ptr);
+        } else {
+            result.weights = ov::Tensor(element_type, node_shape, data);
+        }
+        result.weight_node = std::make_shared(result.weights);
+        return result;
+    }
+
+    // Handle quantized weights
+    if (!ggml_is_quantized(tensor->type)) {
+        OPENVINO_THROW("Unsupported weight tensor type: ", ggml_type_name(tensor->type));
+    }
+
+    result.layout = ggml_openvino_get_extracted_layout(tensor, use_bias);
+    const auto & layout = result.layout;
+    if (layout.total_size == 0) {
+        OPENVINO_THROW("Unsupported quantized type: ", ggml_type_name(tensor->type));
+    }
+
+    if (use_bias) {
+        OPENVINO_ASSERT(!layout.is_requant,
+                        "use_bias is only used for test-backend-ops, which should not have requantization");
+        // bias node will be created on the fly and not use backend buffer
+        output_base_ptr = nullptr;
+    }
+
+    // F16 requant path - no separate scales/zp needed in result
+    if (layout.is_requant && layout.requant_type.has_value() && layout.requant_type.value() == ExtraQuantType::F16) {
+        if (output_base_ptr) {
+            result.weights = ov::Tensor(ov::element::f16, node_shape,
+                                        static_cast(output_base_ptr) + layout.weights_offset);
+        } else {
+            result.weights = ov::Tensor(ov::element::f16, node_shape);
+        }
+        ov::Tensor dummy_scales, dummy_zp;  // Not used for F16
+        result.weight_node =
+            requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, result.weights, dummy_scales, dummy_zp);
+        return result;
+    }
+
+    // Quantized path (normal extraction or quantized requant)
+    // Create weight/scale/zp tensors - shared between both paths
+    ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8;
+    ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block};
+    ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape;
+
+    if (output_base_ptr) {
+        uint8_t * buf_base = static_cast(output_base_ptr);
+        result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset);
+        result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset);
+        result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset);
+    } else {
+        result.weights = ov::Tensor(weight_type, node_shape);
+        result.scales = ov::Tensor(ov::element::f16, scale_shape);
+        if (use_bias && !layout.is_symmetric) {
+            // bias only has effect for asymmetric quant
+            result.zp = ov::Tensor(ov::element::f16, zp_shape);
+        } else {
+            result.zp = ov::Tensor(weight_type, zp_shape);
+        }
+    }
+
+    if (layout.is_requant && layout.requant_type.has_value()) {
+        result.weight_node = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block,
+                                                   result.weights, result.scales, result.zp);
+    } else {
+        result.weight_node =
+            extract_quantized_weights(tensor, data, result.weights, result.scales, result.zp, use_bias);
+    }
+
+    return result;
+}
+
+void quantize_q4_0(const float * x,
+                   ov::Tensor & weights_arr,
+                   ov::Tensor & scales_arr,
+                   ov::Tensor & zp_arr,
+                   int64_t k,
+                   int64_t qk) {
+    assert(k % qk == 0);
+    const int nb = k / qk;
+
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+    auto * zp = static_cast(zp_arr.data());
+    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization
+
+    // For Q4_0, zero point is always 8
+    if (is_scalar_zp) {
+        zp[0] = 8 | (8 << 4);  // Pack two 4-bit values
+    }
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f;  // absolute max
+        float max = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i * qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max = v;
+            }
+        }
+
+        const float d = max / -8;
+
+        if (d == 0) {
+            scales[i] = ov::float16(1.0f);
+            // zp is already set to 8 for symmetric, or set per-block for asymmetric
+            if (!is_scalar_zp) {
+                if (i % 2 == 0) {
+                    zp[i / 2] = 8;
+                } else {
+                    zp[i / 2] |= (8 << 4);
+                }
+            }
+            memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2);
+            continue;
+        }
+
+        const float id = 1.0f / d;
+        scales[i] = ov::float16(d);
+        // For asymmetric quantization, store per-block zero points
+        if (!is_scalar_zp) {
+            if (i % 2 == 0) {
+                zp[i / 2] = 8;
+            } else {
+                zp[i / 2] |= (8 << 4);
+            }
+        }
+
+        for (int j = 0; j < qk / 2; ++j) {
+            const float x0 = x[i * qk + 2 * j] * id;
+            const float x1 = x[i * qk + 2 * j + 1] * id;
+            const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f));
+            weights[i * qk / 2 + j] = xi0 | (xi1 << 4);
+        }
+    }
+}
+
+void quantize_q8_0(const float * x,
+                   ov::Tensor & weights_arr,
+                   ov::Tensor & scales_arr,
+                   ov::Tensor & zp_arr,
+                   int64_t k,
+                   int64_t qk) {
+    assert(k % qk == 0);
+    const int nb = k / qk;
+
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+    auto * zp = static_cast(zp_arr.data());
+    bool is_scalar_zp = (zp_arr.get_size() == 1);  // Symmetric quantization
+
+    // For Q8_0, zero point is always 128
+    if (is_scalar_zp) {
+        zp[0] = 128;
+    }
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f;  // absolute max
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i * qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+            }
+        }
+
+        const float d = amax / 127.0f;
+        const float id = d ? 1.0f / d : 0.0f;
+        scales[i] = ov::float16(d);
+        // For asymmetric quantization, store per-block zero points
+        if (!is_scalar_zp) {
+            zp[i] = 128;
+        }
+
+        for (int j = 0; j < qk; ++j) {
+            const float x0 = x[i * qk + j] * id;
+            const int8_t xi0 = roundf(x0);
+            weights[i * qk + j] = (uint8_t) (xi0 + 128);
+        }
+    }
+}
+
+void quantize_q8_1(const float * x,
+                   ov::Tensor & weights_arr,
+                   ov::Tensor & scales_arr,
+                   ov::Tensor & zp_arr,
+                   int64_t k,
+                   int64_t qk) {
+    assert(k % qk == 0);
+    const int nb = k / qk;
+
+    auto * weights = static_cast(weights_arr.data());
+    auto * scales = scales_arr.data::value_type>();
+    auto * zp = static_cast(zp_arr.data());
+    for (int i = 0; i < nb; i++) {
+        float min = std::numeric_limits::max();
+        float max = std::numeric_limits::lowest();
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i * qk + j];
+            if (v < min) {
+                min = v;
+            }
+            if (v > max) {
+                max = v;
+            }
+        }
+
+        const float d = (max - min) / ((1 << 8) - 1);
+        const float id = d ? 1.0f / d : 0.0f;
+        scales[i] = ov::float16(d);
+        // zp = -min / scale (Q8_1 is asymmetric)
+        zp[i] = (d != 0.0f) ? (uint8_t) std::round(-min / d) : 0;
+
+        for (int j = 0; j < qk; ++j) {
+            const float x0 = (x[i * qk + j] - min) * id;
+            const uint8_t xi0 = roundf(x0);
+            weights[i * qk + j] = xi0;
+        }
+    }
+}
diff --git a/ggml/src/ggml-openvino/ggml-quants.h b/ggml/src/ggml-openvino/ggml-quants.h
new file mode 100644
index 00000000..e4a02297
--- /dev/null
+++ b/ggml/src/ggml-openvino/ggml-quants.h
@@ -0,0 +1,153 @@
+#pragma once
+#include "ggml-openvino-extra.h"  // For ExtraQuantType
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+
+void unpack_32_4(const uint8_t* data, uint8_t* dst);
+
+void extract_q4_0_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr);
+
+void extract_q4_1_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr,
+                       bool use_bias = false);
+
+void extract_q8_0_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr);
+
+void unpack_256_4(const uint8_t* data, uint8_t* dst);
+
+void extract_q4_k_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr,
+                       bool use_bias = false);
+
+void extract_q5_k_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr,
+                       bool use_bias = false);
+
+void extract_q6_k_data(const ggml_tensor * tensor,
+                       ov::Tensor & weights_arr,
+                       ov::Tensor & scales_arr,
+                       ov::Tensor & zp_arr);
+
+static constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32;
+
+ov::Output make_int8_weights(ov::Tensor & weight,
+                                       ov::Tensor & scales,
+                                       ov::Tensor & zp,
+                                       size_t group_size = GGML_QUANTIZATION_GROUP_SIZE,
+                                       bool use_bias = false);
+
+ov::Output make_int4_weights(ov::Tensor & weight,
+                                       ov::Tensor & scales,
+                                       ov::Tensor & zp,
+                                       size_t group_size = GGML_QUANTIZATION_GROUP_SIZE,
+                                       bool use_bias = false);
+
+// Extract quantized weights from tensor and create weight subgraph
+// If weights/scales/zp are provided (non-empty), uses them as output buffers
+// Otherwise allocates new ov::Tensors internally
+// Returns the weight node (make_int4_weights or make_int8_weights result)
+std::shared_ptr extract_quantized_weights(
+    const ggml_tensor * tensor,
+    const void * data,  // Source data pointer (may differ from tensor->data)
+    ov::Tensor & weights,
+    ov::Tensor & scales,
+    ov::Tensor & zp,
+    bool use_bias = false);  // Use fp bias instead of quantized zero_point (for test-backend-ops)
+
+// Requantize weights from tensor to target format, writing to provided buffers
+// For F16 target, only weights buffer is used (scales/zp ignored)
+// Returns the weight node
+std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor,
+                                                const void * data,  // Source data pointer
+                                                ExtraQuantType requant_type,
+                                                int64_t block_size,
+                                                ov::Tensor & weights,
+                                                ov::Tensor & scales,
+                                                ov::Tensor & zp);
+
+inline const char * extra_quant_type_name(ExtraQuantType t) {
+    switch (t) {
+    case ExtraQuantType::F16:
+        return "F16";
+    case ExtraQuantType::Q4_0_C:
+        return "Q4_0_C";
+    case ExtraQuantType::Q4_0_128:
+        return "Q4_0_128";
+    case ExtraQuantType::Q8_0_C:
+        return "Q8_0_C";
+    case ExtraQuantType::Q8_0_32:
+        return "Q8_0_32";
+    case ExtraQuantType::Q8_1_C:
+        return "Q8_1_C";
+    default:
+        return "unknown";
+    }
+}
+
+// Result from process_weight_tensor containing the weight node and tensors.
+// For quantized weights, also contains the extracted layout and scale/zp tensors.
+struct OvWeight {
+    std::shared_ptr weight_node;
+    ggml_openvino_extracted_layout layout;  // Only meaningful for quantized (layout.total_size > 0)
+    ov::Tensor weights;
+    ov::Tensor scales;
+    ov::Tensor zp;
+
+    bool is_quantized() const { return layout.scales_size > 0; }
+};
+
+// Process weight tensor and create an OpenVINO weight node
+// Handles F16/F32/BF16 and quantized weights, with optional requantization
+// If output_base_ptr is nullptr, allocates internal buffers (for decoder use)
+// If output_base_ptr is provided, uses pre-allocated buffers at specified offsets (for backend buffer use)
+// Returns OvWeight with the weight node and optional quantized tensors
+OvWeight process_weight_tensor(
+    const ggml_tensor * tensor,
+    const void * data,                 // Source data pointer (may differ from tensor->data)
+    void * output_base_ptr = nullptr,  // Base pointer for output buffers (or nullptr for internal allocation)
+    bool use_bias = false);            // Use fp bias instead of quantized zero_point, only used in test-backend-ops
+
+void quantize_q4_0(const float * x,
+                   ov::Tensor & weights_arr,
+                   ov::Tensor & scales_arr,
+                   ov::Tensor & zp_arr,
+                   int64_t k,
+                   int64_t qk);
+void quantize_q8_1(const float * x,
+                   ov::Tensor & weights_arr,
+                   ov::Tensor & scales_arr,
+                   ov::Tensor & zp_arr,
+                   int64_t k,
+                   int64_t qk);
+void quantize_q8_0(const float * x,
+                   ov::Tensor & weights_arr,
+                   ov::Tensor & scales_arr,
+                   ov::Tensor & zp_arr,
+                   int64_t k,
+                   int64_t qk);
+
+namespace ov {
+namespace op {
+namespace util {
+// From /src/common/transformations/include/transformations/utils/utils.hpp
+bool get_single_value(const std::shared_ptr& const_node,
+                      float& value,
+                      bool check_value_range = true);
+}  // namespace util
+}  // namespace op
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/decoder.h b/ggml/src/ggml-openvino/openvino/decoder.h
new file mode 100644
index 00000000..3b8da2be
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/decoder.h
@@ -0,0 +1,74 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+class GgmlDecoder : public DecoderBase {
+public:
+    virtual ov::Any get_attribute(const std::string& name) const = 0;
+
+    virtual PartialShape get_input_shape(int node_idx, const std::string& name) const = 0;
+
+    virtual std::vector get_input_stride(int node_idx, const std::string& name) const = 0;
+
+    virtual element::Type get_input_type(int node_idx, const std::string& name) const = 0;
+
+    virtual size_t get_input_size() const = 0;
+
+    virtual size_t get_input_size(int node_idx) const = 0;
+
+    virtual void get_input_node(size_t input_port_idx,
+                                std::string& producer_name,
+                                std::string& producer_output_port_name,
+                                size_t& producer_output_port_index) const = 0;
+
+    virtual std::vector get_input_names(int node_idx) const = 0;
+
+    virtual PartialShape get_output_shape(int node_idx) const = 0;
+
+    virtual element::Type get_output_type(const int node_idx) const = 0;
+
+    virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0;
+
+    virtual int32_t * get_output_op_params(int node_idx) const = 0;
+
+    virtual std::vector get_output_names(int node_idx) const = 0;
+
+    virtual const std::string& get_op_type() const = 0;
+
+    virtual const std::string& get_op_type(int node_idx) const = 0;
+
+    virtual const std::string& get_op_name() const = 0;
+
+    virtual const std::string& get_op_name(int node_idx) const = 0;
+
+    virtual void visit_subgraph(std::function, int node_idx)> node_visitor) const = 0;
+
+    virtual int get_op_case(int node_idx) const = 0;
+
+    virtual const std::map>& get_model_inputs() const = 0;
+    virtual const std::map>& get_model_extra_inputs() const = 0;
+    virtual const std::map>& get_model_weights() const = 0;
+    virtual std::vector get_model_output_names() const = 0;
+
+    virtual int32_t* get_rope_params() const = 0;
+
+    virtual std::map get_kv_param_res_names() const = 0;
+
+    virtual bool is_static() const = 0;
+
+    virtual bool is_stateful() const = 0;
+
+    virtual int is_swa_layer(int layer) const = 0;
+};
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/frontend.cpp b/ggml/src/ggml-openvino/openvino/frontend.cpp
new file mode 100644
index 00000000..c2ba14e6
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/frontend.cpp
@@ -0,0 +1,27 @@
+#include "frontend.h"
+
+#include "input_model.h"
+#include "op_table.h"
+#include "translate_session.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+FrontEnd::FrontEnd() {}
+
+std::shared_ptr FrontEnd::convert(const InputModel::Ptr & model, bool naive) {
+    auto ggml_model = std::dynamic_pointer_cast(model);
+    FRONT_END_GENERAL_CHECK(ggml_model, "Invalid input model");
+    std::shared_ptr converted_model;
+    const auto & supported_ops = get_supported_ops();
+    {
+        TranslateSession translate_session(model, supported_ops, naive);
+        converted_model = translate_session.get_converted_model();
+    }
+    return converted_model;
+}
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/frontend.h b/ggml/src/ggml-openvino/openvino/frontend.h
new file mode 100644
index 00000000..f1c6f0c3
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/frontend.h
@@ -0,0 +1,23 @@
+// Copyright (C) 2018-2024 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#pragma once
+
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+class FrontEnd {
+public:
+    using Ptr = std::shared_ptr;
+    FrontEnd();
+
+    static std::shared_ptr convert(const InputModel::Ptr& model, bool naive = false);
+};
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/input_model.cpp b/ggml/src/ggml-openvino/openvino/input_model.cpp
new file mode 100644
index 00000000..39b004c9
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/input_model.cpp
@@ -0,0 +1,17 @@
+#include "input_model.h"
+
+#include "decoder.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+InputModel::InputModel(const std::shared_ptr & gdecoder) : m_decoder(gdecoder) {}
+
+const std::shared_ptr & InputModel::get_model_decoder() const {
+    return m_decoder;
+}
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/input_model.h b/ggml/src/ggml-openvino/openvino/input_model.h
new file mode 100644
index 00000000..ce843442
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/input_model.h
@@ -0,0 +1,29 @@
+#pragma once
+
+#include 
+
+#include "decoder.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+class FrontEnd;
+class GgmlDecoder;
+using ov::frontend::ggml::GgmlDecoder;
+
+class InputModel : public ov::frontend::InputModel {
+    friend class ::ov::frontend::ggml::FrontEnd;
+
+public:
+    explicit InputModel(const std::shared_ptr& gdecoder);
+
+    const std::shared_ptr& get_model_decoder() const;
+
+private:
+    std::shared_ptr m_decoder;
+};
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/node_context.h b/ggml/src/ggml-openvino/openvino/node_context.h
new file mode 100644
index 00000000..aa484128
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/node_context.h
@@ -0,0 +1,112 @@
+#pragma once
+
+#include 
+#include 
+#include 
+
+#include "decoder.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+class TranslateSession;
+
+typedef std::map> TensorMap;
+
+class NodeContext : public frontend::NodeContext {
+public:
+    NodeContext(const std::shared_ptr& decoder,
+                std::shared_ptr& tensor_map,
+                int node_idx,
+                TranslateSession* translate_session = nullptr)
+        : ov::frontend::NodeContext(decoder->get_op_type(node_idx)),
+          m_decoder(decoder),
+          m_tensor_map(tensor_map),
+          m_node_idx(node_idx),
+          m_translate_session(translate_session) {
+        m_input_names = decoder->get_input_names(m_node_idx);
+        m_output_names = decoder->get_output_names(m_node_idx);
+    }
+
+    TranslateSession* get_translate_session() const {
+        return m_translate_session;
+    }
+
+    const std::vector& get_input_names() const { return m_input_names; }
+
+    size_t get_input_size() const override {
+        return m_decoder->get_input_size(m_node_idx);
+    }
+
+    ov::element::Type get_input_type(size_t index) const {
+        return m_decoder->get_input_type(m_node_idx, m_input_names[index]);
+    }
+
+    PartialShape get_input_shape(size_t input_index) const {
+        return m_decoder->get_input_shape(m_node_idx, m_input_names[input_index]);
+    }
+
+    std::vector get_input_stride(size_t index) const {
+        return m_decoder->get_input_stride(m_node_idx, m_input_names[index]);
+    }
+
+    std::string get_output_name() const { return m_output_names[0]; }
+
+    PartialShape get_output_shape() const { return m_decoder->get_output_shape(m_node_idx); }
+
+    int32_t* get_input_op_params(size_t index) const {
+        return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]);
+    }
+
+    int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); }
+
+    ov::element::Type get_output_type() const {
+        return m_decoder->get_output_type(m_node_idx);
+    }
+
+    Output get_input(int idx) const override {
+        return m_tensor_map->at(m_input_names[idx]);
+    }
+
+    Output get_input(const std::string& name) const override {
+        if (m_tensor_map->find(name) == m_tensor_map->end()) {
+            throw std::runtime_error("'" + name + "' not found in tensor map.");
+        }
+        return m_tensor_map->at(name);
+    }
+
+    bool has_input(const std::string& name) const {
+        return m_tensor_map->find(name) != m_tensor_map->end();
+    }
+
+    const std::string& get_name() const override {
+        return m_decoder->get_op_name(m_node_idx);
+    }
+
+    ov::Any get_attribute_as_any(const std::string& name) const override {
+        return m_decoder->get_attribute(name);
+    }
+
+    int get_op_case() const {
+        return m_decoder->get_op_case(m_node_idx);
+    }
+
+    bool is_static() const { return m_decoder->is_static(); }
+
+    bool is_stateful() const { return m_decoder->is_stateful(); }
+
+private:
+    std::shared_ptr m_decoder;
+    std::shared_ptr& m_tensor_map;
+    int m_node_idx;
+    TranslateSession* m_translate_session;
+    std::vector m_input_names;
+    std::vector m_output_names;
+};
+
+using CreatorFunction = std::function;
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp
new file mode 100644
index 00000000..6160dd74
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp
@@ -0,0 +1,48 @@
+
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_cont(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    int op_case = context.get_op_case();
+    FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case");
+
+    auto src_shape = context.get_input_shape(0).to_shape();
+    auto dst_shape = context.get_output_shape().to_shape();
+    ov::Output res;
+
+    if (op_case == 1) {
+        // The input comes from a PERMUTE
+        throw std::runtime_error("Code of this case might be outdated");
+        dst_shape[1] = -1;
+        res = std::make_shared(
+            context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false);
+    } else if (op_case == 2) {
+        // The input comes from a TRANSPOSE
+        return {context.get_input(0)};
+    } else {
+        // The input comes from a VIEW
+        res = process_view_input(context, 0);
+    }
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp
new file mode 100644
index 00000000..83111720
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/cpy.cpp
@@ -0,0 +1,21 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_cpy(const NodeContext & context) {
+    auto res = std::make_shared(context.get_input(0), context.get_output_type());
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp
new file mode 100644
index 00000000..42602a73
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp
@@ -0,0 +1,90 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_flash_attn_ext(const NodeContext & context) {
+    num_inputs_check(context, 4, 4);
+    auto q_f32 = context.get_input(0);
+    auto k = context.get_input(1);
+    auto v = context.get_input(2);
+    auto mask = context.get_input(3);
+
+    float * params = reinterpret_cast(context.get_output_op_params());
+    float scale = params[0];
+    // float max_bias      = params[1];
+    // float logit_softcap = params[2];
+
+    auto q = std::make_shared(q_f32, ov::element::f16);
+    auto scale_node = std::make_shared(ov::element::f16, ov::Shape{}, std::vector{scale});
+
+    ov::Output mask_sliced, res;
+    std::string mask_name = "KQ_mask_sliced";
+    if (context.get_input_names()[3].find("swa") != std::string::npos) {
+        mask_name = "KQ_mask_swa_sliced";
+    }
+    if (context.has_input(mask_name)) {
+        mask_sliced = context.get_input(mask_name);
+    } else {
+        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
+        auto token_len = get_dimensions(q, {2});
+        mask_sliced = std::make_shared(mask, zero, token_len, one, two);
+    }
+
+    if (mask_sliced.get_element_type() != ov::element::f16) {
+        mask_sliced = std::make_shared(mask_sliced, ov::element::f16);
+    }
+
+    auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output kv) {
+        int64_t factor = num_heads / num_heads_kv;
+        if (factor > 1 && num_heads_kv > 1) {
+            ov::Output kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
+            auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
+            kv_unsqueezed = std::make_shared(kv, unsqueeze_axes);
+
+            kv_broadcast_shape = ov::op::v0::Constant::create(
+                ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});
+            new_kv_shape =
+                ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size});
+
+            kv = std::make_shared(kv_unsqueezed, kv_broadcast_shape,
+                                                         ov::op::BroadcastType::BIDIRECTIONAL);
+            kv = std::make_shared(kv, new_kv_shape, true);
+        }
+        return kv;
+    };
+
+    auto q_shape = context.get_input_shape(0).to_shape();
+    auto k_shape = context.get_input_shape(1).to_shape();
+    k = tile_kv(q_shape[1], k_shape[1], q_shape[3], k);
+    v = tile_kv(q_shape[1], k_shape[1], q_shape[3], v);
+
+    auto sdpa = std::make_shared(q, k, v, mask_sliced, scale_node, false);
+    res = std::make_shared(sdpa,
+                                                  ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
+    res = std::make_shared(res, ov::element::f32);
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp
new file mode 100644
index 00000000..49f51b7c
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp
@@ -0,0 +1,69 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_get_rows(const NodeContext & context) {
+    num_inputs_check(context, 2, 2);
+
+    int op_case = context.get_op_case();
+
+    Output res;
+    auto data = context.get_input(0);
+    auto indices = context.get_input(1);
+
+    if (op_case == 2) {
+        // The input comes from a VIEW
+        indices = process_view_input(context, 1);
+    }
+
+    // data[1,b,x,y] ind[1,1,b,x'] test-backend-ops case
+    // data[x,y] ind[1,1,1,x'] normal case
+    indices =
+        std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
+    if (data.get_partial_shape().rank() == 4) {
+        if (!(data.get_partial_shape()[1].is_dynamic()) && data.get_partial_shape()[1].get_length() == 1) {
+            // Work-around for a bug in ov cpu plugin for test-backend-ops
+            data = std::make_shared(data,
+                                                         ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
+            auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
+            res = std::make_shared(data, indices, axis);
+        } else {
+            auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
+            data =
+                std::make_shared(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
+            res = std::make_shared(data, indices, axis, 1);
+        }
+    } else if (context.is_stateful() && data.get_partial_shape().rank() == 3) {
+        auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
+        res = std::make_shared(data, indices, axis, 1);
+    } else {
+        auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
+        res = std::make_shared(data, indices, axis);
+    }
+
+    if (res.get_element_type() != context.get_output_type()) {
+        res = std::make_shared(res, context.get_output_type());
+    }
+    if (!(context.is_stateful())) {
+        res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
+    }
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp
new file mode 100644
index 00000000..d9fa4c24
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp
@@ -0,0 +1,61 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_glu_geglu(const NodeContext & context) {
+    num_inputs_check(context, 1, 2);
+
+    ov::Output src0;
+    ov::Output src1;
+    if (context.get_input_size() == 2) {
+        src0 = context.get_input(0);
+        src1 = context.get_input(1);
+    } else {
+        // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2.
+        // Both halves are nc elements; if the dimension is odd, the last element is dropped.
+        // Use Slice instead of Split to handle odd dimensions correctly.
+        auto combined = context.get_input(0);
+        auto combined_shape = combined.get_partial_shape();
+        int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length();
+        int64_t nc = last_dim_val / 2;
+
+        auto axis   = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
+        auto step   = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto stop0  = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
+        auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
+        auto stop1  = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc});
+
+        src0 = std::make_shared(combined, start0, stop0, step, axis);
+        src1 = std::make_shared(combined, start1, stop1, step, axis);
+    }
+
+    int32_t * params = context.get_output_op_params();
+    const int32_t swapped = params[1];
+    if (swapped) {
+        std::swap(src0, src1);
+    }
+
+    auto gelu = std::make_shared(src0);
+    auto res = std::make_shared(gelu, src1);
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp
new file mode 100644
index 00000000..00ed7951
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp
@@ -0,0 +1,62 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_glu_swiglu(const NodeContext & context) {
+    num_inputs_check(context, 1, 2);
+
+    ov::Output src0;
+    ov::Output src1;
+    if (context.get_input_size() == 2) {
+        src0 = context.get_input(0);
+        src1 = context.get_input(1);
+    } else {
+        // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2.
+        // Both halves are nc elements; if the dimension is odd, the last element is dropped.
+        // Use Slice instead of Split to handle odd dimensions correctly.
+        auto combined = context.get_input(0);
+        auto combined_shape = combined.get_partial_shape();
+        int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length();
+        int64_t nc = last_dim_val / 2;
+
+        auto axis   = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
+        auto step   = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto stop0  = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
+        auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
+        auto stop1  = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc});
+
+        src0 = std::make_shared(combined, start0, stop0, step, axis);
+        src1 = std::make_shared(combined, start1, stop1, step, axis);
+    }
+
+    int32_t * params = context.get_output_op_params();
+    const int32_t swapped = params[1];
+    if (swapped) {
+        std::swap(src0, src1);
+    }
+
+    auto sigmoid = std::make_shared(src0);
+    auto silu = std::make_shared(src0, sigmoid);
+    auto res = std::make_shared(silu, src1);
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp
new file mode 100644
index 00000000..38edec85
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp
@@ -0,0 +1,90 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_mulmat(const NodeContext & context) {
+    num_inputs_check(context, 2, 2);
+
+    int op_case = context.get_op_case();
+
+    ov::Output res;
+    ov::Output B = context.get_input(0);
+    ov::Output A = context.get_input(1);
+
+    bool transpose_b = true;
+    if (op_case == 2) {
+        B = B.get_node_shared_ptr()->input_value(0);
+        transpose_b = false;
+    } else if (op_case == 3) {
+        B = process_view_input(context, 0);
+        A = process_view_input(context, 1);
+    }
+    if (A.get_element_type() != B.get_element_type()) {
+        B = std::make_shared(context.get_input(0), context.get_input_type(1));
+    }
+
+    auto B_shape = context.get_input_shape(0).to_shape();
+    auto A_shape = context.get_input_shape(1).to_shape();
+    int64_t A_batch = A_shape[1];
+    int64_t B_batch = B_shape[1];
+
+    auto A_batch_larger = A_batch > B_batch;
+    auto batch_large = A_batch_larger ? A_batch : B_batch;
+    auto batch_small = A_batch_larger ? B_batch : A_batch;
+
+    Output Z = A_batch_larger ? B : A;
+    int64_t factor = batch_large / batch_small;
+    if (factor > 1 && batch_small > 1) {
+        auto batch_large_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{batch_large});
+        auto batch_small_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{batch_small});
+        auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{factor});
+
+        auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
+        auto Z_unsqueezed = std::make_shared(Z, unsqueeze_axes);
+
+        auto broadcast_shape = ov::op::v0::Constant::create(
+            ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});
+        auto new_Z_shape = ov::op::v0::Constant::create(ov::element::i64, {4},
+                                                        {(int64_t) 0, batch_large, (int64_t) -1, (int64_t) A_shape[3]});
+
+        auto Z_broadcasted = std::make_shared(Z_unsqueezed, broadcast_shape,
+                                                                     ov::op::BroadcastType::BIDIRECTIONAL);
+        Z = std::make_shared(Z_broadcasted, new_Z_shape, true);
+    }
+    if (A_batch_larger) {
+        B = Z;
+    } else {
+        A = Z;
+    }
+
+    res = std::make_shared(A, B, false, transpose_b);
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp
new file mode 100644
index 00000000..4c800f9e
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp
@@ -0,0 +1,102 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_permute(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    int op_case = context.get_op_case();
+    FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4,
+                                "Unsupported PERMUTE case");
+
+    ov::Output res;
+    auto src = context.get_input(0);
+    auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});
+
+    if (op_case == 1 || context.is_stateful()) {
+        res = std::make_shared(src, perm);
+    } else if (op_case == 4) {
+        auto output_shape = context.get_output_shape().to_shape();
+        auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]});
+        auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
+        auto n_seq_active = context.has_input("n_seq_active") ?
+                                context.get_input("n_seq_active") :
+                                ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[0]});
+        auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
+
+        auto new_shape =
+            std::make_shared(ov::OutputVector{n_seq_active, neg_one, n_heads, head_size}, 0);
+
+        // // Alternative
+        // auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        // auto new_shape = std::make_shared(ov::OutputVector{n_seq_active, neg_one, zero, zero}, 0);
+
+        auto reshaped = std::make_shared(src, new_shape, true);
+        res = std::make_shared(reshaped, perm);
+    } else {
+        auto cache_shape = src.get_partial_shape();
+        auto output_shape = context.get_output_shape().to_shape();
+        int64_t head_size = output_shape[3];
+        int64_t n_heads = output_shape[1];
+        int64_t ctx_per_seq = cache_shape[2].is_static() ? cache_shape[2].get_length() : -1;
+        int64_t n_seq = cache_shape[1].get_length();
+
+        Output attention_size;
+        if (!context.has_input("attention_size")) {
+            attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]});
+        } else if (op_case == 2) {
+            attention_size = context.get_input("attention_size");
+        } else {
+            attention_size = context.get_input("attention_size_swa");
+        }
+
+        Output seq_active_start;
+        Output seq_active_end;
+        if (context.has_input("seq_active_start")) {
+            seq_active_start = context.get_input("seq_active_start");
+            seq_active_end = context.get_input("seq_active_end");
+        } else {
+            int64_t n_seq_active = output_shape[0];
+            size_t offset = *((size_t *) context.get_input_op_params(0));
+            int64_t seq_active_start_val = offset / context.get_input_stride(0)[0];
+            int64_t seq_active_end_val = seq_active_start_val + n_seq_active;
+            seq_active_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_start_val});
+            seq_active_end = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_end_val});
+        }
+
+        // 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size]
+        // 2. slice out the active sequences
+        // 3. slice out the attention part in each sequence
+        // 4. permute
+        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+
+        auto src_reshaped = std::make_shared(
+            src, ov::op::v0::Constant::create(ov::element::i64, {4}, {n_seq, ctx_per_seq, n_heads, head_size}), false);
+        auto slice1 = std::make_shared(src_reshaped, seq_active_start, seq_active_end, one, zero);
+        auto slice2 = std::make_shared(slice1, zero, attention_size, one, one);
+        res = std::make_shared(slice2, perm);
+    }
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp
new file mode 100644
index 00000000..efd9a5a8
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp
@@ -0,0 +1,83 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_reshape(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+    if (context.get_input_shape(0) == context.get_output_shape()) {
+        return {context.get_input(0)};
+    }
+
+    int op_case = context.get_op_case();
+    FRONT_END_CHECK_IMPLEMENTED(
+        op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6,
+        "Unsupported RESHAPE case");
+
+    auto output_shape = context.get_output_shape().to_shape();
+    std::shared_ptr new_shape_node;
+    if (op_case == 1) {
+        if (context.is_stateful()) {
+            new_shape_node = ov::op::v0::Constant::create(
+                ov::element::i64, {3},
+                std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
+        } else {
+            new_shape_node = ov::op::v0::Constant::create(
+                ov::element::i64, {4},
+                std::vector{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
+        }
+    } else if (op_case == 2) {
+        new_shape_node = ov::op::v0::Constant::create(
+            ov::element::i64, {4},
+            std::vector{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, (int64_t) output_shape[3]});
+
+    } else if (op_case == 3) {
+        throw std::runtime_error("might be outdated RESHAPE case");
+        new_shape_node = ov::op::v0::Constant::create(
+            ov::element::i64, {4}, std::vector{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, 1});
+
+    } else if (op_case == 4) {
+        return {context.get_input(0).get_node_shared_ptr()->input_value(0)};
+
+    } else if (op_case == 5) {
+        if (context.is_stateful()) {
+            std::vector shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
+            new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec);
+        } else {
+            std::vector shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
+            new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
+        }
+
+        // // Alternative
+        // auto token_len = context.get_input("token_len");
+        // auto emb_size =
+        //     ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape().to_shape()[3]});
+        // auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        // new_shape_node = std::make_shared(ov::OutputVector{one, one, token_len, emb_size}, 0);
+
+    } else if (op_case == 6) {
+        new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape().to_shape());
+    }
+    auto res = std::make_shared(context.get_input(0), new_shape_node, false);
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp
new file mode 100644
index 00000000..72cf9228
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp
@@ -0,0 +1,46 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_rms_norm(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    auto input_node = context.get_input(0);
+    auto square = std::make_shared(
+        input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f}));
+
+    auto mean = std::make_shared(
+        square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true);
+
+    float eps;
+    memcpy(&eps, context.get_output_op_params(), sizeof(float));
+
+    auto rms = std::make_shared(
+        std::make_shared(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps})));
+
+    auto reciprocal =
+        std::make_shared(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms);
+
+    auto res = std::make_shared(input_node, reciprocal);
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp
new file mode 100644
index 00000000..26dc2d24
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp
@@ -0,0 +1,123 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_rope(const NodeContext & context) {
+    num_inputs_check(context, 2, 3);
+
+    int op_case = context.get_op_case();
+
+    ov::Output res;
+
+    auto data_node = context.get_input(0).get_node_shared_ptr();
+    auto output_shape = context.get_output_shape().to_shape();
+    int32_t * op_params = context.get_output_op_params();
+
+    Output cos_theta_node;
+    Output sin_theta_node;
+    if (context.has_input("rope_cos")) {
+        cos_theta_node = context.get_input("rope_cos");
+        sin_theta_node = context.get_input("rope_sin");
+    } else {
+        auto inp_pos = context.get_input(1).get_node_shared_ptr();
+        std::shared_ptr rope_freqs_weight;
+        if (context.get_input_size() == 3) {
+            rope_freqs_weight = context.get_input(2).get_node_shared_ptr();
+        }
+        auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight);
+        sin_theta_node = sin_cos.first;
+        cos_theta_node = sin_cos.second;
+    }
+
+    if (op_case == 2) {
+        // The input comes from a VIEW
+        int slice_len = output_shape[2] * output_shape[3];
+        data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr();
+        if (context.is_stateful()) {
+            auto data_shape = ov::op::v0::Constant::create(
+                ov::element::i64, {3}, std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
+            data_node = std::make_shared(data_node, data_shape, false);
+        } else {
+            auto data_shape = ov::op::v0::Constant::create(
+                ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
+            data_node = std::make_shared(data_node, data_shape, false);
+        }
+    }
+
+    const int mode = op_params[2];
+    constexpr int ROPE_TYPE_NORMAL = 0;
+    constexpr int ROPE_TYPE_NEOX = 2;
+
+    if (mode == ROPE_TYPE_NORMAL) {
+        auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
+        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
+        auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
+        Output even_slice;
+        Output odd_slice;
+        int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;
+        even_slice = std::make_shared(data_node, zero, end, two, neg_one);
+        odd_slice = std::make_shared(data_node, one, end, two, neg_one);
+
+        Output first_half =
+            std::make_shared(std::make_shared(even_slice, cos_theta_node),
+                                                   std::make_shared(odd_slice, sin_theta_node));
+        Output second_half =
+            std::make_shared(std::make_shared(even_slice, sin_theta_node),
+                                              std::make_shared(odd_slice, cos_theta_node));
+
+        first_half = std::make_shared(first_half,
+                                                             ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));
+        second_half = std::make_shared(second_half,
+                                                              ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));
+        auto stack = std::make_shared(OutputVector{first_half, second_half}, unsqueeze_dim);
+
+        auto data_shape = ov::op::v0::Constant::create(
+            ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
+        res = std::make_shared(stack, data_shape, false);
+    } else if (mode == ROPE_TYPE_NEOX) {
+        auto data_split = std::make_shared(
+            data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);
+        Output slice_data_node_0 = data_split->outputs()[0];
+        Output slice_data_node_1 = data_split->outputs()[1];
+
+        auto first_half_node = std::make_shared(
+            std::make_shared(slice_data_node_0, cos_theta_node),
+            std::make_shared(slice_data_node_1, sin_theta_node));
+
+        auto second_half_node = std::make_shared(
+            std::make_shared(slice_data_node_0, sin_theta_node),
+            std::make_shared(slice_data_node_1, cos_theta_node));
+
+        res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1);
+    }
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/scale.cpp b/ggml/src/ggml-openvino/openvino/op/scale.cpp
new file mode 100644
index 00000000..0f3d800c
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/scale.cpp
@@ -0,0 +1,41 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_scale(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    float scale;
+    float bias;
+    memcpy(&scale, (float *) context.get_output_op_params() + 0, sizeof(float));
+    memcpy(&bias, (float *) context.get_output_op_params() + 1, sizeof(float));
+
+    auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale});
+    auto scaled = std::make_shared(context.get_input(0), scale_node);
+
+    std::shared_ptr res;
+    if (bias != 0.0f) {
+        auto bias_node =
+            std::make_shared(ov::element::f32, ov::Shape{}, std::vector{bias});
+        res = std::make_shared(scaled, bias_node);
+    } else {
+        res = scaled;
+    }
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp
new file mode 100644
index 00000000..136e4265
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp
@@ -0,0 +1,76 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_set_rows(const NodeContext & context) {
+    num_inputs_check(context, 3, 3);
+
+    auto data = context.get_input(0);
+    auto indices = context.get_input(1);
+    auto dst = context.get_input(2);
+
+    data = std::make_shared(data, context.get_output_type());
+
+    auto dst_shape = context.get_output_shape().to_shape();
+
+    auto ind_squeezed =
+        std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 1, 2}));
+    auto data_reshaped = std::make_shared(
+        data,
+        ov::op::v0::Constant::create(ov::element::i64, {4},
+                                     {(int64_t) 1, (int64_t) 1, (int64_t) -1, (int64_t) dst_shape[3]}),
+        false);
+    auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});
+
+    Output res;
+    if (context.is_stateful()) {
+        int concat_axis = 1;
+        int64_t dim2 = dst.get_partial_shape()[2].get_length();
+        int64_t dim3 = dst.get_partial_shape()[3].get_length();
+        data = std::make_shared(
+            data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);
+        res = std::make_shared(OutputVector{dst, data}, concat_axis);
+    } else {
+        res = std::make_shared(dst, ind_squeezed, data_reshaped, axes);
+    }
+
+    if (auto dst_reshape = std::dynamic_pointer_cast(dst.get_node_shared_ptr())) {
+        // Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb]
+        // ctx_per_seq is not fixed due to llama-bench compatibility
+        auto dst_shape_partial = dst_reshape->get_input_partial_shape(0);
+        std::vector dst_shape = {dst_shape_partial[0].get_length(), dst_shape_partial[1].get_length(),
+                                          dst_shape_partial[2].is_static() ? dst_shape_partial[2].get_length() : -1,
+                                          dst_shape_partial[3].get_length()};
+        res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {4}, dst_shape),
+                                                    false);
+    }
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/softmax.cpp b/ggml/src/ggml-openvino/openvino/op/softmax.cpp
new file mode 100644
index 00000000..9f633086
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/softmax.cpp
@@ -0,0 +1,89 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_soft_max(const NodeContext & context) {
+    // TODO code is outdated
+    num_inputs_check(context, 1, 2);
+
+    auto input_node = context.get_input(0).get_node_shared_ptr();
+    ov::Output res;
+
+    float scale = 1.0f;
+    float max_bias = 0.0f;
+    auto * op_params = context.get_output_op_params();
+    memcpy(&scale, (float *) op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) op_params + 1, sizeof(float));
+    auto src0_shape = context.get_input_shape(0).get_shape();
+    const uint32_t h = src0_shape[2];
+    const uint32_t n_head = src0_shape[0];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+    const float slope =
+        (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
+
+    auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale});
+    auto scaled_input = std::make_shared(input_node, scale_node);
+
+    if (context.get_input_size() < 2) {
+        res = std::make_shared(scaled_input, 2);
+        return rename_outputs_with_suffix({res}, context.get_name());
+    }
+
+    ov::Output mask_node_sliced;
+    if (context.has_input("KQ_mask_sliced")) {
+        mask_node_sliced = context.get_input("KQ_mask_sliced");
+    } else {
+        auto token_len = get_dimensions(input_node, {1});
+        auto mask_node = context.get_input(1);
+        auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        mask_node_sliced = std::make_shared(mask_node, zero, token_len, one, one);
+    }
+
+    if (mask_node_sliced.get_element_type() != context.get_output_type()) {
+        mask_node_sliced = std::make_shared(mask_node_sliced, context.get_output_type());
+    }
+
+    Output slope_mask;
+    if (slope != 1.0f) {
+        auto slope_node =
+            std::make_shared(ov::element::f32, ov::Shape{}, std::vector{slope});
+        slope_mask = std::make_shared(mask_node_sliced, slope_node);
+        throw std::runtime_error("Slope != 1.0f in softmax has not been tested, verify it before use.");
+    }
+    slope_mask = mask_node_sliced;
+
+    auto input_slope_mask_node = std::make_shared(scaled_input, slope_mask);
+
+    res = std::make_shared(input_slope_mask_node, 2);
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp
new file mode 100644
index 00000000..8e62e83c
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp
@@ -0,0 +1,23 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_transpose(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    auto res = std::make_shared(
+        context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2}));
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp
new file mode 100644
index 00000000..037e0b94
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp
@@ -0,0 +1,27 @@
+#include "../node_context.h"
+#include "../op_table.h"
+#include "../utils.h"
+
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_unary_silu(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    auto input = context.get_input(0);
+    auto sigmoid = std::make_shared(input);
+    auto res = std::make_shared(input, sigmoid);
+
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op/view.cpp b/ggml/src/ggml-openvino/openvino/op/view.cpp
new file mode 100644
index 00000000..8528d252
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op/view.cpp
@@ -0,0 +1,53 @@
+#include "../op_table.h"
+#include "../utils.h"
+#include 
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace op {
+
+OutputVector translate_view(const NodeContext & context) {
+    num_inputs_check(context, 1, 1);
+
+    if (context.get_op_case() == 2) {
+        auto dst_shape = context.get_output_shape().to_shape();
+        return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[2] * dst_shape[3])},
+                                          context.get_name());
+    }
+    // op_case 3
+    if (context.get_op_case() == 3) {
+        auto input = context.get_input(0);
+        auto input_ov_shape = input.get_partial_shape();
+
+        auto input_llama_shape = context.get_input_shape(0).to_shape();
+
+        // if the input ov shape size is different from the input llama shape size, it means the input is already reshaped and we need to reshape it back to the original shape before slicing
+        if (input_ov_shape.size() != input_llama_shape.size()) {
+            input = std::make_shared(input, ov::op::v0::Constant::create(ov::element::i64, {input_llama_shape.size()}, input_llama_shape), false);
+        }
+
+        auto dst_shape = context.get_output_shape().to_shape();
+
+        // find the index of dst_shape that is different from input shape, and use that index to slice the input
+        int slice_dim = -1;
+        for (size_t i = 0; i < dst_shape.size(); ++i) {
+            if (dst_shape[i] != input_llama_shape[i]) {
+                slice_dim = i;
+                break;
+            }
+        }
+
+        auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+        auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {dst_shape[slice_dim]});
+        auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+        auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_dim});
+        auto sliced = std::make_shared(input, begin, end, stride, axes);
+        return {sliced};
+    }
+    return {context.get_input(0)};
+}
+
+}  // namespace op
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp
new file mode 100644
index 00000000..beadafe8
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op_table.cpp
@@ -0,0 +1,46 @@
+#include "op_table.h"
+
+#include "utils.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+std::unordered_map get_supported_ops() {
+    using namespace ov::op;
+    return {
+        {"GGML_OP_ADD",            op::translate_1to1_match_2_inputs     },
+        {"GGML_OP_ADD1",           op::translate_1to1_match_2_inputs     },
+        {"GGML_OP_CONT",           op::translate_cont                             },
+        {"GGML_OP_DIV",            op::translate_1to1_match_2_inputs  },
+        {"GGML_OP_GET_ROWS",       op::translate_get_rows                         },
+        {"GGML_OP_MUL",            op::translate_1to1_match_2_inputs},
+        {"GGML_OP_MUL_MAT",        op::translate_mulmat                           },
+        {"GGML_OP_PERMUTE",        op::translate_permute                          },
+        {"GGML_OP_RESHAPE",        op::translate_reshape                          },
+        {"GGML_OP_RMS_NORM",       op::translate_rms_norm                         },
+        {"GGML_OP_ROPE",           op::translate_rope                             },
+        {"GGML_OP_SCALE",          op::translate_scale                            },
+        {"GGML_OP_SOFT_MAX",       op::translate_soft_max                         },
+        {"GGML_OP_SUB",            op::translate_1to1_match_2_inputs},
+        {"GGML_OP_TRANSPOSE",      op::translate_transpose                        },
+        {"GGML_UNARY_OP_SILU",     op::translate_unary_silu                       },
+        {"GGML_OP_VIEW",           op::translate_view                             },
+        {"GGML_GLU_OP_SWIGLU",     op::translate_glu_swiglu                       },
+        {"GGML_GLU_OP_GEGLU",      op::translate_glu_geglu                        },
+        {"GGML_OP_SET_ROWS",       op::translate_set_rows                         },
+        {"GGML_OP_CPY",            op::translate_cpy                              },
+        {"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext                   },
+    };
+}
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h
new file mode 100644
index 00000000..37f76311
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/op_table.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include "node_context.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+namespace op {
+
+#define GGML_OP_CONVERTER(op) OutputVector op(const NodeContext& context)
+
+GGML_OP_CONVERTER(translate_add);
+GGML_OP_CONVERTER(translate_cont);
+GGML_OP_CONVERTER(translate_get_rows);
+GGML_OP_CONVERTER(translate_mul);
+GGML_OP_CONVERTER(translate_mulmat);
+GGML_OP_CONVERTER(translate_permute);
+GGML_OP_CONVERTER(translate_reshape);
+GGML_OP_CONVERTER(translate_rms_norm);
+GGML_OP_CONVERTER(translate_rope);
+GGML_OP_CONVERTER(translate_scale);
+GGML_OP_CONVERTER(translate_unary_silu);
+GGML_OP_CONVERTER(translate_soft_max);
+GGML_OP_CONVERTER(translate_transpose);
+GGML_OP_CONVERTER(translate_view);
+GGML_OP_CONVERTER(translate_glu_swiglu);
+GGML_OP_CONVERTER(translate_glu_geglu);
+GGML_OP_CONVERTER(translate_set_rows);
+GGML_OP_CONVERTER(translate_cpy);
+GGML_OP_CONVERTER(translate_flash_attn_ext);
+
+} // namespace op
+
+std::unordered_map get_supported_ops();
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp
new file mode 100644
index 00000000..ed2a3ab6
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp
@@ -0,0 +1,123 @@
+#include "eliminate_zp.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace pass {
+
+EliminateZeroPoints::EliminateZeroPoints() {
+    // Find pattern:
+    // (Multiply Any(scale)
+    //           (Subtract (Convert Constant(data)))
+    //                     (Convert Constant(zero_point)))
+    // where zero_point is a scalar
+    // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val
+    // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant
+
+    auto m_data_constant = ov::pass::pattern::wrap_type();
+    auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant});
+
+    auto m_zp_constant = ov::pass::pattern::wrap_type();
+    auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant});
+
+    auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert});
+    auto m_scale = ov::pass::pattern::any_input();
+    auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract});
+
+    const auto callback = [=](ov::pass::pattern::Matcher & m) {
+        const auto & pattern_map = m.get_pattern_value_map();
+
+        auto multiply_node =
+            std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr());
+        auto subtract_node =
+            std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr());
+        auto data_constant =
+            std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr());
+        auto zp_constant =
+            std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr());
+
+        if (!multiply_node || !subtract_node || !data_constant || !zp_constant) {
+            return false;
+        }
+
+        if (ov::shape_size(zp_constant->get_shape()) != 1) {
+            return false;
+        }
+
+        auto data_type = data_constant->get_element_type();
+        auto zp_data = zp_constant->cast_vector();
+
+        if (zp_data.empty()) {
+            return false;
+        }
+
+        int zp_value = zp_data[0];
+
+        bool should_eliminate = false;
+        ov::element::Type target_type;
+
+        if (data_type == ov::element::u4 && zp_value == 8) {
+            should_eliminate = true;
+            target_type = ov::element::i4;
+        } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) {
+            should_eliminate = true;
+            target_type = ov::element::i8;
+        }
+
+        if (!should_eliminate) {
+            return false;
+        }
+
+        auto data_shape = data_constant->get_shape();
+        size_t total_elements = ov::shape_size(data_shape);
+
+        std::shared_ptr new_constant;
+
+        // TODO improve performance
+        if (data_type == ov::element::u4) {
+            auto data_values = data_constant->cast_vector();
+            std::vector adjusted_values(total_elements);
+
+            ov::parallel_for(total_elements, [&](size_t i) {
+                adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8);
+            });
+
+            new_constant = std::make_shared(target_type, data_shape, adjusted_values);
+        } else if (data_type == ov::element::u8) {
+            auto data_values = data_constant->cast_vector();
+            std::vector adjusted_values(total_elements);
+
+            ov::parallel_for(total_elements, [&, zp_value](size_t i) {
+                adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value);
+            });
+
+            new_constant = std::make_shared(target_type, data_shape, adjusted_values);
+        }
+
+        auto new_convert =
+            std::make_shared(new_constant, subtract_node->get_output_element_type(0));
+        ov::replace_node(subtract_node, new_convert);
+
+        return true;
+    };
+
+    register_matcher(
+        std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"),
+        callback);
+}
+
+}  // namespace pass
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h
new file mode 100644
index 00000000..edd3cd71
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h
@@ -0,0 +1,17 @@
+#include "openvino/pass/matcher_pass.hpp"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace pass {
+
+class EliminateZeroPoints : public ov::pass::MatcherPass {
+public:
+    OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints")
+    EliminateZeroPoints();
+};
+
+}  // namespace pass
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp
new file mode 100644
index 00000000..0671542e
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp
@@ -0,0 +1,60 @@
+#include "fuse_to_sdpa.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace pass {
+
+FuseToSDPA::FuseToSDPA() {
+    // Not maintained since FLASH_ATTN_EXT has replaced this pattern
+    const auto m_k = ov::pass::pattern::any_input();
+    const auto m_q = ov::pass::pattern::any_input();
+    const auto m_qk = ov::pass::pattern::wrap_type({m_q, m_k});
+    const auto m_qk_f32 = ov::pass::pattern::wrap_type({m_qk});
+    const auto m_scale = ov::pass::pattern::any_input();
+    const auto m_scaled_qk = ov::pass::pattern::wrap_type({m_qk_f32, m_scale});
+    const auto m_mask = ov::pass::pattern::any_input();
+    const auto m_masked_qk = ov::pass::pattern::wrap_type({m_scaled_qk, m_mask});
+    const auto m_softmax_qk = ov::pass::pattern::wrap_type({m_masked_qk});
+    const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type({m_softmax_qk});
+    const auto m_v = ov::pass::pattern::any_input();
+    const auto m_qkv = ov::pass::pattern::wrap_type({m_softmax_qk_f16, m_v});
+
+    const auto callback = [=](ov::pass::pattern::Matcher & m) {
+        auto & pattern_to_output = m.get_pattern_value_map();
+        auto k = pattern_to_output[m_k];
+        auto q = pattern_to_output[m_q];
+        auto v = pattern_to_output[m_v];
+        auto mask = pattern_to_output[m_mask];
+        auto scale = pattern_to_output[m_scale];
+
+        auto mask_f16 = register_new_node(mask, ov::element::f16);
+        auto scale_f16 = register_new_node(scale, ov::element::f16);
+        auto sdpa = std::make_shared(q, k, v, mask_f16, scale_f16, false);
+
+        ov::replace_node(m.get_match_root(), sdpa);
+        ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
+
+        return true;
+    };
+    register_matcher(std::make_shared(m_qkv, "ov::frontend::ggml::pass::FuseToSDPA"),
+                     callback);
+}
+
+}  // namespace pass
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h
new file mode 100644
index 00000000..8b5164d2
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h
@@ -0,0 +1,17 @@
+#include "openvino/pass/matcher_pass.hpp"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace pass {
+
+class FuseToSDPA : public ov::pass::MatcherPass {
+public:
+    OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::FuseToSDPA")
+    FuseToSDPA();
+};
+
+}  // namespace pass
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h
new file mode 100644
index 00000000..b9538561
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h
@@ -0,0 +1,29 @@
+#pragma once
+
+#include "mark_decompression_convert_constant_folding.h"
+#include "openvino/pass/matcher_pass.hpp"
+#include "openvino/core/visibility.hpp"
+
+#ifdef OPENVINO_STATIC_LIBRARY
+#    define TRANSFORMATIONS_API
+#else
+#    ifdef IMPLEMENT_OPENVINO_API
+#        define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS
+#    else
+#        define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS
+#    endif  // IMPLEMENT_OPENVINO_API
+#endif      // OPENVINO_STATIC_LIBRARY
+
+namespace ov {
+namespace pass {
+
+class TRANSFORMATIONS_API MarkCompressedFloatConstants;
+
+}  // namespace pass
+}  // namespace ov
+
+class ov::pass::MarkCompressedFloatConstants : public MatcherPass {
+public:
+    OPENVINO_MATCHER_PASS_RTTI("MarkCompressedFloatConstants")
+    MarkCompressedFloatConstants();
+};
diff --git a/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp
new file mode 100644
index 00000000..20a3a374
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp
@@ -0,0 +1,58 @@
+#include "squeeze_matmul.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace opp = ov::pass::pattern;
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace pass {
+
+// For quantized models, NPUW expects the activation to be 3d in DQ(DynamicQuantization) opt, e.g. DQMatMulGQ2i
+SqueezeMatmul::SqueezeMatmul() {
+    auto m_act = opp::any_input();
+    auto m_wei = opp::any_input();
+    auto m_matmul = opp::wrap_type({m_act, m_wei});
+
+    const auto callback = [=](ov::pass::pattern::Matcher & m) {
+        const auto & pattern_map = m.get_pattern_value_map();
+        auto matmul_node =
+            std::dynamic_pointer_cast(pattern_map.at(m_matmul).get_node_shared_ptr());
+        auto act = pattern_map.at(m_act);
+        auto wei = pattern_map.at(m_wei);
+        auto act_shape = act.get_partial_shape();
+        auto wei_shape = wei.get_partial_shape();
+        if (act_shape.rank().is_dynamic() || wei_shape.rank().is_dynamic()) {
+            return false;
+        }
+        if (act_shape.rank().get_length() == 4 && wei_shape.rank().get_length() == 2) {
+            auto axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
+            auto squeezed_act = std::make_shared(act, axis);
+            auto new_matmul = std::make_shared(squeezed_act, wei, matmul_node->get_transpose_a(),
+                                                                   matmul_node->get_transpose_b());
+            auto unsqueezed_output = std::make_shared(new_matmul, axis);
+            unsqueezed_output->set_friendly_name(matmul_node->get_friendly_name());
+            ov::copy_runtime_info(matmul_node, {squeezed_act, new_matmul, unsqueezed_output});
+            ov::replace_node(matmul_node, unsqueezed_output);
+            return true;
+        }
+        return false;
+    };
+
+    register_matcher(std::make_shared(m_matmul, "ov::frontend::ggml::pass::SqueezeMatmul"),
+                     callback);
+}
+
+}  // namespace pass
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h
new file mode 100644
index 00000000..f8fbc69d
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h
@@ -0,0 +1,17 @@
+#include "openvino/pass/matcher_pass.hpp"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+namespace pass {
+
+class SqueezeMatmul : public ov::pass::MatcherPass {
+public:
+    OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::SqueezeMatmul")
+    SqueezeMatmul();
+};
+
+}  // namespace pass
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp
new file mode 100644
index 00000000..23a1dea2
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp
@@ -0,0 +1,293 @@
+#include "translate_session.h"
+
+#include "ggml-openvino/openvino/node_context.h"
+#include "ggml-openvino/openvino/utils.h"
+#include "input_model.h"
+#include "pass/eliminate_zp.h"
+#include "pass/mark_decompression_convert_constant_folding.h"
+#include "pass/squeeze_matmul.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+using namespace ov::op;
+
+namespace {
+
+ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
+    const std::shared_ptr & model,
+    const std::map & kv_param_res_names) {
+    ov::pass::MakeStateful::ParamResPairs pairs;
+    const auto & params = model->get_parameters();
+    const auto & results = model->get_results();
+
+    for (const auto & param_res : kv_param_res_names) {
+        const auto & param_name = param_res.first;
+        const auto & res_name = param_res.second;
+
+        auto param_it = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr & node) {
+            return node->get_friendly_name() == param_name;
+        });
+
+        OPENVINO_ASSERT(param_it != params.end(), "The tensor name ", param_name,
+                        " is not associated with any of "
+                        "Parameters in the network.");
+
+        auto res_it = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr & node) {
+            return node->get_friendly_name() == res_name;
+        });
+
+        OPENVINO_ASSERT(res_it != results.end(), "The tensor name ", res_name,
+                        " is not associated with any of "
+                        "Results in the network.");
+
+        std::shared_ptr param = *param_it;
+        std::shared_ptr res = *res_it;
+        pairs.emplace_back(param, res);
+    }
+    return pairs;
+}
+
+void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
+
+    auto create_sliced_mask = [&](const std::string & mask_name, const std::string & sliced_name, bool is_static) {
+        if ((tensor_map.find(mask_name) != tensor_map.end()) &&
+            (tensor_map.find("token_len_per_seq") != tensor_map.end())) {
+            auto token_len_per_seq = tensor_map.at("token_len_per_seq").get_node_shared_ptr();
+            auto mask = tensor_map.at(mask_name).get_node_shared_ptr();
+            std::shared_ptr mask_sliced;
+            if (is_static) {
+                mask_sliced = mask;
+            } else if (ggml_model_decoder.is_stateful()) {
+                auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
+                auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
+                auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+                auto three_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
+                auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
+                auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1});
+                auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
+                auto gather_inp_pos = std::make_shared(inp_pos, neg_one_1d, three_1d);
+                auto reshaped_inp_pos = std::make_shared(gather_inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {1}), false);
+                auto inp_pos_incremented = std::make_shared(reshaped_inp_pos, ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1}));
+                auto stop = std::make_shared(ov::OutputVector{token_len_per_seq, std::make_shared(inp_pos_incremented, token_len_per_seq)}, 0);
+                mask_sliced =
+                    std::make_shared(mask, zero_2d, stop, one_2d, axes);
+                mask_sliced = std::make_shared(mask_sliced, ov::element::f16);
+                mask_sliced->set_friendly_name(sliced_name);
+            } else {
+                auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
+                auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+                auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
+                mask_sliced = std::make_shared(mask, zero, token_len_per_seq, one, two);
+                mask_sliced = std::make_shared(mask_sliced, ov::element::f16);
+                mask_sliced->set_friendly_name(sliced_name);
+            }
+            tensor_map.insert({sliced_name, mask_sliced->output(0)});
+        }
+    };
+
+    create_sliced_mask("self_kq_mask", "KQ_mask_sliced", ggml_model_decoder.is_static());
+    create_sliced_mask("self_kq_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
+}
+
+void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
+    int32_t * rope_params = ggml_model_decoder.get_rope_params();
+    if (tensor_map.find("inp_pos") == tensor_map.end() || rope_params == nullptr) {
+        return;
+    }
+    auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
+    std::shared_ptr rope_freqs_weight;
+    if (tensor_map.find("rope_freqs.weight") != tensor_map.end()) {
+        rope_freqs_weight = tensor_map.at("rope_freqs.weight").get_node_shared_ptr();
+    }
+
+    auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight);
+    auto sin_theta = sin_cos.first;
+    auto cos_theta = sin_cos.second;
+
+    cos_theta.get_node_shared_ptr()->set_friendly_name("rope_cos");
+    sin_theta.get_node_shared_ptr()->set_friendly_name("rope_sin");
+    tensor_map.insert({"rope_cos", cos_theta});
+    tensor_map.insert({"rope_sin", sin_theta});
+}
+
+// Create common patterns
+void preprocess(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
+    add_sliced_mask(tensor_map, ggml_model_decoder);
+    add_rope_sin_cos(tensor_map, ggml_model_decoder);
+}
+
+}  // namespace
+
+TranslateSession::TranslateSession(const frontend::InputModel::Ptr & input_model,
+                                   const std::unordered_map & translator_map,
+                                   bool naive) :
+    m_input_model(input_model),
+    m_translator_map(translator_map),
+    m_ov_model(nullptr),
+    m_naive(naive) {}
+
+std::shared_ptr TranslateSession::get_converted_model() {
+    if (m_ov_model) {
+        return m_ov_model;
+    }
+    m_ov_model = translate_graph(m_input_model);
+    return m_ov_model;
+}
+
+std::shared_ptr TranslateSession::translate_graph(const frontend::InputModel::Ptr & input_model) {
+    ov::ParameterVector params;
+    ov::ResultVector results;
+    auto tensor_map = std::make_shared();
+    std::shared_ptr resulting_model;
+
+    const auto & ggml_model = std::dynamic_pointer_cast(input_model);
+    std::shared_ptr ggml_model_decoder = ggml_model->get_model_decoder();
+
+    for (const auto & it : ggml_model_decoder->get_model_inputs()) {
+        params.push_back(std::dynamic_pointer_cast(it.second));
+        (*tensor_map)[it.first] = it.second;
+    }
+
+    for (const auto & it : ggml_model_decoder->get_model_extra_inputs()) {
+        if (std::dynamic_pointer_cast(it.second)) {
+            params.push_back(std::dynamic_pointer_cast(it.second));
+        }
+        (*tensor_map)[it.first] = it.second;
+    }
+
+    for (const auto & it : ggml_model_decoder->get_model_weights()) {
+        (*tensor_map)[it.first] = it.second;
+    }
+
+    auto node_visitor = [&](std::shared_ptr decoder, int node_idx) {
+        auto operation_type = decoder->get_op_type(node_idx);
+        if (operation_type == "GGML_OP_NONE") {
+            return;
+        }
+
+        ov::OutputVector converted_outputs;
+        auto it = m_translator_map.find(operation_type);
+        FRONT_END_OP_CONVERSION_CHECK(it != m_translator_map.end(), "Translation for operation type ", operation_type,
+                                      " is not implemented.");
+        NodeContext node_context(decoder, tensor_map, node_idx, this);
+        converted_outputs = it->second(node_context);
+
+        const auto & node_output_names = decoder->get_output_names(node_idx);
+        FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), "Number of ",
+                                      operation_type, " outputs greater than number of converted outputs, which are ",
+                                      node_output_names.size(), " and ", converted_outputs.size(), " respectively.");
+
+        for (size_t i = 0; i < node_output_names.size(); ++i) {
+            auto output_name = node_output_names[i];
+            if (i < converted_outputs.size() && converted_outputs[i].get_node_shared_ptr() != nullptr) {
+                (*tensor_map)[output_name] = converted_outputs[i];
+            }
+        }
+    };
+
+    if (!m_naive) {
+        preprocess(*tensor_map, *ggml_model_decoder);
+    }
+    ggml_model_decoder->visit_subgraph(node_visitor);
+
+    for (const auto & name : ggml_model_decoder->get_model_output_names()) {
+        FRONT_END_GENERAL_CHECK(tensor_map->find(name) != tensor_map->end(),
+                                "Output name not found in tensor map: ", name);
+        auto result = std::make_shared(tensor_map->at(name));
+        result->set_friendly_name(name);
+        results.push_back(result);
+    }
+
+    ov::ParameterVector used_params;
+    for (const auto & param : params) {
+        if (!param->output(0).get_target_inputs().empty()) {
+            used_params.push_back(param);
+        }
+    }
+    // if (auto diff = params.size() - used_params.size()) {
+    //     GGML_LOG_INFO("%zu parameters are not used in the model.", diff);
+    // }
+    resulting_model = std::make_shared(results, used_params);
+
+    apply_transformations(resulting_model);
+    return resulting_model;
+}
+
+std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr model) {
+    auto ggml_model_decoder = std::dynamic_pointer_cast(m_input_model)->get_model_decoder();
+    {
+        ov::pass::Manager manager;
+        manager.set_per_pass_validation(true);
+        manager.register_pass();
+
+        if (ggml_model_decoder->is_stateful()) {
+            const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
+            const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
+            manager.register_pass(kv_param_res_pairs);
+        }
+
+        if (ggml_model_decoder->is_static()) {
+            manager.register_pass();
+            manager.register_pass();
+        }
+        manager.run_passes(model);
+        if (ggml_model_decoder->is_stateful()) {
+            auto output_names = ggml_model_decoder->get_model_output_names();
+            std::map model_output_indexes;
+            for (size_t i=0; iget_output_size(); i++) {
+                auto output_friendly_name = model->output(i).get_node_shared_ptr()->get_friendly_name();
+                auto output_id = model_output_indexes[output_friendly_name];
+                auto model_output_shape = model->output(i).get_partial_shape();
+                auto decoder_output_shape = ggml_model_decoder->get_output_shape(output_id);
+                if (model_output_shape.rank().is_static() && decoder_output_shape.rank().is_static()
+                    && model_output_shape.rank().get_length() + 1 == decoder_output_shape.rank().get_length()
+                    && decoder_output_shape[0].is_static() && decoder_output_shape[0].get_length() == 1) {
+                    ppp.output(i).postprocess().custom([](const ov::Output& node) {
+                        auto axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {0});
+                        return std::make_shared(node, axes);
+                    });
+                }
+            }
+            model = ppp.build();
+        }
+    }
+    return model;
+}
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/translate_session.h b/ggml/src/ggml-openvino/openvino/translate_session.h
new file mode 100644
index 00000000..56a14ae7
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/translate_session.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "input_model.h"
+#include "node_context.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+class TranslateSession {
+public:
+    TranslateSession(const frontend::InputModel::Ptr& input_model,
+                     const std::unordered_map& translator_map, bool naive = false);
+
+    std::shared_ptr get_converted_model();
+    std::shared_ptr translate_graph(const frontend::InputModel::Ptr& input_model);
+
+private:
+    std::shared_ptr apply_transformations(std::shared_ptr model);
+    const frontend::InputModel::Ptr m_input_model;
+    const std::unordered_map& m_translator_map;
+    std::shared_ptr m_ov_model;
+    bool m_naive;
+};
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp
new file mode 100644
index 00000000..65356a51
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/utils.cpp
@@ -0,0 +1,226 @@
+#include "utils.h"
+
+#include "ggml-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+std::string getCurrentTime() {
+    std::time_t now = std::time(nullptr);
+    char buf[100];
+    std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now));
+    return buf;
+}
+
+void num_inputs_check(const NodeContext & context, size_t min_inputs, size_t max_inputs) {
+    auto input_size = context.get_input_size();
+    FRONT_END_OP_CONVERSION_CHECK(input_size >= min_inputs, "Got less inputs than expected");
+    FRONT_END_OP_CONVERSION_CHECK(input_size <= max_inputs, "Got more inputs than expected");
+}
+
+int non_cont_dim(std::vector ne, std::vector nb) {
+    int dim = nb.size() - 1;
+    size_t bytes = nb[dim];
+    for (int i = dim; i > 0; i--) {
+        bytes *= ne[i];
+        if (bytes != nb[i - 1]) {
+            return i;
+        }
+    }
+    return 0;
+}
+
+std::shared_ptr get_dimensions(const std::shared_ptr & shape,
+                                         const std::vector & dims) {
+    using namespace ov::op;
+    const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
+    const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims);
+    return std::make_shared(shape, dims_const, zero);
+}
+
+std::shared_ptr get_dimensions(const std::shared_ptr & node, const std::vector & dims) {
+    return get_dimensions(std::make_shared(node), dims);
+}
+
+OutputVector rename_outputs_with_suffix(const OutputVector & outputs, const std::string & suffix) {
+    for (const auto & output : outputs) {
+        auto node = output.get_node_shared_ptr();
+        std::string name = node->get_friendly_name();
+        name += "_";
+        name += suffix;
+        node->set_friendly_name(name);
+        // std::cout << name << "  " << output.get_partial_shape() << std::endl;
+    }
+    return outputs;
+}
+
+namespace {
+ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], float ext_factor) {
+    int half_n_dims = n_dims / 2;
+    std::vector dim_ids_vec(half_n_dims);
+    std::iota(dim_ids_vec.begin(), dim_ids_vec.end(), 0);
+    auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, (size_t) half_n_dims}, dim_ids_vec);
+    auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[0]});
+    auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[1]});
+    auto denom = std::make_shared(
+        std::make_shared(corr_high, corr_low),
+        ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {0.001f}));
+    auto ramp_y =
+        std::make_shared(std::make_shared(dim_ids, corr_low), denom);
+    auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f);
+    auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor});
+    auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node);
+    return ramp_mix;
+}
+
+float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
+#ifndef M_PI
+#    define M_PI 3.14159265358979323846
+#endif
+    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(int n_dims,
+                              int n_ctx_orig,
+                              float freq_base,
+                              float beta_fast,
+                              float beta_slow,
+                              float dims[2]) {
+    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
+    float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
+    dims[0] = std::max(0.0f, start);
+    dims[1] = std::min(static_cast(n_dims - 1), end);
+}
+}  // namespace
+
+std::pair, ov::Output> make_sin_cos(int32_t * rope_params,
+                                                           std::shared_ptr inp_pos,
+                                                           std::shared_ptr rope_freqs_weight,
+                                                           bool stateful) {
+    if (stateful) {
+        inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
+        inp_pos = std::make_shared(inp_pos, ov::element::f32);
+        auto pos_perm =
+            std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0});
+        inp_pos = std::make_shared(inp_pos, pos_perm);
+    } else {
+        inp_pos = std::make_shared(inp_pos, ov::element::f32);
+        auto pos_perm =
+            std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{0, 3, 1, 2});
+        inp_pos = std::make_shared(inp_pos, pos_perm);
+    }
+
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+    const int n_dims = rope_params[1];
+    const int n_ctx_orig = rope_params[4];
+    memcpy(&freq_base, rope_params + 5, sizeof(float));
+    memcpy(&freq_scale, rope_params + 6, sizeof(float));
+    memcpy(&ext_factor, rope_params + 7, sizeof(float));
+    memcpy(&attn_factor, rope_params + 8, sizeof(float));
+    memcpy(&beta_fast, rope_params + 9, sizeof(float));
+    memcpy(&beta_slow, rope_params + 10, sizeof(float));
+
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    std::vector factor(n_dims / 2);
+    factor[0] = 1.0f;
+    for (size_t i = 1; i < factor.size(); i++) {
+        factor[i] = theta_scale * factor[i - 1];
+    }
+
+    Output freq_factors;
+    if (stateful) {
+        freq_factors =
+            std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);
+    } else {
+        freq_factors =
+            std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
+    }
+    if (rope_freqs_weight) {
+        freq_factors = std::make_shared(freq_factors, rope_freqs_weight);
+    }
+
+    auto theta_extrap = std::make_shared(freq_factors, inp_pos);
+    auto theta_interp = std::make_shared(
+        theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale}));
+
+    Output theta;
+    float mscale = attn_factor;
+    if (ext_factor == 0.0f) {
+        theta = theta_interp;
+    } else {
+        auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);
+        Output one;
+        if (stateful) {
+            one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});
+        } else {
+            one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
+        }
+        auto one_minus_ramp = std::make_shared(one, ramp_mix);
+
+        theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp),
+                                                  std::make_shared(theta_extrap, ramp_mix));
+        mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale));
+    }
+
+    Output cos_theta = std::make_shared(theta);
+    Output sin_theta = std::make_shared(theta);
+
+    auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale});
+
+    cos_theta = std::make_shared(cos_theta, mscale_node);
+    sin_theta = std::make_shared(sin_theta, mscale_node);
+    return std::make_pair(sin_theta, cos_theta);
+}
+
+ov::Output process_view_input(const NodeContext & context, int input_index, int slice_len) {
+    // Only works for VIEW operations that slice at the lowest dimension
+    // If the VIEW also reshape the result, `slice_len` should be provided
+    auto input = context.get_input(input_index);
+    auto * op_params = (size_t *) context.get_input_op_params(input_index);
+    auto src1_stride = context.get_input_stride(input_index);
+
+    int64_t split_addr = op_params[0] / src1_stride[3];
+    if (slice_len == 0) {
+        slice_len = context.get_input_shape(input_index)[3].get_length();
+    }
+    int64_t slice_end = split_addr + slice_len;
+
+    auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
+    auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
+    auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
+    auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});
+    auto sliced = std::make_shared(input, begin, end, stride, axes);
+    return sliced;
+}
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h
new file mode 100644
index 00000000..88dcad4c
--- /dev/null
+++ b/ggml/src/ggml-openvino/openvino/utils.h
@@ -0,0 +1,85 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include "node_context.h"
+
+namespace ov {
+namespace frontend {
+namespace ggml {
+
+std::string getCurrentTime();
+
+void dump_ov_model(std::shared_ptr model);
+
+void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs);
+
+int non_cont_dim(std::vector ne, std::vector nb);
+
+template 
+std::vector argsort_descend(const std::vector& v) {
+    std::vector idx(v.size());
+    std::iota(idx.begin(), idx.end(), 0);
+    std::sort(idx.begin(), idx.end(), [&v](int i1, int i2) {
+        return v[i1] > v[i2];
+    });
+    return idx;
+}
+
+template 
+std::vector sorted_descend(std::vector v) {
+    std::sort(v.begin(), v.end(), [](T a, T b) {
+        return a > b;
+    });
+    return v;
+}
+
+template 
+bool is_permuted(const std::vector& strides) {
+    for (size_t i = 0; i < strides.size() - 1; ++i) {
+        if (strides[i] < strides[i + 1]) {
+            return true;
+        }
+    }
+    return false;
+}
+
+template 
+std::vector permute(const std::vector& x, const std::vector& perm) {
+    std::vector result;
+    result.reserve(perm.size());
+    for (int i : perm) {
+        result.push_back(x[i]);
+    }
+    return result;
+}
+
+std::shared_ptr get_dimensions(const std::shared_ptr& shape,
+                                         const std::vector& dims);
+std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims);
+
+OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix);
+
+std::pair, ov::Output> make_sin_cos(int32_t* rope_params,
+                                                           std::shared_ptr inp_pos,
+                                                           std::shared_ptr rope_freqs_weight = nullptr,
+                                                           bool stateful = false);
+
+ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0);
+
+namespace op {
+template 
+OutputVector translate_1to1_match_2_inputs(const NodeContext& context) {
+    num_inputs_check(context, 2, 2);
+    auto res = std::make_shared(context.get_input(0), context.get_input(1));
+    return rename_outputs_with_suffix({res}, context.get_name());
+}
+}  // namespace op
+
+}  // namespace ggml
+}  // namespace frontend
+}  // namespace ov
diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp
new file mode 100644
index 00000000..1b553a0d
--- /dev/null
+++ b/ggml/src/ggml-openvino/utils.cpp
@@ -0,0 +1,823 @@
+#include "utils.h"
+
+#include "ggml-impl.h"
+#include "ggml-openvino-extra.h"
+#include "ggml-openvino/ggml-decoder.h"
+#include "ggml.h"
+#include "openvino/frontend.h"
+#include "openvino/input_model.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// Suppress  deprecation warning for ov::Tensor::data()
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
+
+enum ggml_status ov_graph_compute(ggml_cgraph * cgraph, ggml_backend_t backend) {
+    ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context;
+    try {
+        if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) {
+            std::string filename = "cgraph_ov.txt";
+            GgmlOvDecoder::dump_cgraph(cgraph, filename);
+        }
+
+        const auto is_static = ggml_openvino_is_npu();
+
+        GGML_ASSERT(ctx->runtime_context != nullptr);
+        std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context);
+
+        return is_static ? ov_graph_compute_static(cgraph, r_ctx) : ov_graph_compute_dynamic(cgraph, r_ctx);
+    } catch (const ov::Exception & e) {
+        GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what());
+        return GGML_STATUS_FAILED;
+    } catch (const std::exception & e) {
+        GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what());
+        return GGML_STATUS_FAILED;
+    } catch (...) {
+        GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n");
+        return GGML_STATUS_FAILED;
+    }
+}
+
+ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder,
+                                   std::shared_ptr infer_request,
+                                   int output_index,
+                                   const ggml_tensor * ggml_tensor) {
+    auto output_type = ggml_decoder->get_ov_type(ggml_tensor);
+    ov::Shape output_shape;
+    if (ggml_decoder->is_static()) {
+        output_shape = infer_request->get_output_tensor(output_index).get_shape();
+    } else {
+        output_shape = ggml_decoder->get_shape(ggml_tensor);
+    }
+
+    ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
+    return output_tensor;
+}
+
+enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) {
+    auto & core = ov_singleton_core();
+    const auto & config = ggml_openvino_get_compile_config();
+    auto device = r_ctx->device;
+    bool stateful = r_ctx->stateful;
+    static auto is_static = false;
+
+    if (is_naive(cgraph)) {
+        return naive_compute(cgraph, core, device, config);
+    }
+
+    auto start_time = ggml_time_us();
+
+    std::shared_ptr ggml_decoder;
+    std::shared_ptr infer_request;
+    ModelParams m_params;
+    ComputeParams c_params;
+    std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static);
+
+    graph_key key(cgraph);
+    bool cache_hit;
+
+    int64_t decoder_end_time;
+    int64_t conversion_end_time;
+    int64_t compile_end_time;
+    int64_t infer_end_time;
+
+    {
+        std::lock_guard lock(r_ctx->ov_compute_mutex);
+
+        auto it = r_ctx->decoder_cache.find(key);
+
+        cache_hit = it != r_ctx->decoder_cache.end();
+        ModelParams old_m_params;
+        if (cache_hit) {
+            ggml_decoder = it->second;
+            old_m_params = ggml_decoder->get_model_params();
+            cache_hit = old_m_params.can_reuse_dynamically(m_params);
+        }
+
+        if (cache_hit) {
+            std::map> model_weights;
+            ggml_decoder->set_compute_params(c_params);
+            ggml_decoder->set_model_params(m_params);
+            if (old_m_params.kv_buffer_changed(m_params)) {
+                ggml_decoder->update_io(cgraph);
+            }
+            ggml_decoder->add_extra_inputs();
+            infer_request = r_ctx->infer_request_cache.at(key);
+
+            if (stateful) {
+                const auto * inp_pos = get_inp_pos_tensor(cgraph);
+                int32_t * pos_data = (int32_t *) inp_pos->data;
+                auto pos_shape = ggml_decoder->get_shape(inp_pos);
+                if (pos_data[0] == 0) {
+                    infer_request->reset_state();
+                    r_ctx->stateful_kv_size = pos_shape[3];
+                } else if (r_ctx->stateful_kv_size == static_cast(pos_data[0])) {
+                    r_ctx->stateful_kv_size += pos_shape[3];
+                } else {
+                    auto states = infer_request->query_state();
+                    for (auto state : states) {
+                        auto state_tensor = state.get_state();
+                        auto state_tensor_shape = state_tensor.get_shape();
+                        if (static_cast(pos_data[0]) > r_ctx->stateful_kv_size) {
+                            std::string state_name;
+                            try {
+                                state_name = r_ctx->kv_state_input_name_map.at(state.get_name());
+                            } catch (...) {
+                                GGML_LOG_ERROR("GGML OpenVINO backend stateful inference failed: no input found for the state\n");
+                                return GGML_STATUS_FAILED;
+                            }
+                            auto kv_tensor = get_ov_input_tensor(ggml_decoder, state_name);
+                            kv_tensor.set_shape({state_tensor_shape[0], kv_tensor.get_shape()[2],
+                                                 state_tensor_shape[2], state_tensor_shape[3]});
+                           state_tensor = kv_tensor;
+                           state_tensor_shape = state_tensor.get_shape();
+                        }
+                        ov::Coordinate begin = {0, 0, 0, 0};
+                        ov::Coordinate end = {state_tensor_shape[0], static_cast(pos_data[0]),
+                                              state_tensor_shape[2], state_tensor_shape[3]};
+                        ov::Tensor new_state_tensor(state_tensor, begin, end);
+                        state.set_state(new_state_tensor);
+                    }
+                    r_ctx->stateful_kv_size = pos_data[0] + 1;
+                }
+            }
+
+            decoder_end_time = ggml_time_us();
+            conversion_end_time = decoder_end_time;
+            compile_end_time = decoder_end_time;
+        } else {
+            r_ctx->infer_request_cache.erase(key);
+
+            std::shared_ptr model;
+            auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
+
+            ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, stateful);
+            decoder_end_time = ggml_time_us();
+
+            auto input_model = std::make_shared(ggml_decoder);
+            model = ov::frontend::ggml::FrontEnd::convert(input_model);
+            ggml_decoder->clear_model_weights();
+            conversion_end_time = ggml_time_us();
+
+            if (getenv("GGML_OPENVINO_DUMP_IR")) {
+                char timestamped_filename[64];
+                auto timestamp = (long long) ggml_time_us();
+                snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp);
+                ov::serialize(model, timestamped_filename);
+            }
+
+            ov::CompiledModel compiled_model;
+            auto remote_context = ggml_openvino_get_remote_context();
+            if (remote_context.has_value()) {
+                compiled_model = core.compile_model(model, remote_context.value(), config);
+            } else {
+                compiled_model = core.compile_model(model, device, config);
+            }
+            compile_end_time = ggml_time_us();
+            infer_request = std::make_shared(compiled_model.create_infer_request());
+            r_ctx->infer_request_cache[key] = infer_request;
+            r_ctx->decoder_cache[key] = ggml_decoder;
+
+            std::vector ov_input_names;
+            std::vector ov_output_names;
+            for (const auto & ov_param : model->get_parameters()) {
+                ov_input_names.push_back(ov_param->get_friendly_name());
+            }
+            for (const auto & ov_output : model->get_results()) {
+                ov_output_names.push_back(ov_output->get_friendly_name());
+            }
+            r_ctx->ov_input_names_cache[key] = std::move(ov_input_names);
+            r_ctx->ov_output_names_cache[key] = std::move(ov_output_names);
+
+            if (stateful) {
+                const auto * inp_pos = get_inp_pos_tensor(cgraph);
+                auto pos_shape = ggml_decoder->get_shape(inp_pos);
+                r_ctx->stateful_kv_size = pos_shape[3];
+                const auto kv_param_res_names = ggml_decoder->get_kv_param_res_names();
+                for (const auto& pair : kv_param_res_names) {
+                    r_ctx->kv_state_input_name_map[pair.first+pair.second] = pair.first;
+                }
+            }
+        }
+
+        auto ov_input_names = r_ctx->ov_input_names_cache[key];
+        auto ov_output_names = r_ctx->ov_output_names_cache[key];
+
+        for (size_t i = 0; i < ov_input_names.size(); i++) {
+            auto param_name = ov_input_names[i];
+            auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name);
+            infer_request->set_input_tensor(i, input_tensor);
+
+            if (getenv("GGML_OPENVINO_DEBUG_INPUT")) {
+                print_input_tensor_info(param_name, input_tensor);
+            }
+        }
+
+        for (size_t i = 0; i < ov_output_names.size(); i++) {
+            auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]);
+            auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor);
+            infer_request->set_output_tensor(i, output_tensor);
+        }
+
+        infer_request->infer();
+        infer_end_time = ggml_time_us();
+
+        if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) {
+            for (size_t i = 0; i < ov_output_names.size(); i++) {
+                const auto output_tensor = infer_request->get_output_tensor(i);
+                print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());
+            }
+        }
+
+        if (getenv("GGML_OPENVINO_PROFILING")) {
+            GGML_LOG_INFO("\nGGML OpenVINO Backend: \n");
+            GGML_LOG_INFO("  - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000);
+            if (!cache_hit) {
+                GGML_LOG_INFO("  - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000);
+                GGML_LOG_INFO("  - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000);
+            }
+            GGML_LOG_INFO("  - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000);
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr r_ctx) {
+    auto & core = ov_singleton_core();
+
+    auto get_prefill_chunk_size = [] {
+        const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE");
+        if (chunk_size_str && atoi(chunk_size_str) > 0) {
+            return atoi(chunk_size_str);
+        }
+        return 256;
+    };
+
+    static std::string device = "NPU";
+    static auto is_static = true;
+    static auto stateful = false;
+    static auto prefill_chunk_size = get_prefill_chunk_size();
+    const auto & config = ggml_openvino_get_compile_config();
+
+    if (is_naive(cgraph)) {
+        return naive_compute(cgraph, core, device, config);
+    }
+
+    auto start_time = ggml_time_us();
+
+    std::shared_ptr ggml_decoder;
+    std::shared_ptr infer_request;
+    ModelParams m_params;
+    ComputeParams c_params;
+    std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static);
+
+    const auto * inp_pos = get_inp_pos_tensor(cgraph);
+    const auto is_prefill = get_is_prefill(inp_pos);
+    graph_key key(cgraph);
+    bool cache_hit;
+
+    int64_t decoder_end_time;
+    int64_t conversion_end_time;
+    int64_t compile_end_time;
+    int64_t infer_end_time;
+
+    auto it = r_ctx->decoder_cache.find(key);
+
+    cache_hit = it != r_ctx->decoder_cache.end();
+    ModelParams old_m_params;
+    if (cache_hit) {
+        ggml_decoder = it->second;
+        old_m_params = ggml_decoder->get_model_params();
+        cache_hit = old_m_params.can_reuse_statically(m_params);
+    }
+
+    if (cache_hit) {
+        std::map> model_weights;
+        ggml_decoder->m_is_prefill = is_prefill;
+        ggml_decoder->set_model_params(m_params);
+        ggml_decoder->set_compute_params(c_params);
+        if (old_m_params.kv_buffer_changed(m_params)) {
+            ggml_decoder->update_io(cgraph);
+        }
+        ggml_decoder->add_extra_inputs();
+        infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key);
+
+        decoder_end_time = ggml_time_us();
+        conversion_end_time = decoder_end_time;
+        compile_end_time = decoder_end_time;
+    } else {
+        r_ctx->infer_request_cache.erase(key);
+        r_ctx->infer_request_cache_prefill.erase(key);
+
+        std::shared_ptr model;
+        auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
+
+        auto ggml_decoder_prefill = std::make_shared(cgraph, m_params, c_params, model_weights,
+                                                                    is_static, stateful, true, prefill_chunk_size);
+        auto ggml_decoder_decode = std::make_shared(cgraph, m_params, c_params, model_weights, is_static,
+                                                                   stateful, false, prefill_chunk_size);
+        decoder_end_time = ggml_time_us();
+
+        auto input_model_prefill = std::make_shared(ggml_decoder_prefill);
+        auto input_model_decode = std::make_shared(ggml_decoder_decode);
+
+        auto model_prefill = ov::frontend::ggml::FrontEnd::convert(input_model_prefill);
+        ggml_decoder_prefill->clear_model_weights();
+        auto model_decode = ov::frontend::ggml::FrontEnd::convert(input_model_decode);
+        ggml_decoder_decode->clear_model_weights();
+        conversion_end_time = ggml_time_us();
+
+        if (getenv("GGML_OPENVINO_DUMP_IR")) {
+            char timestamped_filename[64];
+            auto timestamp = (long long) ggml_time_us();
+            snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp);
+            ov::serialize(model_prefill, timestamped_filename);
+            snprintf(timestamped_filename, sizeof(timestamped_filename), "model_decode_%lld.xml", timestamp);
+            ov::serialize(model_decode, timestamped_filename);
+        }
+
+        ov::CompiledModel compiled_model_prefill;
+        ov::CompiledModel compiled_model_decode;
+        auto remote_context = ggml_openvino_get_remote_context();
+        if (remote_context.has_value()) {
+            compiled_model_prefill = core.compile_model(model_prefill, remote_context.value(), config);
+            compiled_model_decode = core.compile_model(model_decode, remote_context.value(), config);
+        } else {
+            compiled_model_prefill = core.compile_model(model_prefill, device, config);
+            compiled_model_decode = core.compile_model(model_decode, device, config);
+        }
+
+        r_ctx->infer_request_cache_prefill[key] =
+            std::make_shared(compiled_model_prefill.create_infer_request());
+        r_ctx->infer_request_cache[key] =
+            std::make_shared(compiled_model_decode.create_infer_request());
+        compile_end_time = ggml_time_us();
+
+        model = is_prefill ? model_prefill : model_decode;
+        ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode;
+        infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key];
+        r_ctx->decoder_cache[key] = ggml_decoder;
+
+        std::vector ov_input_names;
+        std::vector ov_output_names;
+        for (const auto & ov_param : model->get_parameters()) {
+            ov_input_names.push_back(ov_param->get_friendly_name());
+        }
+        for (const auto & ov_output : model->get_results()) {
+            ov_output_names.push_back(ov_output->get_friendly_name());
+        }
+        r_ctx->ov_input_names_cache[key] = std::move(ov_input_names);
+        r_ctx->ov_output_names_cache[key] = std::move(ov_output_names);
+    }
+
+    auto ov_input_names = r_ctx->ov_input_names_cache[key];
+    auto ov_output_names = r_ctx->ov_output_names_cache[key];
+
+    if (is_prefill) {
+        auto inp_len = inp_pos->ne[0];
+        for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) {
+            for (size_t i = 0; i < ov_input_names.size(); i++) {
+                auto param_name = ov_input_names[i];
+                auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index);
+                infer_request->set_input_tensor(i, input_tensor);
+
+                if (getenv("GGML_OPENVINO_DEBUG_INPUT")) {
+                    const auto input_tensor = infer_request->get_input_tensor(i);
+                    print_input_tensor_info(param_name, input_tensor);
+                }
+            }
+
+            for (size_t i = 0; i < ov_output_names.size(); i++) {
+                auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]);
+                auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor);
+                infer_request->set_output_tensor(i, output_tensor);
+            }
+
+            infer_request->infer();
+
+            if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) {
+                for (size_t i = 0; i < ov_output_names.size(); i++) {
+                    const auto output_tensor = infer_request->get_output_tensor(i);
+                    print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());
+                }
+            }
+        }
+        infer_end_time = ggml_time_us();
+    } else {
+        for (size_t i = 0; i < ov_input_names.size(); i++) {
+            auto param_name = ov_input_names[i];
+            auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name);
+            infer_request->set_input_tensor(i, input_tensor);
+
+            if (getenv("GGML_OPENVINO_DEBUG_INPUT")) {
+                const auto input_tensor = infer_request->get_input_tensor(i);
+                print_input_tensor_info(param_name, input_tensor);
+            }
+        }
+
+        for (size_t i = 0; i < ov_output_names.size(); i++) {
+            auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]);
+            auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor);
+            infer_request->set_output_tensor(i, output_tensor);
+        }
+
+        infer_request->infer();
+        infer_end_time = ggml_time_us();
+
+        if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) {
+            for (size_t i = 0; i < ov_output_names.size(); i++) {
+                const auto output_tensor = infer_request->get_output_tensor(i);
+                print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data());
+            }
+        }
+    }
+
+    if (getenv("GGML_OPENVINO_PROFILING")) {
+        GGML_LOG_INFO("\nGGML OpenVINO Backend: \n");
+        GGML_LOG_INFO("  - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000);
+        if (!cache_hit) {
+            GGML_LOG_INFO("  - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000);
+            GGML_LOG_INFO("  - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000);
+        }
+        GGML_LOG_INFO("  - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000);
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+bool is_naive(ggml_cgraph * cgraph) {
+    constexpr int naive_graph_size_threshold = 20;
+    int count = 0;
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i]->op != GGML_OP_NONE) {
+            count++;
+        }
+    }
+    return count < naive_graph_size_threshold;
+}
+
+enum ggml_status naive_compute(ggml_cgraph * cgraph,
+                               ov::Core & core,
+                               const std::string & device,
+                               const ov::AnyMap & config) {
+    if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_NONE || cgraph->nodes[0]->op == GGML_OP_VIEW)) {
+        return GGML_STATUS_SUCCESS;
+    }
+
+    bool naive = true;
+    auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, naive);
+    auto decoder = std::make_shared(cgraph, model_weights);
+    auto input_model = std::make_shared(decoder);
+    auto model = ov::frontend::ggml::FrontEnd::convert(input_model, naive);
+    if (getenv("GGML_OPENVINO_DUMP_IR")) {
+        ov::serialize(model, "IR_naive.xml");
+    }
+
+    std::shared_ptr infer_request;
+    auto remote_context = ggml_openvino_get_remote_context();
+    if (cgraph->nodes[0]->op == GGML_OP_MUL_MAT) {
+        // TODO ACCURACY hint triggers a bug in GPU plugin/driver on Lunar Lake. Remove once CVS-182166 is resolved
+        core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::PERFORMANCE));
+    } else {
+        core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY));
+    }
+    if (remote_context.has_value()) {
+        infer_request = std::make_shared(
+            core.compile_model(model, remote_context.value(), config).create_infer_request());
+    } else {
+        infer_request =
+            std::make_shared(core.compile_model(model, device, config).create_infer_request());
+    }
+
+    auto ov_params = model->get_parameters();
+    for (size_t i = 0; i < ov_params.size(); i++) {
+        auto param_name = ov_params[i]->get_friendly_name();
+        auto input_tensor = get_ov_input_tensor(decoder, param_name);
+        infer_request->set_input_tensor(i, input_tensor);
+    }
+
+    auto ov_results = model->get_results();
+    for (size_t i = 0; i < ov_results.size(); i++) {
+        auto * ggml_tensor = decoder->get_model_outputs().at(ov_results[i]->get_friendly_name());
+        auto output_tensor = create_ov_output_tensor(decoder, infer_request, i, ggml_tensor);
+        infer_request->set_output_tensor(i, output_tensor);
+    }
+
+    infer_request->infer();
+    return GGML_STATUS_SUCCESS;
+}
+
+namespace {
+ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, const std::string & name) {
+    const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name);
+
+    if (ggml_tensor->extra != nullptr) {
+        // GGML_LOG_DEBUG("Using ggml_tensor->extra as ov::Tensor for input: %s\n", name.c_str());
+        auto * extra_base = static_cast(ggml_tensor->extra);
+        if (extra_base->type != ggml_openvino_extra_base::Type::TENSOR) {
+            throw std::runtime_error("ggml tensor extra is not of type TENSOR for input: " + name);
+        }
+        auto * tensor_extra = static_cast(extra_base);
+        return *tensor_extra->tensor;
+    }
+
+    // GGML_LOG_DEBUG("Converting ggml tensor to ov::Tensor for input: %s\n", name.c_str());
+    auto * input_data = ggml_tensor->data;
+    ov::Shape input_shape;
+    if (ggml_tensor->op == GGML_OP_VIEW) {
+        // This case is added to make test-backend-ops work
+        input_shape = ggml_decoder->get_shape(ggml_tensor->view_src);
+    } else {
+        input_shape = ggml_decoder->get_shape(ggml_tensor);
+    }
+    auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data);
+    return input_tensor;
+}
+}  // namespace
+
+ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name) {
+    ov::Tensor input_tensor;
+    if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) {
+        input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name);
+    } else {
+        input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
+    }
+    return input_tensor;
+}
+
+ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder,
+                                             const std::string & param_name) {
+    // NPU decoding stage
+    const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
+    const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
+
+    if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) ||
+        GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) {
+        assert(ggml_tensor->ne[0] == 1);
+        ov::Shape input_shape = {1, 1, 1, 1};
+        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);
+        if (ggml_tensor->type == GGML_TYPE_I32) {
+            *input_tensor.data() = *((int32_t *) ggml_tensor->data);
+        } else if (ggml_tensor->type == GGML_TYPE_I64) {
+            *input_tensor.data() = *((int64_t *) ggml_tensor->data);
+        } else {
+            throw std::runtime_error("Unexpected tensor type for " + param_name);
+        }
+        return input_tensor;
+    }
+
+    if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) {
+        ov::Shape input_shape = {1, 1, 1, 1};
+        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);
+        int32_t inp_out_id = *((int32_t *) ggml_tensor->data);
+        assert(ggml_tensor->ne[0] == 1);
+        assert(inp_out_id == 0);
+        *input_tensor.data() = inp_out_id;
+        return input_tensor;
+    }
+
+    if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) {
+        size_t context_size = ggml_decoder->get_ctx_size();
+        std::vector padded_data = pad_input(ggml_tensor, 1, context_size, -INFINITY);
+        ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size});
+        auto * data_ptr = input_tensor.data();
+        std::copy(padded_data.begin(), padded_data.begin() + context_size, data_ptr);
+        return input_tensor;
+    }
+
+    return get_ov_input_tensor(ggml_decoder, param_name);
+}
+
+ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggml_decoder,
+                                              const std::string & param_name,
+                                              int chunk_index) {
+    // NPU prompt processing stage
+    const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
+    const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
+
+    const size_t input_len = ggml_decoder->get_input_len();
+    const size_t chunk_size = ggml_decoder->m_prefill_chunk_size;
+    const size_t chunk_valid_size = std::min(chunk_size, input_len - chunk_index * chunk_size);
+    const size_t chunk_pad_size = chunk_size - chunk_valid_size;
+
+    if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) ||
+        GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) {
+        ov::Shape input_shape = {1, 1, 1, chunk_size};
+        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);
+        // copy the chunk_index-th chunk from ggml_tensor
+        size_t element_size = ggml_type_size(ggml_tensor->type);
+        void * input_data = (char *) ggml_tensor->data + chunk_index * chunk_size * element_size;
+        std::memcpy(input_tensor.data(), input_data, chunk_valid_size * element_size);
+        // pad the rest with last_value + 1, so that kv's of padded positions are inserted
+        // to the next row after the valids row in the kvcache
+        if (chunk_pad_size > 0) {
+            if (ggml_tensor->type == GGML_TYPE_I32) {
+                int32_t last_value =
+                    *((int32_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1));
+                int32_t * output_data = input_tensor.data();
+                std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1);
+            } else if (ggml_tensor->type == GGML_TYPE_I64) {
+                int64_t last_value =
+                    *((int64_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1));
+                int64_t * output_data = input_tensor.data();
+                std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1);
+            } else {
+                throw std::runtime_error("Unexpected tensor type for " + param_name);
+            }
+        }
+        return input_tensor;
+    }
+
+    if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) {
+        size_t output_len = ggml_decoder->get_compute_params().output_len;
+        ov::Shape input_shape = {1, 1, 1, output_len};
+        ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape);
+        if (ggml_tensor->ne[0] == 0) {
+            *input_tensor.data() = 0;
+        } else {
+            auto * data_addr = input_tensor.data();
+            for (size_t i = 0; i < output_len; i++) {
+                data_addr[i] = ((int32_t *) ggml_tensor->data)[i] % chunk_size;
+            }
+        }
+        return input_tensor;
+    }
+
+    if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) {
+        size_t cols = ggml_tensor->ne[0];
+        size_t rows = ggml_tensor->ne[1];
+        float * ggml_data = (float *) ggml_tensor->data + chunk_index * chunk_size * cols;
+        size_t chunk_valid_rows = std::min(chunk_size, rows - chunk_index * chunk_size);
+        size_t context_size = ggml_decoder->get_ctx_size();
+        std::vector padded_data =
+            pad_input(ggml_data, chunk_valid_rows, cols, chunk_size, context_size, -INFINITY);
+        set_zero_diagonal(padded_data, chunk_size, context_size);
+        ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, chunk_size, context_size});
+        auto * data_ptr = input_tensor.data();
+        std::copy(padded_data.begin(), padded_data.begin() + chunk_size * context_size, data_ptr);
+        return input_tensor;
+    }
+
+    return get_ov_input_tensor(ggml_decoder, param_name);
+}
+
+size_t checksum(const void * data, size_t size) {
+    const uint8_t * bytes = static_cast(data);
+    size_t sum = 0;
+    for (size_t i = 0; i < size; ++i) {
+        sum += (uint8_t) i;
+        sum += bytes[i];
+    }
+    return sum;
+}
+
+void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor) {
+    std::cout << "Input name: " << name << ", Input shape: " << tensor.get_shape() << ", Address: " << tensor.data()
+              << std::endl;
+    switch (tensor.get_element_type()) {
+    case ov::element::f32: {
+        if (name.find("self_kq_mask") == std::string::npos) {
+            std::cout << *(tensor.data()) << std::endl;
+        } else {
+            size_t rows = tensor.get_shape()[2];
+            size_t cols = tensor.get_shape()[3];
+            auto * data = tensor.data();
+            for (size_t i = 0; i < rows; ++i) {
+                for (size_t j = 0; j < cols; ++j) {
+                    float val = data[i * cols + j];
+                    if (std::isinf(val) && val < 0) {
+                        std::cout << std::setw(5) << "-inf";
+                    } else {
+                        std::cout << std::setw(5) << val;
+                    }
+                }
+                std::cout << std::endl;
+            }
+        }
+
+        break;
+    }
+    case ov::element::f16:
+        std::cout << *(tensor.data()) << std::endl;
+        break;
+    case ov::element::i32:
+        for (size_t i = 0; i < tensor.get_size(); ++i) {
+            std::cout << tensor.data()[i] << " ";
+        }
+        std::cout << std::endl;
+        break;
+    case ov::element::i64:
+        for (size_t i = 0; i < tensor.get_size(); ++i) {
+            std::cout << tensor.data()[i] << " ";
+        }
+        std::cout << std::endl;
+        break;
+    default:
+        break;
+    }
+}
+
+void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst) {
+    std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape() << ", Address: " << output_dst
+              << std::endl;
+
+    auto print_float_stats = [](const std::string & type_name, size_t size, auto get_value) {
+        if (size == 0) {
+            return;
+        }
+
+        float first = get_value(0);
+        float min = first;
+        float max = first;
+        double sum = first;
+
+        for (size_t i = 1; i < size; ++i) {
+            float v = get_value(i);
+            if (v < min) {
+                min = v;
+            }
+            if (v > max) {
+                max = v;
+            }
+            sum += v;
+        }
+        double mean = sum / size;
+
+        std::cout << std::right << std::setw(6) << type_name << std::right << std::setw(12) << "First" << std::setw(12)
+                  << "Min" << std::setw(12) << "Max" << std::setw(12) << "Mean" << std::endl;
+        std::cout << std::right << std::setw(6) << "" << std::right << std::setw(12) << first << std::setw(12) << min
+                  << std::setw(12) << max << std::setw(12) << mean << std::endl;
+    };
+
+    switch (tensor.get_element_type()) {
+    case ov::element::f32: {
+        const float * data = tensor.data();
+        size_t size = tensor.get_size();
+        print_float_stats("[f32]", size, [data](size_t i) { return data[i]; });
+        break;
+    }
+    case ov::element::f16: {
+        const ov::float16 * data = tensor.data();
+        size_t size = tensor.get_size();
+        print_float_stats("[f16]", size, [data](size_t i) { return static_cast(data[i]); });
+        break;
+    }
+    default:
+        break;
+    }
+}
+
+void set_zero_diagonal(std::vector & matrix, size_t rows, size_t cols) {
+    for (size_t i = 0; i < rows; ++i) {
+        size_t diag_col = std::min(i, cols - 1);
+        matrix[i * cols + diag_col] = 0.0f;
+    }
+}
+
+const ggml_tensor * get_inp_pos_tensor(ggml_cgraph * cgraph) {
+    for (int i = 0; i < cgraph->n_nodes; ++i) {
+        auto * op = cgraph->nodes[i];
+        for (int j = 0; j < GGML_MAX_SRC; ++j) {
+            auto * src = op->src[j];
+            if (src == nullptr) {
+                break;
+            }
+            if (GgmlOvDecoder::is_inp_pos(src, op)) {
+                return src;
+            }
+        }
+    }
+    GGML_LOG_ERROR("get_inp_pos_tensor: inp_pos not found in cgraph");
+    throw std::runtime_error("get_inp_pos_tensor: inp_pos not found in cgraph");
+}
+
+bool get_is_prefill(const ggml_tensor * inp_pos) {
+    return inp_pos->ne[0] > 1;
+}
+
+#pragma GCC diagnostic pop
diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h
new file mode 100644
index 00000000..656573d1
--- /dev/null
+++ b/ggml/src/ggml-openvino/utils.h
@@ -0,0 +1,123 @@
+#include "ggml-backend-impl.h"
+#include "ggml-decoder.h"
+#include "ggml-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+struct graph_key {
+    int n_nodes;
+    std::string first_node_name;
+    std::string last_node_name;
+
+    graph_key(const ggml_cgraph * cgraph) : n_nodes(cgraph->n_nodes) {
+        if (n_nodes > 0) {
+            first_node_name = cgraph->nodes[0]->name;
+            last_node_name = cgraph->nodes[n_nodes - 1]->name;
+        }
+    }
+
+    bool operator==(const graph_key & other) const {
+        return n_nodes == other.n_nodes && first_node_name == other.first_node_name &&
+               last_node_name == other.last_node_name;
+    }
+};
+
+struct graph_key_hash {
+    size_t operator()(const graph_key & key) const {
+        size_t h = std::hash{}(key.n_nodes);
+        if (key.n_nodes > 0) {
+            h ^= std::hash{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2);
+            h ^= std::hash{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2);
+        }
+        return h;
+    }
+};
+
+struct ov_runtime_context {
+    std::mutex ov_compute_mutex;
+    std::string device;
+    bool stateful;
+    std::unordered_map, graph_key_hash> decoder_cache;
+    std::unordered_map, graph_key_hash> infer_request_cache;
+    std::unordered_map, graph_key_hash> infer_request_cache_prefill;
+    std::unordered_map, graph_key_hash> ov_input_names_cache;
+    std::unordered_map, graph_key_hash> ov_output_names_cache;
+    //TODO: Stateful is only supported for single request at a time.
+    //      Simultanous stateful inference request support to be added.
+    size_t stateful_kv_size;
+    std::map kv_state_input_name_map;
+
+    ov_runtime_context() :
+        device("CPU"),
+        stateful(false),
+        stateful_kv_size(0) {}
+};
+
+enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend);
+
+enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, std::shared_ptr r_ctx);
+enum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph, std::shared_ptr r_ctx);
+
+size_t checksum(const void * data, size_t size);
+
+void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor);
+
+void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst);
+
+template 
+std::vector pad_input(const T * data,
+                         size_t rows,
+                         size_t cols,
+                         size_t padded_rows,
+                         size_t padded_cols,
+                         T pad_value) {
+    std::vector padded(padded_rows * padded_cols, pad_value);
+
+    for (size_t i = 0; i < std::min(rows, padded_rows); ++i) {
+        for (size_t j = 0; j < std::min(cols, padded_cols); ++j) {
+            padded[i * padded_cols + j] = data[i * cols + j];
+        }
+    }
+
+    return padded;
+}
+
+template 
+std::vector pad_input(const ggml_tensor * tensor, size_t padded_rows, size_t padded_cols, T pad_value) {
+    return pad_input(reinterpret_cast(tensor->data),
+                        static_cast(tensor->ne[1]),  // rows
+                        static_cast(tensor->ne[0]),  // cols
+                        padded_rows, padded_cols, pad_value);
+}
+
+void set_zero_diagonal(std::vector & matrix, size_t rows, size_t cols);
+
+const ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph);
+
+bool get_is_prefill(const ggml_tensor * inp_pos);
+
+ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name);
+ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder,
+                                             const std::string & param_name);
+ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggml_decoder,
+                                              const std::string & param_name,
+                                              int chunk_index);
+
+ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder,
+                                   std::shared_ptr infer_request,
+                                   int output_index,
+                                   const ggml_tensor * ggml_tensor);
+
+bool is_naive(struct ggml_cgraph * cgraph);
+
+enum ggml_status naive_compute(struct ggml_cgraph * cgraph,
+                               ov::Core & core,
+                               const std::string & device,
+                               const ov::AnyMap & config);
diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c
index de5cbd75..48695a61 100644
--- a/ggml/src/ggml-quants.c
+++ b/ggml/src/ggml-quants.c
@@ -304,6 +304,41 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
     }
 }
 
+void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_NVFP4;
+    static const int qk_sub = QK_NVFP4_SUB;
+    static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        for (int s = 0; s < n_sub; s++) {
+            const float * xb = x + i*qk + s*qk_sub;
+
+            float amax = 0.0f;
+            for (int j = 0; j < qk_sub; j++) {
+                if (amax < fabsf(xb[j])) {
+                    amax = fabsf(xb[j]);
+                }
+            }
+
+            // UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax
+            const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);
+            y[i].d[s] = ue;
+            const float d = ggml_ue4m3_to_fp32(ue);
+
+            for (int j = 0; j < qk_sub/2; ++j) {
+                const uint8_t x0 = best_index_mxfp4(xb[0        + j], d);
+                const uint8_t x1 = best_index_mxfp4(xb[qk_sub/2 + j], d);
+
+                y[i].qs[s*(qk_sub/2) + j] = x0 | (x1 << 4);
+            }
+        }
+    }
+}
+
 void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
 
@@ -434,6 +469,31 @@ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_REST
     }
 }
 
+void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_NVFP4;
+    static const int qk_sub = QK_NVFP4_SUB;
+    static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        for (int s = 0; s < n_sub; s++) {
+            const float d = ggml_ue4m3_to_fp32(x[i].d[s]);
+            float * yb = y + i*qk + s*qk_sub;
+
+            for (int j = 0; j < qk_sub/2; ++j) {
+                const int8_t v0 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] & 0x0F];
+                const int8_t v1 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] >>   4];
+
+                yb[j + 0       ] = v0*d;
+                yb[j + qk_sub/2] = v1*d;
+            }
+        }
+    }
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -2098,6 +2158,12 @@ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
     return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
 }
 
+size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_UNUSED(quant_weights);
+    quantize_row_nvfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
+}
+
 // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
 
 void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -3104,6 +3170,11 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
             }
             float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
             float eff_max = scale*kMaxQ;
+            if (eff_max <= 0) {
+                scales[ib] = 0;
+                memset(L, 0, 32);
+                continue;
+            }
             float best = 0;
             for (int is = -6; is <= 6; ++is) {
                 float id = (2*kMaxQ-1+is*0.1f)/eff_max;
@@ -3273,9 +3344,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
             }
             float max = xval[0];
             for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+            memset(L, 0, 16);
             if (max < GROUP_MAX_EPS) {
                 scales[ib] = 0;
-                memset(L, 0, 16);
                 continue;
             }
             float best = 0;
@@ -3714,9 +3785,9 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT
             }
             float max = xval[0];
             for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+            memset(L, 0, 32);
             if (max < GROUP_MAX_EPS_IQ3_XXS) {
                 scales[ib] = 0;
-                memset(L, 0, 32);
                 continue;
             }
             float best = 0;
@@ -3922,6 +3993,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT
             }
             float max = xval[0];
             for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
+            memset(L, 0, block_size);
             if (!max) {
                 scales[ib] = 0;
                 continue;
@@ -4245,6 +4317,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R
             for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
             if (max < GROUP_MAX_EPS_IQ1_S) {
                 scales[ib] = 0;
+                shifts[ib] = 1;
                 memset(L, 1, block_size);
                 continue;
             }
@@ -4285,7 +4358,12 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R
                     }
                 }
             }
-            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
+            if (besti1 < 0 || besti2 < 0 || best_shift == 0) {
+                scales[ib] = 0;
+                shifts[ib] = 1;
+                memset(L, 1, block_size);
+                continue;
+            }
             for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
             for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
             for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
@@ -4429,6 +4507,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
             for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
             if (max < GROUP_MAX_EPS_IQ1_M) {
                 scales[ib] = 0;
+                shifts[ib] = 0;
                 memset(L, 1, block_size);
                 continue;
             }
@@ -4527,7 +4606,12 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
                     }
                 }
             }
-            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
+            if (besti1 < 0 || besti2 < 0 || best_k < 0) {
+                scales[ib] = 0;
+                shifts[ib] = 0;
+                memset(L, 1, block_size);
+                continue;
+            }
             for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
             for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
             for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
@@ -4683,7 +4767,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
             sumqx += w*q*xb[j];
             sumq2 += w*q*q;
         }
-        d = sumqx/sumq2;
+        d = sumq2 > 0 ? sumqx/sumq2 : 0.f;
         float best = d*sumqx;
         for (int itry = -ntry; itry <= ntry; ++itry) {
             id = (itry + values[0])/max;
@@ -4874,6 +4958,7 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R
             }
             float max = xval[0];
             for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+            memset(L, 0, 16);
             if (max < GROUP_MAX_EPS_IQ2_S) {
                 scales[ib] = 0;
                 continue;
@@ -5225,6 +5310,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
             {
                 VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
             } break;
+        case GGML_TYPE_NVFP4:
+            {
+                // UE4M3 scales are uint8_t — all byte values are valid
+                GGML_UNUSED(data);
+                GGML_UNUSED(nb);
+            } break;
         case GGML_TYPE_Q2_K:
             {
                 VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h
index 3b688f31..00604f75 100644
--- a/ggml/src/ggml-quants.h
+++ b/ggml/src/ggml-quants.h
@@ -22,6 +22,7 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 *
 GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
 
 GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
 
 GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
@@ -48,6 +49,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
 //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
 GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
 GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -95,6 +97,7 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR
 GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
 GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
 GGML_API void iq2xs_init_impl(enum ggml_type type);
 GGML_API void iq2xs_free_impl(enum ggml_type type);
diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt
index 5a89d8dd..7b07b227 100644
--- a/ggml/src/ggml-sycl/CMakeLists.txt
+++ b/ggml/src/ggml-sycl/CMakeLists.txt
@@ -1,7 +1,7 @@
 message(STATUS  "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}")
 
-if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
-    message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
+if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL)$")
+    message(FATAL_ERROR "GGML_SYCL_TARGET: Invalid target, the supported options are [INTEL]")
 endif()
 
 check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
@@ -25,6 +25,11 @@ ggml_add_backend_library(ggml-sycl
 
 file(GLOB   GGML_HEADERS_SYCL "*.hpp")
 file(GLOB   GGML_SOURCES_SYCL "*.cpp")
+file(GLOB   SRCS "template-instances/fattn-tile*.cpp")
+list(APPEND GGML_SOURCES_SYCL ${SRCS})
+file(GLOB   SRCS "template-instances/fattn-vec*.cpp")
+list(APPEND GGML_SOURCES_SYCL ${SRCS})
+
 target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
 
 if (WIN32)
@@ -125,106 +130,28 @@ endif()
 target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
 
 if (GGML_SYCL_F16)
-    if (GGML_SYCL_TARGET STREQUAL "AMD")
-        message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
-    endif()
     add_compile_definitions(GGML_SYCL_F16)
 endif()
 
 if (GGML_SYCL_TARGET STREQUAL "INTEL")
     add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
     target_link_options(ggml-sycl PRIVATE  -Xs   -ze-intel-greater-than-4GB-buffer-required)
-elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
-elseif (GGML_SYCL_TARGET STREQUAL "AMD")
-    # INFO: Allowed Sub_group_sizes are not consistent through all
-    # hip targets. For example, 64 is used for certain models, but the backend
-    # does not support it.
-    # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
-    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
-else()
-    # default for other target
-    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
-endif()
 
-if (GGML_SYCL_GRAPH)
-    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
-endif()
-
-# Link against Intel oneMKL or oneMath
-if (GGML_SYCL_TARGET STREQUAL "INTEL")
-    # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
-    # See https://github.com/uxlfoundation/oneMath/issues/654
+    # Link against Intel oneMKL
     if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
         set(SYCL_COMPILER ON)
     endif()
     find_package(MKL REQUIRED)
     target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
-    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
 else()
-    find_package(oneMath QUIET)
-    if (NOT oneMath_FOUND)
-        message(STATUS "oneMath not found: oneMath will be automatically downloaded")
-        # Use FetchContent to automatically pull and build oneMath
-        include(FetchContent)
-        set(BUILD_FUNCTIONAL_TESTS False)
-        set(BUILD_EXAMPLES False)
-        set(TARGET_DOMAINS blas)
-        if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-            set(ENABLE_MKLCPU_BACKEND False)
-            set(ENABLE_MKLGPU_BACKEND False)
-            set(ENABLE_CUBLAS_BACKEND True)
-        elseif (GGML_SYCL_TARGET STREQUAL "AMD")
-            set(ENABLE_MKLCPU_BACKEND False)
-            set(ENABLE_MKLGPU_BACKEND False)
-            set(ENABLE_ROCBLAS_BACKEND True)
-            # Ensure setting a string variable here is not overriden by oneMath CACHE variables
-            cmake_policy(SET CMP0126 NEW)
-            # Setting the device architecture is only needed and useful for AMD devices in oneMath
-            set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE)
-        endif()
-        FetchContent_Declare(
-            ONEMATH
-            GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git
-            GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142
-        )
-        FetchContent_MakeAvailable(ONEMATH)
-        # Create alias to match with find_package targets name
-        function(onemath_alias target)
-            if (TARGET ${target}_obj)
-                # Silence verbose warnings from external libraries
-                target_compile_options(${target}_obj PRIVATE -w)
-            endif()
-            if (TARGET ${target})
-                add_library(ONEMATH::${target} ALIAS ${target})
-            endif()
-        endfunction()
-        onemath_alias(onemath)
-        onemath_alias(onemath_blas_mklcpu)
-        onemath_alias(onemath_blas_mklgpu)
-        onemath_alias(onemath_blas_cublas)
-        onemath_alias(onemath_blas_rocblas)
-    endif()
+    # default for other target
+    message(FATAL_ERROR "GGML_SYCL_TARGET is not supported")
+    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
+endif()
 
-    # Below oneMath compile-time dispatching is used for better performance
-    if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas)
-        target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
-        target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
-        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA)
-    elseif (GGML_SYCL_TARGET STREQUAL "AMD")
-        if (NOT GGML_SYCL_DEVICE_ARCH)
-            message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
-        endif()
-        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas)
-        target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
-        target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
-        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD)
-    else()
-        # Fallback to oneMath runtime dispatcher
-        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath)
-        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC)
-    endif()
+if (GGML_SYCL_GRAPH)
+    message(STATUS "find GGML_SYCL_GRAPH")
+    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
 endif()
 
 if (GGML_SYCL_DEVICE_ARCH)
diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp
index 00c073cf..8929017a 100644
--- a/ggml/src/ggml-sycl/add-id.cpp
+++ b/ggml/src/ggml-sycl/add-id.cpp
@@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
   const int32_t* src2_d = (const int32_t*)src2->data;
   float* dst_d = (float*)dst->data;
 
-  int threads = std::min((int)ne00, 768);  // cols
+  const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
+  assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
+
+  int threads = std::min((unsigned int)ne00, max_work_group_size);  // cols
+
   ctx.stream()->parallel_for(
       sycl::nd_range<3>(
           sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp
index 75657f3f..b30b7f2b 100644
--- a/ggml/src/ggml-sycl/backend.hpp
+++ b/ggml/src/ggml-sycl/backend.hpp
@@ -23,6 +23,7 @@
 #include "dequantize.hpp"
 #include "dmmv.hpp"
 #include "element_wise.hpp"
+#include "fattn.hpp"
 #include "gla.hpp"
 #include "im2col.hpp"
 #include "mmq.hpp"
diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp
index 0a3883ae..92dd1888 100644
--- a/ggml/src/ggml-sycl/binbcast.cpp
+++ b/ggml/src/ggml-sycl/binbcast.cpp
@@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
         int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
-        /*int s00,*/ int s01, int s02, int s03,
-        /*int s10,*/ int s11, int s12, int s13,
+        int s00, int s01, int s02, int s03,
+        int s10, int s11, int s12, int s13,
         const sycl::nd_item<3> &item_ct1) {
     const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
                     item_ct1.get_local_id(2);
@@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
     for (int i0 = i0s; i0 < ne0;
          i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
         const int i10 = i0 % ne10;
-        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
     }
 }
 
@@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
         int ne0, int ne1, int ne2, int ne3,
         int ne10, int ne11, int ne12, int ne13,
         /*int s0, */ int s1,  int s2,  int s3,
-        /*int s00,*/ int s01, int s02, int s03,
-        /*int s10,*/ int s11, int s12, int s13,
+        int s00, int s01, int s02, int s03,
+        int s10, int s11, int s12, int s13,
         const sycl::nd_item<3> &item_ct1) {
 
     const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
     dst_t * dst_row = dst + i_dst;
 
     const int i10 = i0 % ne10;
-    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]);
 }
 
 
@@ -95,7 +95,8 @@ struct bin_bcast_sycl {
                     const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
                     const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
                     const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
-                    const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
+                    const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted,
+                    queue_ptr stream) {
         int nr0 = ne10 / ne0;
         int nr1 = ne11/ne1;
         int nr2 = ne12/ne2;
@@ -123,7 +124,7 @@ struct bin_bcast_sycl {
             cnb[3] *= cne[3];
         };
 
-        if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
+        if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) {
             for (int i = 0; i < 4; i++) {
                 if (nr[i] != 1) {
                     break;
@@ -164,7 +165,7 @@ struct bin_bcast_sycl {
             size_t nb12 = cnb1[2];
             size_t nb13 = cnb1[3];
 
-            size_t s0 = nb0 / sizeof(dst_t);
+            // size_t s0 = nb0 / sizeof(dst_t);
             size_t s1 = nb1 / sizeof(dst_t);
             size_t s2 = nb2 / sizeof(dst_t);
             size_t s3 = nb3 / sizeof(dst_t);
@@ -196,9 +197,6 @@ struct bin_bcast_sycl {
             GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
             GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
 
-            GGML_ASSERT(s0 == 1);
-            GGML_ASSERT(s10 == 1);
-
             const int block_size = 128;
 
             int64_t hne0 = std::max(ne0/2LL, 1LL);
@@ -232,8 +230,8 @@ struct bin_bcast_sycl {
                         [=](sycl::nd_item<3> item_ct1) {
                             k_bin_bcast_unravel(
                                 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
-                                ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
-                                s03, s11, s12, s13, item_ct1);
+                                ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02,
+                                s03, s10, s11, s12, s13, item_ct1);
                         });
                 }
             } else {
@@ -251,7 +249,7 @@ struct bin_bcast_sycl {
                     [=](sycl::nd_item<3> item_ct1) {
                         k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1,
                                             ne2, ne3, ne10, ne11, ne12, ne13,
-                                            s1, s2, s3, s01, s02, s03, s11, s12, s13,
+                                            s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13,
                                             item_ct1);
                     });
             }
@@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
         op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
              ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
-             ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+             ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
         op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
              ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
-             nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
+             nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
              main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
         op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
              ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
-             nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+             nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
+             main_stream);
     } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
         op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
              ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
-             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
+             main_stream);
     } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
         op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
              ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
-             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1),
+             main_stream);
     } else {
         fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 519638fd..fcb0db99 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -19,10 +19,13 @@
 #include 
 
 #include "dpct/helper.hpp"
+#include "ggml.h"
+#include "ggml-impl.h"
 #include "ggml-sycl.h"
 #include "presets.hpp"
 #include "sycl_hw.hpp"
 
+namespace syclexp = sycl::ext::oneapi::experimental;
 
 #if GGML_SYCL_DNNL
 #include "dnnl.hpp"
@@ -31,6 +34,9 @@
 
 #define GGML_COMMON_DECL_SYCL
 #define GGML_COMMON_IMPL_SYCL
+#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
+#define SYCL_FAST_FP16  //don't change. remove it will break fattn-tile.hpp building
+
 /* suppress warning spam */
 #pragma clang diagnostic push
 #pragma clang diagnostic ignored "-Wnested-anon-types"
@@ -45,6 +51,8 @@ void ggml_sycl_host_free(void* ptr);
 extern int g_ggml_sycl_debug;
 extern int g_ggml_sycl_disable_optimize;
 extern int g_ggml_sycl_prioritize_dmmv;
+extern int g_ggml_sycl_enable_flash_attention;
+
 
 #if defined(__clang__) && __has_builtin(__builtin_expect)
 // Hint the optimizer to pipeline the more likely following instruction in branches
@@ -76,10 +84,10 @@ extern int g_ggml_sycl_prioritize_dmmv;
 
 
 #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
-#define VER_4VEC 610 // todo for hardward optimize.
-#define VER_GEN9 700 // todo for hardward optimize.
-#define VER_GEN12 1000000 // todo for hardward optimize.
-#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize.
+#define VER_4VEC 610 // todo for hardware optimize.
+#define VER_GEN9 700 // todo for hardware optimize.
+#define VER_GEN12 1000000 // todo for hardware optimize.
+#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardware optimize.
 
 #define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares
 
@@ -170,6 +178,10 @@ static size_t g_scratch_offset = 0;
 
 int get_current_device_id();
 
+inline int ggml_sycl_get_device() {
+    return get_current_device_id();
+}
+
 inline dpct::err0 ggml_sycl_set_device(const int device) try {
   int current_device_id;
   SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
@@ -194,11 +206,14 @@ struct optimize_feature {
 };
 
 struct sycl_device_info {
-    int     cc;                 // compute capability
+    int cc;  // compute capability
     int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
              // number of compute units on a SYCL device.
     // size_t  smpb;               // max. shared memory per block
     size_t  smpbo;              // max. shared memory per block (with opt-in)
+    int warp_size;     // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only.
+    int max_wg_per_cu; // max work groups per compute unit - refer to
+                       // cudaOccupancyMaxActiveBlocksPerMultiprocessor
     bool    vmm;                // virtual memory support
     size_t  total_vram;
     //sycl_hw_info hw_info;     \\ device id and aarch, currently not used
@@ -435,13 +450,15 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
     return a;
 }
 
-template 
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
 static __dpct_inline__ int warp_reduce_sum(int x) {
   return sycl::reduce_over_group(
       sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
 }
 
-template 
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
 static __dpct_inline__ float warp_reduce_sum(float x) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -451,7 +468,19 @@ static __dpct_inline__ float warp_reduce_sum(float x) {
   return x;
 }
 
-template 
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
+static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) {
+#pragma unroll
+  for (int offset = width / 2; offset > 0; offset >>= 1) {
+    x += dpct::permute_sub_group_by_xor(
+        item_ct1.get_sub_group(), x, offset);
+  }
+  return x;
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
 static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -465,7 +494,8 @@ static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
   return a;
 }
 
-template 
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
 static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -481,7 +511,52 @@ static constexpr int ggml_sycl_get_physical_warp_size() {
   return WARP_SIZE;
 }
 
-template 
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
+static __dpct_inline__ int warp_reduce_all(int x) {
+    if (width == ggml_sycl_get_physical_warp_size()) {
+        return sycl::all_of_group(
+            sycl::ext::oneapi::this_work_item::get_sub_group(),
+            (~0xffffffff &
+             (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
+                         .get_local_linear_id())) ||
+                x);
+    } else {
+#pragma unroll
+        for (int offset = width / 2; offset > 0; offset >>= 1) {
+            x = dpct::permute_sub_group_by_xor(
+                    sycl::ext::oneapi::this_work_item::get_sub_group(), x,
+                    offset, width) &&
+                x;
+        }
+        return x;
+    }
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
+static __dpct_inline__ int warp_reduce_any(int x) {
+    if (width == ggml_sycl_get_physical_warp_size()) {
+        return sycl::any_of_group(
+            sycl::ext::oneapi::this_work_item::get_sub_group(),
+            (0xffffffff &
+             (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
+                         .get_local_linear_id())) &&
+                x);
+    } else {
+#pragma unroll
+        for (int offset = width / 2; offset > 0; offset >>= 1) {
+            x = dpct::permute_sub_group_by_xor(
+                    sycl::ext::oneapi::this_work_item::get_sub_group(), x,
+                    offset, width) ||
+                x;
+        }
+        return x;
+    }
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template 
 static __dpct_inline__ float warp_reduce_max(float x) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -629,6 +704,42 @@ static const sycl::uint3 init_fastdiv_values(uint32_t d) {
     return sycl::uint3(mp, L, d);
 }
 
+// Maximum number of bytes that can be copied in a single instruction.
+// Set by test result.
+static constexpr int ggml_sycl_get_max_cpy_bytes() {
+    return 16;
+}
+
+// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes.
+template 
+static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) {
+    if constexpr (alignment != 0) {
+        static_assert(nbytes % alignment == 0, "bad alignment");
+    }
+    constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
+
+#pragma unroll
+    for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
+        if constexpr (nb_per_cpy == 1) {
+            ((char *) dst)[i] = ((const char *) src)[i];
+        } else if constexpr (nb_per_cpy == 2) {
+            ((short *) dst)[i] = ((const short *) src)[i];
+        } else if constexpr (nb_per_cpy == 4) {
+            ((int *) dst)[i] = ((const int *) src)[i];
+        } else if constexpr (nb_per_cpy == 8) {
+            ((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i];
+        } else if constexpr (nb_per_cpy == 16) {
+            ((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i];
+        } else {
+            static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
+        }
+    }
+}
+template 
+sycl::half2 __dpct_inline__ make_half2( T x, T y) {
+    sycl::half2 res(static_cast(x),static_cast(y));
+    return res;
+}
 
 static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
     const uint32_t hi = sycl::mul_hi(n, fastdiv_values.x());
@@ -636,6 +747,17 @@ static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_va
 }
 
 
+template 
+sycl::float2 __dpct_inline__ make_float2( T x, T y) {
+    sycl::float2 res(static_cast(x),static_cast(y));
+    return res;
+}
+
+sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) {
+    sycl::float2 float2_value(static_cast(H.x()), static_cast(H.y()));
+    return float2_value;
+}
+
 static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
     const uint32_t div_val = fastdiv(n, fastdiv_values);
     const uint32_t mod_val = n - div_val * fastdiv_values.z();
@@ -659,5 +781,188 @@ static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
     return result;
 }
 
+sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) {
+    sycl::float2 float2_value(static_cast(H.x()), static_cast(H.y()));
+    return float2_value;
+}
+
+float __dpct_inline__ __half2float(sycl::half H) {
+    return static_cast(H);
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) {
+    acc += v*u;
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) {
+    acc += v.x() * u.x();
+    acc += v.y() * u.y();
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) {
+#ifdef GGML_SYCL_F16
+    const sycl::float2 tmp = (v * u).template convert();
+    acc += tmp.x() + tmp.y();
+#else
+    const sycl::float2 tmpv = __half22float2(v);
+    const sycl::float2 tmpu = __half22float2(u);
+    acc += tmpv.x() * tmpu.x();
+    acc += tmpv.y() * tmpu.y();
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) {
+#ifdef GGML_SYCL_F16
+    acc += v*u;
+#else
+    const sycl::float2 tmpv = __half22float2(v);
+    const sycl::float2 tmpu = __half22float2(u);
+    sycl::float2 tmpacc = __half22float2(acc);
+    // tmpacc.x += tmpv.x() * tmpu.x();
+    // tmpacc.y += tmpv.y() * tmpu.y();
+    sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y());
+    acc = make_half2(tmp1.x(), tmp1.y());
+#endif // GGML_SYCL_F16
+}
+
+template 
+struct ggml_sycl_unroll {
+    template 
+    void operator()(const Func & f, Args... args) const {
+        f(n - 1, args...);
+        ggml_sycl_unroll{}(f, args...);
+    }
+};
+
+template <>
+struct ggml_sycl_unroll<1> {
+    template 
+    void operator()(const Func & f, Args... args) const {
+        f(0, args...);
+    }
+};
+
+static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) {
+    sycl::half2 ret;
+    reinterpret_cast(ret.x()) =
+        sycl::vec(sycl::fmax(a[0], b[0])).convert()[0];
+    reinterpret_cast(ret.y()) =
+        sycl::vec(sycl::fmax(a[1], b[1])).convert()[0];
+    return ret;
+}
+
+static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) {
+    return sycl::vec(
+               sycl::fmax(sycl::vec(a).convert()[0],
+                          sycl::vec(b).convert()[0]))
+        .convert()[0];
+}
+
+static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) {
+    const uint32_t mask_low  = 0x0000FFFF * (float(a[0]) > float(b[0]));
+    const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1]));
+    return mask_low | mask_high;
+}
+
+static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) {
+    // expects  fastdiv_values to contain  in  (see init_fastdiv_values)
+    return n - fastdiv(n, fastdiv_values) * fastdiv_values.z();
+}
+
+static bool fast_fp16_available(const int cc) {
+    GGML_UNUSED(cc);
+    return true;   //Intel GPUs always support FP16.
+}
+
+enum class block_reduce_method {
+    MAX,
+    SUM,
+};
+
+template
+struct block_reduce_policy;
+
+template 
+inline constexpr bool is_any = (std::is_same_v || ...);
+
+template
+inline constexpr bool ggml_sycl_dependent_false_v = false;
+
+#define WARP_32_SIZE 32
+
+template  struct block_reduce_policy {
+    static T reduce(T val) {
+        if constexpr (is_any) {
+            return warp_reduce_sum(val);
+        } else {
+            static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce sum");
+        }
+    }
+
+    static T sentinel() {
+        if constexpr (std::is_same_v) {
+            return 0.0f;
+        } else if constexpr (std::is_same_v) {
+            return sycl::float2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v) {
+            return sycl::half2(0.0f, 0.0f);
+        } else if constexpr (std::is_same_v) {
+            return 0;
+        } else {
+            static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce sum");
+        }
+    }
+};
+
+template  struct block_reduce_policy {
+    static T reduce(T val) {
+        if constexpr (is_any) {
+            return warp_reduce_max(val);
+        } else {
+            static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce max");
+        }
+    }
+
+    static T sentinel() {
+        if constexpr (std::is_same_v) {
+            return -INFINITY;
+        } else if constexpr (std::is_same_v) {
+            return sycl::half2(-INFINITY, -INFINITY);
+        } else {
+            static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce max");
+        }
+    }
+};
+
+
+template 
+static T block_reduce(T val, T * shared_vals, int block_size_template) {
+    auto item_ct1                 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    val                           = block_reduce_policy::reduce(val);
+    const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
+    const int nthreads = item_ct1.get_local_range(2);
+    const int nwarps = nthreads / WARP_SIZE;
+
+    if (block_size > warp_size) {
+        assert((block_size <= 1024) && (block_size % warp_size) == 0);
+        const int warp_id = item_ct1.get_local_id(2) / warp_size;
+        const int lane_id = item_ct1.get_local_id(2) % warp_size;
+        if (lane_id == 0) {
+            shared_vals[warp_id] = val;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        size_t nreduce = nwarps / WARP_SIZE;
+        float tmp = 0.f;
+        if (lane_id < (static_cast(block_size) / warp_size)) {
+            for (size_t i = 0; i < nreduce; i += 1)
+            {
+                tmp += shared_vals[lane_id + i * WARP_SIZE];
+            }
+        }
+        return block_reduce_policy::reduce(tmp);
+    }
+    return val;
+}
 
 #endif // GGML_SYCL_COMMON_HPP
diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp
index 8bdae364..d17aca2c 100644
--- a/ggml/src/ggml-sycl/convert.cpp
+++ b/ggml/src/ggml-sycl/convert.cpp
@@ -482,6 +482,63 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
         });
 }
 
+template 
+static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t s01, const int64_t s02, const int64_t s03) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2));
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int64_t i01 = item_ct1.get_group(1);
+    const int64_t i02 = item_ct1.get_group(0) % ne02;
+    const int64_t i03 = item_ct1.get_group(0) / ne02;
+
+    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+    const int64_t ib = ibx0 + i00/qk; // block index
+    const int64_t iqs = (i00%qk)/qr; // quant index
+    const int64_t iybs = i00 - i00%qk; // y block start index
+    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    #ifdef GGML_SYCL_F16
+        sycl::half2 v;
+    #else
+        sycl::float2 v;
+    #endif
+
+    dequantize_kernel(vx, ib, iqs, v);
+
+    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
+    y[iy0 + 0]        = ggml_sycl_cast(v.x());
+    y[iy0 + y_offset] = ggml_sycl_cast(v.y());
+}
+
+
+template 
+static void dequantize_block_nc_sycl(const void *    vx,
+                                  dst_t *         y,
+                                  const int64_t   ne00,
+                                  const int64_t   ne01,
+                                  const int64_t   ne02,
+                                  const int64_t   ne03,
+                                  const int64_t   s01,
+                                  const int64_t   s02,
+                                  const int64_t   s03,
+                                  dpct::queue_ptr stream) {
+    const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01,
+                                ne02 * ne03);
+    stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+                                           sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             GGML_UNUSED(item_ct1);
+                             dequantize_block_nc(vx, y, ne00, ne01, ne02, s01, s02, s03);
+                         });
+}
 template 
 static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
                           const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
@@ -662,7 +719,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
     }
 }
 
-to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
+
+to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
             return convert_unary_nc_sycl;
@@ -670,6 +728,16 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
         case GGML_TYPE_BF16:
             return convert_unary_nc_sycl;
 #endif
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_nc_sycl;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_nc_sycl;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_nc_sycl;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_nc_sycl;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_nc_sycl;
         default:
             return nullptr;
     }
diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp
index f8cb573e..6e621f21 100644
--- a/ggml/src/ggml-sycl/convert.hpp
+++ b/ggml/src/ggml-sycl/convert.hpp
@@ -29,6 +29,27 @@ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne0
                                    int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
 
 typedef to_t_nc_sycl_t to_fp16_nc_sycl_t;
-to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
+to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type);
+
+template
+ inline dst_t ggml_sycl_cast(src_t x) {
+    if constexpr (std::is_same_v) {
+        return x;
+    } else if constexpr (std::is_same_v) {
+        return sycl::ext::oneapi::bfloat16(float(x));
+    } else if constexpr (std::is_same_v) {
+        return static_cast(x);
+    } else if constexpr (std::is_same_v && std::is_same_v) {
+        return x.template convert();
+    } else if constexpr (std::is_same_v &&
+                         std::is_same_v>) {
+        return {x.x, x.y};
+    } else if constexpr(std::is_same_v) {
+        return int32_t(x);
+    } else {
+        return float(x);
+    }
+}
+
 
 #endif  // GGML_SYCL_CONVERT_HPP
diff --git a/ggml/src/ggml-sycl/count-equal.cpp b/ggml/src/ggml-sycl/count-equal.cpp
index b0a8b482..4580354c 100644
--- a/ggml/src/ggml-sycl/count-equal.cpp
+++ b/ggml/src/ggml-sycl/count-equal.cpp
@@ -18,7 +18,7 @@ static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
         nequal += xi == yi;
     }
 
-    nequal = warp_reduce_sum(nequal);
+    nequal = warp_reduce_sum(nequal);
 
     if (item_ct1.get_local_id(2) != 0) {
         return;
diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp
index 30ec1e8d..791d3cac 100644
--- a/ggml/src/ggml-sycl/dpct/helper.hpp
+++ b/ggml/src/ggml-sycl/dpct/helper.hpp
@@ -15,18 +15,9 @@
 
 #include 
 #include 
-#include 
-#include 
-
-#ifdef GGML_SYCL_USE_INTEL_ONEMKL
 #include 
-// Allow to use the same namespace for Intel oneMKL and oneMath
-namespace oneapi {
-    namespace math = mkl;
-}
-#else
-#include 
-#endif
+
+#include 
 
 #include "ggml.h"
 
@@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
 }
 
 template  struct matrix_info_t {
-    oneapi::math::transpose transpose_info[2];
+    oneapi::mkl::transpose transpose_info[2];
     Ts                     value_info[2];
     std::int64_t           size_info[3];
     std::int64_t           ld_info[3];
     std::int64_t           groupsize_info;
 };
 
-inline auto get_onemath_backend(sycl::queue& queue)
-#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
-  -> sycl::queue&
-#endif
-{
-// If the backend is known at compile-time, use oneMath backend_selector to use
-// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
-// fallback to runtime dispatching.
-#if defined(GGML_SYCL_NVIDIA)
-    return oneapi::math::backend_selector{ queue };
-#elif defined(GGML_SYCL_AMD)
-    return oneapi::math::backend_selector{ queue };
-#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
-    return queue;
-#else
-    static_assert(false, "Unsupported backend");
-#endif
-}
-
 namespace dpct
 {
     typedef sycl::queue *queue_ptr;
@@ -1735,7 +1707,7 @@ namespace dpct
     namespace detail
     {
     template 
-    inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+    inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
                           int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
                           const void * beta, void * c, int ldc) {
         Ts   alpha_value = dpct::get_value(reinterpret_cast(alpha), q);
@@ -1743,7 +1715,7 @@ namespace dpct
         auto data_a      = get_memory(a);
         auto data_b      = get_memory(b);
         auto data_c      = get_memory(c);
-        oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
+        oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
                                                lda, data_b, ldb, beta_value, data_c, ldc);
     }
 
@@ -1775,7 +1747,7 @@ namespace dpct
         };
 
         template 
-        inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
                                     int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
                                     int ldb, const void * beta, void ** c, int ldc, int batch_size,
                                     matrix_info_t * matrix_info) {
@@ -1794,8 +1766,8 @@ namespace dpct
             matrix_info->ld_info[2] = ldc;
             matrix_info->groupsize_info = batch_size;
 
-            sycl::event e = oneapi::math::blas::column_major::gemm_batch(
-                get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
+            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
+                q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
                 matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
                 reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info,
                 reinterpret_cast(b), matrix_info->ld_info + 1,
@@ -1804,7 +1776,7 @@ namespace dpct
         }
 
         template 
-        inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
                                     int m, int n, int k, const void * alpha, const void * a, int lda,
                                     long long int stride_a, const void * b, int ldb, long long int stride_b,
                                     const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
@@ -1813,7 +1785,7 @@ namespace dpct
             auto data_a = get_memory(a);
             auto data_b = get_memory(b);
             auto data_c = get_memory(c);
-            oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
+            oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
                                                          data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
                                                          data_c, ldc, stride_c, batch_size);
         }
@@ -2300,7 +2272,7 @@ namespace dpct
                            sycl::range<3>(x, y, 1), direction);
     }
 
-    inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
+    inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
                      int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
                      library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
                      library_data_t scaling_type) {
@@ -2367,7 +2339,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_impl(
+            detail::gemm_impl(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
             break;
         }
@@ -2406,7 +2378,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_impl(
+            detail::gemm_impl(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
             break;
         }
@@ -2448,7 +2420,7 @@ namespace dpct
     /// \param [in] ldc Leading dimension of C.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
                            int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
                            const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
                            library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2486,7 +2458,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl(
+            detail::gemm_batch_impl(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
@@ -2494,7 +2466,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl(
+            detail::gemm_batch_impl(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
@@ -2570,7 +2542,7 @@ namespace dpct
     /// \param [in] stride_c Stride between the different C matrices.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
                            int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
                            long long int stride_a, const void * b, library_data_t b_type, int ldb,
                            long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
@@ -2643,7 +2615,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl(
+            detail::gemm_batch_impl(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
                 batch_size);
             break;
@@ -2652,7 +2624,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl(
+            detail::gemm_batch_impl(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
                 batch_size);
             break;
@@ -3025,6 +2997,778 @@ namespace dpct
       return 0;
     }
 
+    template 
+    class args_selector;
+
+    /// args_selector is a helper class for extracting arguments from an
+    /// array of pointers to arguments or buffer of arguments to pass to a
+    /// kernel function.
+    ///
+    /// \param R(Ts...) The type of the kernel
+    /// \param n_nondefault_params The number of nondefault parameters of the
+    /// kernel (excluding parameters that like sycl::nd_item, etc.) \param
+    /// n_default_params The number of default parameters of the kernel
+    ///
+    /// Example usage:
+    /// With the following kernel:
+    ///   void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
+    ///   f=.1) {}
+    /// and with the declaration:
+    ///   args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
+    /// we have:
+    ///   selector.get<0>() returns a reference to sycl::float*,
+    ///   selector.get<1>() returns a reference to int,
+    ///   selector.get<2>() returns a reference to float
+    template 
+    class args_selector {
+      private:
+        void **kernel_params;
+        char *args_buffer;
+
+        template  static constexpr int account_for_default_params() {
+            constexpr int n_total_params = sizeof...(Ts);
+            if constexpr (i >= n_nondefault_params) {
+                return n_total_params - n_default_params +
+                       (i - n_nondefault_params);
+            } else {
+                return i;
+            }
+        }
+
+      public:
+        /// Get the type of the ith argument of R(Ts...)
+        /// \param [in] i Index of parameter to get
+        /// \returns Type of ith parameter
+        template 
+        using arg_type = std::tuple_element_t(),
+                                              std::tuple>;
+        static constexpr int params_num = sizeof...(Ts);
+
+      private:
+        template  static constexpr int get_offset() {
+            if constexpr (i == 0) {
+                // we can assume args_buffer is properly aligned to the
+                // first argument
+                return 0;
+            } else {
+                constexpr int prev_off = get_offset();
+                constexpr int prev_past_end =
+                    prev_off + sizeof(arg_type);
+                using T = arg_type;
+                // is the past-the-end of the i-1st element properly aligned
+                // with the ith element's alignment?
+                if constexpr (prev_past_end % alignof(T) == 0) {
+                    return prev_past_end;
+                }
+                // otherwise bump prev_past_end to match alignment
+                else {
+                    return prev_past_end +
+                           (alignof(T) - (prev_past_end % alignof(T)));
+                }
+            }
+        }
+
+        static char *get_args_buffer(void **extra) {
+            if (!extra)
+                return nullptr;
+            for (; (std::size_t)*extra != 0; ++extra) {
+                if ((std::size_t)*extra == 1) {
+                    return static_cast(*(extra + 1));
+                }
+            }
+            return nullptr;
+        }
+
+      public:
+        /// If kernel_params is nonnull, then args_selector will
+        /// extract arguments from kernel_params. Otherwise, it
+        /// will extract them from extra.
+        /// \param [in] kernel_params Array of pointers to arguments
+        /// a or null pointer.
+        /// \param [in] extra Array containing pointer to argument buffer.
+        args_selector(void **kernel_params, void **extra)
+            : kernel_params(kernel_params),
+              args_buffer(get_args_buffer(extra)) {}
+
+        /// Get a reference to the ith argument extracted from kernel_params
+        /// or extra.
+        /// \param [in] i Index of argument to get
+        /// \returns Reference to the ith argument
+        template  arg_type &get() {
+            if (kernel_params) {
+                return *static_cast *>(kernel_params[i]);
+            } else {
+                return *reinterpret_cast *>(args_buffer +
+                                                        get_offset());
+            }
+        }
+    }; // COPY from DPCT head file
+       // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
+
+    /// Utility class for launching SYCL kernels through kernel
+    /// function wrapper.
+    /// For example:
+    /// A SYCL kernel function:
+    ///   void kernel_func(int *ptr, sycl::nd_item<3> item);
+    /// Kernel function wrapper:
+    ///   void kernel_func_wrapper(int *ptr) {
+    ///     sycl::queue queue = *dpct::kernel_launcher::_que;
+    ///     unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
+    ///     sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
+    ///     queue.parallel_for(
+    ///       nr,
+    ///       [=](sycl::nd_item<3> item_ct1) {
+    ///         kernel_func(ptr, item_ct1);
+    ///       });
+    ///   }
+    /// Then launch the kernel through wrapper like:
+    ///   typedef void(*fpt)(int *);
+    ///   fpt fp = kernel_func_wrapper;
+    ///   dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
+    ///   device_ptr);
+    /// If the origin function type is erased, then need to register it first:
+    ///   void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
+    ///   dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
+    ///   0, 0);
+    class kernel_launcher {
+        template 
+        static void launch_helper(FuncT &&func, ArgSelector &selector,
+                                  std::index_sequence) {
+            func(selector.template get()...);
+        }
+        static void set_execution_config(dim3 group_range, dim3 local_range,
+                                         unsigned int local_mem_size,
+                                         queue_ptr que) {
+            if (que) {
+                _que = que;
+            } else {
+                _que = &get_default_queue();
+            }
+            _nr = sycl::nd_range<3>(
+                static_cast>(group_range * local_range),
+                static_cast>(local_range));
+            _local_mem_size = local_mem_size;
+
+
+        };
+        static inline std::mutex kernel_function_ptr_map_mutex;
+
+      public:
+        /// Variables for storing execution configuration.
+        static inline thread_local sycl::queue *_que = nullptr;
+        static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
+        static inline thread_local unsigned int _local_mem_size = 0;
+        /// Map for retrieving launchable functor from a raw pointer.
+        static inline std::map<
+            const void *,
+            std::function>
+            kernel_function_ptr_map = {};
+
+        /// Registers a kernel function pointer with a corresponding launchable
+        /// functor.
+        /// \param [in] func Pointer to the kernel function.
+        /// \param [in] launcher Functor to handle kernel invocation.
+        static void register_kernel_ptr(
+            const void *func,
+            std::function
+                launcher) {
+            std::lock_guard lock(kernel_function_ptr_map_mutex);
+            kernel_function_ptr_map[func] = std::move(launcher);
+        }
+        /// Launches a kernel function with arguments provided directly through
+        /// kernel function wrapper.
+        /// \tparam FuncT Type of the kernel function wrapper.
+        /// \tparam ArgsT Types of kernel arguments.
+        /// \param [in] func Pointer to the kernel function wrapper.
+        /// \param [in] group_range SYCL group range.
+        /// \param [in] local_range SYCL local range.
+        /// \param [in] local_mem_size The size of local memory required by the
+        /// kernel function. \param [in] que SYCL queue used to execute kernel.
+        /// \param [in] args Kernel arguments.
+        template 
+        static std::enable_if_t, void>
+        launch(FuncT *func, dim3 group_range, dim3 local_range,
+               unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
+            set_execution_config(group_range, local_range, local_mem_size, que);
+            func(args...);
+        }
+        /// Launches a kernel function through registered kernel function
+        /// wrapper. \param [in] func Pointer to the registered kernel function
+        /// wrapper. \param [in] group_range SYCL group range. \param [in]
+        /// local_range SYCL local range. \param [in] args Array of pointers to
+        /// kernel arguments. \param [in] local_mem_size The size of local
+        /// memory required by the kernel function. \param [in] que SYCL queue
+        /// used to execute kernel.
+        static void launch(const void *func, dim3 group_range, dim3 local_range,
+                           void **args, unsigned int local_mem_size,
+                           queue_ptr que) {
+            std::lock_guard lock(kernel_function_ptr_map_mutex);
+            auto Iter = kernel_function_ptr_map.find(func);
+            if (Iter == kernel_function_ptr_map.end()) {
+                throw std::runtime_error("dpct::launch() : no registered "
+                                         "kernel function wrapper found.");
+            }
+            (Iter->second)(group_range, local_range, args, local_mem_size, que);
+        }
+        /// Launches a kernel function with packed arguments through kernel
+        /// function wrapper.
+        /// \tparam FuncT Type of the kernel function wrapper.
+        /// \param [in] func Pointer to the kernel function wrapper.
+        /// \param [in] group_range SYCL group range.
+        /// \param [in] local_range SYCL local range.
+        /// \param [in] args Array of pointers to kernel arguments.
+        /// \param [in] local_mem_size The size of local memory required by the
+        /// kernel function. \param [in] que SYCL queue used to execute kernel.
+        template 
+        static std::enable_if_t, void>
+        launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
+               unsigned int local_mem_size, queue_ptr que) {
+            constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
+            set_execution_config(group_range, local_range, local_mem_size, que);
+            args_selector selector(args, nullptr);
+            launch_helper(func, selector, std::make_index_sequence{});
+        }
+    }; // COPY from DPCT head file
+       // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
+
+    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
+    template 
+    T select_from_sub_group(
+        sycl::sub_group g,
+        T x,
+        int remote_local_id,
+        int logical_sub_group_size = 32) {
+      unsigned int start_index = g.get_local_linear_id() /
+                                 logical_sub_group_size *
+                                 logical_sub_group_size;
+      return sycl::select_from_group(
+          g, x, start_index + remote_local_id % logical_sub_group_size);
+    }
+
+    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
+    template 
+    void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
+      auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
+      int lane = sg.get_local_linear_id();
+
+      int lane_group8_row = lane / 8;
+      int lane_group8_col = lane % 8;
+
+      if (!trans) {
+        // calculate the source lane
+        int src_lane = 2 * lane_group8_row;
+        if (lane_group8_col >= 4)
+          src_lane += 1;
+
+        // Broadcast the address from the source lane
+        auto recv_addr_uintp =
+            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
+
+        // Cast the received address from uintptr_t to the type of 'm'
+        auto recv_addr = reinterpret_cast(recv_addr_uintp);
+
+        // Non-transposed load
+        *m = recv_addr[lane_group8_col % 4];
+      } else {
+        // calculate the source lane
+        int src_lane = (lane % 4) * 2;
+
+        // Broadcast the address from the source lane
+        auto recv_addr_uintp_1 =
+            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
+        auto recv_addr_uintp_2 =
+            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
+
+        // Cast the received address from uintptr_t to 'half *'
+        auto recv_addr_1 = reinterpret_cast(recv_addr_uintp_1);
+        auto recv_addr_2 = reinterpret_cast(recv_addr_uintp_2);
+
+        // Transposed load
+        int index = lane / 4;
+        sycl::half val0 = recv_addr_1[index];
+        sycl::half val1 = recv_addr_2[index];
+
+        // Combine the two 16-bits into one 32-bit value
+        sycl::half2 val = sycl::half2(val0, val1);
+        *m = *reinterpret_cast(&val);
+      }
+    }
+
+    template 
+    void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
+      // Load 1st matrix
+      ldmatrix(addr, m1, trans, 0);
+      // Load 2nd matrix
+      ldmatrix(addr, m2, trans, 1);
+    }
+
+    template 
+    void ldmatrix(
+        uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
+      // Load 1st matrix
+      ldmatrix(addr, m1, trans, 0);
+      // Load 2nd matrix
+      ldmatrix(addr, m2, trans, 1);
+      // Load 3rd matrix
+      ldmatrix(addr, m3, trans, 2);
+      // Load 4th matrix
+      ldmatrix(addr, m4, trans, 3);
+    }
+
+    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
+
+    /// A helper struct that defines the pack type for the input matrix
+    /// fragments
+    /// of mma() function based on the type of input matrix fragments.
+    /// The MMAType struct is specialized for different types of input matrices.
+    /// Currently, the specialization for f16, bf16 and s8 types is defined
+    /// below. \tparam [in] T The type of the input matrix fragments
+    template 
+    struct MMAType {
+      using PackType = uint32_t;
+    };
+
+    /// Each work item of a sub-group (limited to size 32) calling this function
+    /// calculates a subset fragment for the output matrix D using MAD operation
+    /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
+    /// types:
+    /// - m8n8k4 (f32.f16.f16.f32)
+    /// - m8n8k16 (s32.s8.s8.s32)
+    /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
+    /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
+    /// - m16n8k32 (s32.s8.s8.s32)
+    /// Here, m, n & k define the shapes of A, B & C matrices respectively
+    /// (A = [m x k], B = [k x n], C = [m x n]).
+    /// \tparam [in] M The rows of A, C & D matrices
+    /// \tparam [in] N The columns of B, C, D matrices
+    /// \tparam [in] K The columns & rows of A & B matrices respectively
+    /// \tparam [in] ABType The type of the input matrix (A & B) fragment
+    /// \tparam [in] CDType The type of the output matrix (C & D) fragment
+    /// \param [out] d_mat_frag The fragment of the output matrix D to store the
+    /// result of A * B + C
+    /// \param [in] a_mat_frag The fragment of the input matrix A to be
+    /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
+    /// the input matrix B to be multiplied with A matrix fragment \param [in]
+    /// c_mat_frag The fragment of the input matrix C to be added with the
+    /// result of A * B fragments
+    template 
+    void mma(
+        volatile void** d_mat_frag,
+        void* a_mat_frag,
+        void* b_mat_frag,
+        void* c_mat_frag) {
+      auto d = reinterpret_cast(d_mat_frag);
+      auto a =
+          reinterpret_cast::PackType*>(a_mat_frag);
+      auto b =
+          reinterpret_cast::PackType*>(b_mat_frag);
+      auto c = reinterpret_cast(c_mat_frag);
+
+      auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
+      int lane = sg.get_local_linear_id();
+
+      static_assert(
+          (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
+              (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
+              (M == 16 && N == 8 && K == 32),
+          "Unsupported MMA shape!");
+
+      short row_load_offset = 4 * (lane >> 2);
+      short col_load_offset = 8 * (lane % 4);
+
+      if constexpr (M == 8 && N == 8 && K == 4) {
+        if constexpr (std::is_floating_point_v) {
+          col_load_offset = row_load_offset % 16;
+
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+          *d[4] = c[4];
+          *d[5] = c[5];
+          *d[6] = c[6];
+          *d[7] = c[7];
+
+          // Calculate the row and col offset indices to iterate through the row
+          // & col fragments of A & B matrices
+          int r_ind = (lane % 2) ? 1 : 0;
+          int c_ind = ((lane % 4) / 2) ? 2 : 0;
+
+          // Each sub-group is responsible for computing a fragment size of 8*8
+          // elements of matrix D for each of 4 MMA computations.
+          // Each work item computes 8 elements of matrix D by gathering
+          // their corresponding col & row matrix fragments of length k (4)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (i % 4) if (lane < 16) else (i % 4) + 4
+          // col0 = (lane % 4)
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          typename MMAType::PackType recv_a[2], recv_b[2];
+
+          for (int i = 0; i < 4; i++) {
+            // Load partial fragment from col0 of matrix A ({a0, a1})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from col0 of matrix A ({a2, a3})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+
+            // Load partial fragment from row0 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from row0 of matrix B ({b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+
+            auto ra = reinterpret_cast(recv_a);
+            auto rb = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment (for
+            // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
+            // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
+            // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
+            // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
+            // d3 += col1{ a3 } * row0{ b3 }
+            *d[0] +=
+                static_cast(ra[r_ind]) * static_cast(rb[c_ind]);
+            *d[1] += static_cast(ra[r_ind]) *
+                     static_cast(rb[c_ind + 1]);
+            *d[2] += static_cast(ra[r_ind + 2]) *
+                     static_cast(rb[c_ind]);
+            *d[3] += static_cast(ra[r_ind + 2]) *
+                     static_cast(rb[c_ind + 1]);
+
+            // Load partial fragment from row1 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
+            // Load partial fragment from row1 of matrix B ({b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
+
+            // (for even work item indices)
+            // d0 += col0{ a0 } * row1{ b0 }
+            // d1 += col0{ a0 } * row1{ b1 }
+            // d2 += col1{ a2 } * row1{ b0 }
+            // d3 += col1{ a2 } * row1{ b1 }
+            // (for odd work item indices)
+            // d0 += col0{ a1 } * row1{ b2 }
+            // d1 += col0{ a1 } * row1{ b3 }
+            // d2 += col1{ a3 } * row1{ b2 }
+            // d3 += col1{ a3 } * row1{ b3 }
+            *d[4] +=
+                static_cast(ra[r_ind]) * static_cast(rb[c_ind]);
+            *d[5] += static_cast(ra[r_ind]) *
+                     static_cast(rb[c_ind + 1]);
+            *d[6] += static_cast(ra[r_ind + 2]) *
+                     static_cast(rb[c_ind]);
+            *d[7] += static_cast(ra[r_ind + 2]) *
+                     static_cast(rb[c_ind + 1]);
+          }
+        }
+      } else if constexpr (M == 8 && N == 8 && K == 16) {
+        if constexpr (std::is_integral_v) {
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 2 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (16)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = ((lane % 4) * 4) + i
+          // col0 = (lane >> 2)
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType::PackType recv_a, recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+            recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto a = reinterpret_cast(&recv_a);
+            auto b = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
+            // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
+            // col1{ b0, b1, b2, b3 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += a[j] * b[j];
+              *d[1] += a[j] * b[j + 4];
+            }
+          }
+        }
+      } else if constexpr (M == 16 && N == 8 && K == 8) {
+        if constexpr (std::is_floating_point_v) {
+          // Init D matrix fragment with C matrix fragment
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (8)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
+          // col0 = (lane % 4) * 2 + (i & 0x1)
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a2, a3})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto ra = reinterpret_cast(recv_a);
+            auto rb = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
+            // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
+            // } * col1{ b0, b1 }
+            for (int j = 0; j < 2; j++) {
+              *d[0] += static_cast(ra[j]) * static_cast(rb[j]);
+              *d[1] +=
+                  static_cast(ra[j]) * static_cast(rb[j + 2]);
+              *d[2] +=
+                  static_cast(ra[j + 2]) * static_cast(rb[j]);
+              *d[3] +=
+                  static_cast(ra[j + 2]) * static_cast(rb[j + 2]);
+            }
+          }
+        }
+      } else if constexpr (M == 16 && N == 8 && K == 16) {
+        if constexpr (std::is_floating_point_v) {
+          // Init D matrix fragment with C matrix fragment
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (8)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8
+          // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType::PackType recv_a[4], recv_b[4];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row0 of matrix A ({a2, a3})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a0, a1})
+            recv_a[2] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a2, a3})
+            recv_a[3] =
+                dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
+
+            // Load partial fragment from col0 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1})
+            recv_b[2] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
+            // Load partial fragment from col1 of matrix B ({b2, b3})
+            recv_b[3] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
+
+            auto ra = reinterpret_cast(recv_a);
+            auto rb = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
+            // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
+            // col1{ b0, b1, b2, b3 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += static_cast(ra[j]) * static_cast(rb[j]);
+              *d[1] +=
+                  static_cast(ra[j]) * static_cast(rb[j + 4]);
+              *d[2] +=
+                  static_cast(ra[j + 4]) * static_cast(rb[j]);
+              *d[3] += static_cast(ra[j + 4]) *
+                       static_cast(rb[j + 4]);
+            }
+          }
+        } else if constexpr (std::is_integral_v) {
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (8)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8
+          // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto ra = reinterpret_cast(recv_a);
+            auto rb = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
+            // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
+            // col1{ b4, b5, b6, b7 }
+            for (int i = 0; i < 4; i++) {
+              *d[0] += ra[i] * rb[i];
+              *d[1] += ra[i] * rb[i + 4];
+              *d[2] += ra[i + 4] * rb[i];
+              *d[3] += ra[i + 4] * rb[i + 4];
+            }
+          }
+        }
+      } else if constexpr (M == 16 && N == 8 && K == 32) {
+        if constexpr (std::is_integral_v) {
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (32)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8
+          // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
+          // & 0x3) As each row & col fragment of A & B matrices is distributed
+          // across 4 work items, each iteration of below loop loads a partial
+          // fragment of matrix A (row) and matrix B (col) using the row & col
+          // offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto a = reinterpret_cast(recv_a);
+            auto b = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
+            // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
+            // col1{ b0, b1, b2, b3 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += a[j] * b[j];
+              *d[1] += a[j] * b[j + 4];
+              *d[2] += a[j + 4] * b[j];
+              *d[3] += a[j + 4] * b[j + 4];
+            }
+          }
+
+          for (int i = 0; i < 4; i++) {
+            typename MMAType::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a12, a13, a14,
+            // a15})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
+
+            auto a = reinterpret_cast(recv_a);
+            auto b = reinterpret_cast(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
+            // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
+            // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
+            // a15 } * col1{ b4, b5, b6, b7 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += a[j] * b[j];
+              *d[1] += a[j] * b[j + 4];
+              *d[2] += a[j + 4] * b[j];
+              *d[3] += a[j + 4] * b[j + 4];
+            }
+          }
+        }
+      }
+    }
 } // COPY from DPCT head files
 
 #endif // GGML_SYCL_DPCT_HELPER_HPP
diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp
index 8d83b244..acd51bf4 100644
--- a/ggml/src/ggml-sycl/element_wise.cpp
+++ b/ggml/src/ggml-sycl/element_wise.cpp
@@ -9,23 +9,32 @@
 #define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
     (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
 
+static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
+        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+        const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2);
 
-static void acc_f32(const float * x, const float * y, float * dst, const int ne,
-    const int ne10, const int ne11, const int ne12,
-    const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
-    const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
     if (i >= ne) {
         return;
     }
-    int src1_idx = i - offset;
-    int oz = src1_idx / nb2;
-    int oy = (src1_idx - (oz * nb2)) / nb1;
-    int ox = src1_idx % nb1;
-    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
-        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
-    } else {
-        dst[i] = x[i];
+
+    int64_t src1_idx = i - offset;
+
+    int64_t tmp = src1_idx;
+    const int64_t i13 = tmp / s13;
+    tmp -= i13 * s13;
+    const int64_t i12 = tmp / s12;
+    tmp -= i12 * s12;
+    const int64_t i11 = tmp / s11;
+    tmp -= i11 * s11;
+    const int64_t i10 = tmp;
+
+    float val = x[i];
+    if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
+        val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
     }
+    dst[i] = val;
 }
 
 /* Unary OP funcs */
@@ -123,6 +132,15 @@ static __dpct_inline__ T op_log(T x) {
     return sycl::log(x);
 }
 
+template
+static __dpct_inline__ T op_softplus(T x) {
+    const float xf = (float) x;
+    const float ax = sycl::fabs(xf);
+    const float m  = sycl::fmax(xf, 0.0f);
+    const float y  = m + sycl::log1p(sycl::exp(-ax));
+    return (T) y;
+}
+
 template
 static __dpct_inline__ T op_neg(T x) {
     return -x;
@@ -355,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const
 
 namespace ggml_sycl_detail {
 static void acc_f32_sycl(const float *x, const float *y, float *dst,
-                         const int n_elements, const int ne10, const int ne11,
-                         const int ne12, const int nb1, const int nb2,
-                         const int offset, queue_ptr stream) {
-    int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
-    stream->parallel_for(
-        sycl::nd_range<1>(sycl::range<1>(num_blocks) *
-                              sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
-                          sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
-        [=](sycl::nd_item<1> item_ct1) {
-            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
-                    item_ct1);
-        });
+                         const int64_t n_elements, const int64_t ne10, const int64_t ne11,
+                         const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3,
+                         const int64_t offset, queue_ptr stream) {
+    const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
+    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
+                                           sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
+                         });
 }
 
 template
@@ -393,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
 
 template
 static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
-#if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
     GGML_ASSERT(dst->src[0]->type == dst->type);
+
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     switch (dst->type) {
-#if defined (GGML_SYCL_F16)
         case GGML_TYPE_F16:
             {
                 auto data_pts = cast_data(dst);
                 kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...);
                 break;
             }
-#endif
         case GGML_TYPE_F32:
             {
                 auto data_pts = cast_data(dst);
@@ -425,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx,
 
 template
 static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
-#if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
     GGML_ASSERT(dst->src[0]->type == dst->type);
+
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     const ggml_tensor * src0 = dst->src[0];
@@ -454,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
         GGML_ASSERT(src0->type == src1->type);
     }
     switch (dst->type) {
-#if defined (GGML_SYCL_F16)
         case GGML_TYPE_F16:
             {
                 sycl::half * src0_p = (sycl::half *) src0_d;
@@ -475,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
                                std::forward(args)...);
                 break;
             }
-#endif
         case GGML_TYPE_F32:
             {
                 float * src0_p = (float *) src0_d;
@@ -504,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c
 
 template
 static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
-#if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
+
     GGML_ASSERT(dst->src[0]->type == dst->type);
 
     dpct::queue_ptr main_stream = ctx.stream();
@@ -521,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
     const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
     const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
     switch (dst->type) {
-#if defined (GGML_SYCL_F16)
         case GGML_TYPE_F16:
             {
                 auto data_pts = cast_data(dst);
@@ -530,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
                                main_stream, std::forward(args)...);
                 break;
             }
-#endif
         case GGML_TYPE_F32:
             {
                 auto data_pts = cast_data(dst);
@@ -695,6 +692,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
         });
 }
 
+static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_softplus(x);
+    });
+}
+
 static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
         return op_neg(x);
@@ -821,16 +824,9 @@ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tens
 }
 
 static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
-        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
-            const int num_blocks = ceil_div(k_elements, 256);
-            stream->parallel_for(
-                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
-                                  sycl::range<1>(256)),
-                [=](sycl::nd_item<1> item_ct1) {
-                    unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
-                });
-        });
+    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
+        return op_ceil(x);
+    });
 }
 
 static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -860,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens
 }
 
 static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *)  dst->data;
+
+    dpct::queue_ptr stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    const float * src0_dd = static_cast(dst->src[0]->data);
-    const float * src1_dd = static_cast(dst->src[1]->data);
-    float *       dst_dd  = static_cast(dst->data);
 
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
+    GGML_ASSERT(ggml_is_contiguously_allocated(dst));
 
-    ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
+    const int64_t s1     = dst->op_params[0] / sizeof(float);
+    const int64_t s2     = dst->op_params[1] / sizeof(float);
+    const int64_t s3     = dst->op_params[2] / sizeof(float);
+    const int64_t offset = dst->op_params[3] / sizeof(float);
+
+    ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst),
+        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+        s1, s2, s3, offset, stream);
 }
 
 static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -1101,6 +1106,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_log(ctx, dst);
 }
 
+void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_softplus(ctx, dst);
+}
+
 void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     ggml_sycl_op_neg(ctx, dst);
diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp
index 0913a2e5..7c719746 100644
--- a/ggml/src/ggml-sycl/element_wise.hpp
+++ b/ggml/src/ggml-sycl/element_wise.hpp
@@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
+void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
 void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp
new file mode 100644
index 00000000..ed00d03c
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn-common.hpp
@@ -0,0 +1,1179 @@
+#pragma once
+
+#include 
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "convert.hpp"
+#include "vecdotq.hpp"
+
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+
+
+#define FATTN_KQ_STRIDE       256
+#define HALF_MAX_HALF         sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
+
+typedef void (*fattn_kernel_t)(
+    const char* Q,
+    const char* K,
+    const char* V,
+    const char* mask,
+    const char* sinks,
+    const int* KV_max,
+    float* dst,
+    sycl::float2* dst_meta,
+    const float scale,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const uint32_t n_head_log2,
+    const float logit_softcap,
+    const int32_t ne00,
+    const sycl::uint3 ne01,
+    const int32_t ne02,
+    const int32_t ne03,
+    const int32_t nb01,
+    const int32_t nb02,
+    const int32_t nb03,
+    const int32_t ne10,
+    const int32_t ne11,
+    const int32_t ne12,
+    const int32_t ne13,
+    const int32_t nb11,
+    const int32_t nb12,
+    const int64_t nb13,
+    const int32_t nb21,
+    const int32_t nb22,
+    const int64_t nb23,
+    const int32_t ne31,
+    const int32_t ne32,
+    const int32_t ne33,
+    const int32_t nb31,
+    const int32_t nb32,
+    const int64_t nb33);
+
+typedef float (*vec_dot_KQ_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template 
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c,
+                                                      const void * __restrict__ Q_v,
+                                                      const int * __restrict__ Q_q8,
+                                                      const void * __restrict__ Q_ds_v) {
+    const sycl::half2 * K_h2 = (const sycl::half2 *) K_c;
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
+        sycl::half2 tmp[cpy_ne];
+        ggml_sycl_memcpy_1(
+            tmp,
+            K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne);
+#pragma unroll
+        for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
+#ifdef GGML_SYCL_F16
+            ggml_sycl_mad(sum,                tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#else
+            ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    return sum;
+}
+
+template 
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+
+    const block_q4_0 * K_q4_0   = (const block_q4_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_0;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+        sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y());
+    }
+
+    return sum;
+}
+
+template 
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q4_1 * K_q4_1   = (const block_q4_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert();
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+        sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
+    }
+
+    return sum;
+}
+
+template 
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q5_0 * K_q5_0   = (const block_q5_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_0;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+
+        {
+            int vh;
+            ggml_sycl_memcpy_1(&vh, K_q5_0[ib].qh);
+            vh >>= iqs8 * QI5_0;
+
+            v |= (vh <<  4) & 0x00000010; // 0 ->  4
+            v |= (vh << 11) & 0x00001000; // 1 -> 12
+            v |= (vh << 18) & 0x00100000; // 2 -> 20
+            v |= (vh << 25) & 0x10000000; // 3 -> 28
+        }
+
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+        sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y());
+    }
+
+    return sum;
+}
+
+template 
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q5_1 * K_q5_1   = (const block_q5_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_1;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+
+        {
+            int vh;
+            ggml_sycl_memcpy_1(&vh, K_q5_1[ib].qh);
+            vh >>= iqs8 * QI5_0;
+
+            v |= (vh <<  4) & 0x00000010; // 0 ->  4
+            v |= (vh << 11) & 0x00001000; // 1 -> 12
+            v |= (vh << 18) & 0x00100000; // 2 -> 20
+            v |= (vh << 25) & 0x10000000; // 3 -> 28
+        }
+
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert();
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+        sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
+    }
+
+    return sum;
+}
+
+template 
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q8_0 * K_q8_0   = (const block_q8_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib  = k_KQ / QI8_0;
+        const int iqs = k_KQ % QI8_0;
+
+        int v;
+        ggml_sycl_memcpy_1(&v, K_q8_0[ib].qs + 4*iqs);
+
+        const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v;
+        const float          Q_d  = Q_ds[k_KQ_0 / nthreads].x();
+
+        sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
+    }
+
+    return sum;
+}
+
+template 
+static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x,
+                                                    const float scale,
+                                                    int * __restrict__ yq32,
+                                                    void * __restrict__ yds) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+
+    float vals[sizeof(int)] = { 0.0f };
+#pragma unroll
+    for (int l = 0; l < int(sizeof(int)); ++l) {
+        vals[l] =
+            (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f;
+    }
+
+    float amax = sycl::fabs(vals[0]);
+    float sum  = vals[0];
+#pragma unroll
+    for (int l = 1; l < int(sizeof(int)); ++l) {
+        amax = sycl::fmax(amax, sycl::fabs(vals[l]));
+        sum += vals[l];
+    }
+#pragma unroll
+    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+        amax = sycl::fmax(
+            amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask));
+        sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask);
+    }
+
+    const float d = amax / 127;
+    int q32 = 0;
+    int8_t * q8 = (int8_t *) &q32;
+
+    if (d != 0.0f) {
+#pragma unroll
+        for (int l = 0; l < int(sizeof(int)); ++l) {
+            q8[l] = sycl::round(vals[l] / d);
+        }
+    }
+
+    yq32[item_ct1.get_local_id(2)] = q32;
+    if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) {
+        if (std::is_same::value) {
+            ((sycl::half2  *) yds)[item_ct1.get_local_id(2)/QI8_1] =  make_half2(d, sum);
+        } else {
+            ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum);
+        }
+    }
+}
+
+typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
+
+template 
+static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    if constexpr (std::is_same_v) {
+        ggml_sycl_memcpy_1(dst, (const sycl::half *) vx + i0);
+    } else if constexpr (std::is_same_v) {
+        static_assert(ne % 2 == 0, "bad ne");
+        sycl::half2 tmp[ne / 2];
+        ggml_sycl_memcpy_1(tmp, (const sycl::half *) vx + i0);
+        sycl::float2 * dst_f2 = (sycl::float2 *) dst;
+#pragma unroll
+        for (int l = 0; l < ne/2; ++l) {
+            dst_f2[l] = tmp[l].template convert();
+        }
+    } else {
+        static_assert(std::is_same_v, "unsupported type");
+    }
+}
+
+template 
+static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int64_t ib    =  i0          /  QK4_0;
+    const int     iqs   =  i0          % (QK4_0/2);
+    const int     shift = (i0 % QK4_0) / (QK4_0/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+    q = dpct::vectorized_binary(q, 0x08080808, dpct::sub_sat());
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const float d = x[ib].d;
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = d * q8[l];
+        }
+    } else {
+        static_assert(std::is_same_v, "bad type");
+    }
+}
+
+template 
+static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int64_t ib    =  i0          /  QK4_1;
+    const int     iqs   =  i0          % (QK4_1/2);
+    const int     shift = (i0 % QK4_1) / (QK4_1/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const sycl::half2 dm = x[ib].dm;
+        const sycl::half2 d  = sycl::half2(dm[0]);
+        const sycl::half2 m  = sycl::half2(dm[1]);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const sycl::float2 dm = (x[ib].dm).template convert();
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
+        }
+    } else {
+        static_assert(std::is_same_v, "bad type");
+    }
+}
+
+template 
+static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int64_t ib    =  i0          /  QK5_0;
+    const int     idq   =  i0          %  QK5_0;
+    const int     iqs   =  i0          % (QK5_0/2);
+    const int     shift = (i0 % QK5_0) / (QK5_0/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+
+    {
+        int qh;
+        ggml_sycl_memcpy_1(&qh, x[ib].qh);
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+        }
+    }
+
+    q = dpct::vectorized_binary(q, 0x10101010, dpct::sub_sat());
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const float d = x[ib].d;
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = d * q8[l];
+        }
+    } else {
+        static_assert(std::is_same_v, "bad type");
+    }
+}
+
+template 
+static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int64_t ib    =  i0          /  QK5_1;
+    const int     idq   =  i0          %  QK5_1;
+    const int     iqs   =  i0          % (QK5_1/2);
+    const int     shift = (i0 % QK5_1) / (QK5_1/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+
+    {
+        int qh;
+        ggml_sycl_memcpy_1(&qh, x[ib].qh);
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+        }
+    }
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const sycl::half2 dm = x[ib].dm;
+        const sycl::half2 d  = sycl::half2(dm[0]);
+        const sycl::half2 m  = sycl::half2(dm[1]);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v) {
+        const sycl::float2 dm = (x[ib].dm).template convert();
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
+        }
+    } else {
+        static_assert(std::is_same_v, "bad type");
+    }
+}
+
+template 
+static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const int64_t ib  = i0 / QK8_0;
+    const int     iqs = i0 % QK8_0;
+
+    static_assert(ne % 2 == 0, "bad ne");
+    int8_t qs[ne];
+    ggml_sycl_memcpy_1(qs, x[ib].qs + iqs);
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same::value) {
+        const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same::value) {
+        const float d = x[ib].d;
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = d * qs[l];
+        }
+    } else {
+        static_assert(std::is_same_v, "unsupported type");
+    }
+}
+
+template 
+constexpr vec_dot_KQ_t get_vec_dot_KQ() {
+    if constexpr (type_K == GGML_TYPE_F16) {
+        return vec_dot_fattn_vec_KQ_f16;
+    } else if constexpr (type_K == GGML_TYPE_Q4_0) {
+        return vec_dot_fattn_vec_KQ_q4_0;
+    } else if constexpr (type_K == GGML_TYPE_Q4_1) {
+        return vec_dot_fattn_vec_KQ_q4_1;
+    } else if constexpr (type_K == GGML_TYPE_Q5_0) {
+        return vec_dot_fattn_vec_KQ_q5_0;
+    } else if constexpr (type_K == GGML_TYPE_Q5_1) {
+        return vec_dot_fattn_vec_KQ_q5_1;
+    } else if constexpr (type_K == GGML_TYPE_Q8_0) {
+        return vec_dot_fattn_vec_KQ_q8_0;
+    } else {
+        static_assert(type_K == -1, "bad type");
+        return nullptr;
+    }
+}
+
+template 
+constexpr dequantize_V_t get_dequantize_V() {
+    if constexpr (type_V == GGML_TYPE_F16) {
+        return dequantize_V_f16;
+    } else if constexpr (type_V == GGML_TYPE_Q4_0) {
+        return dequantize_V_q4_0;
+    } else if constexpr (type_V == GGML_TYPE_Q4_1) {
+        return dequantize_V_q4_1;
+    } else if constexpr (type_V == GGML_TYPE_Q5_0) {
+        return dequantize_V_q5_0;
+    } else if constexpr (type_V == GGML_TYPE_Q5_1) {
+        return dequantize_V_q5_1;
+    } else if constexpr (type_V == GGML_TYPE_Q8_0) {
+        return dequantize_V_q8_0;
+    } else {
+        static_assert(type_V == -1, "bad type");
+        return nullptr;
+    }
+}
+
+template 
+static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask,
+                                      int * __restrict__ KV_max,
+                                      const int ne30,
+                                      const int s31,
+                                      const int s33,
+                                      int *     buf_iw) {
+    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int ne31     = item_ct1.get_group_range(2);
+    const int tid      = item_ct1.get_local_id(2);
+    const int sequence = item_ct1.get_group(1);
+    const int jt       = item_ct1.get_group(2);
+
+    mask += sequence*s33 + jt*ncols1*s31;
+
+    if (tid < warp_size) {
+        buf_iw[tid] = 1;
+    }
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
+    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
+        int all_inf = 1;
+
+#pragma unroll
+        for (int j = 0; j < ncols1; ++j) {
+            const sycl::float2 tmp =
+                mask[j * s31 + KV_max_sj / 2 + tid].template convert();
+            all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y())));
+        }
+
+        all_inf = warp_reduce_all(all_inf);
+        if (tid % warp_size == 0) {
+            buf_iw[tid / warp_size] = all_inf;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        all_inf = buf_iw[tid % warp_size];
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        all_inf = warp_reduce_all(all_inf);
+
+        if (!all_inf) {
+            break;
+        }
+    }
+
+    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+    KV_max_sj += FATTN_KQ_STRIDE;
+
+    if (item_ct1.get_local_id(2) != 0) {
+        return;
+    }
+
+    KV_max[sequence*ne31 + jt] = KV_max_sj;
+}
+
+template   // D == head size
+
+static void flash_attn_stream_k_fixup(float * __restrict__ dst,
+                                      const sycl::float2 * __restrict__ dst_fixup,
+                                      const int ne01,
+                                      const int ne02,
+                                      const int ne03,
+                                      const int ne11,
+                                      const int ne12,
+                                      const int nbatch_fa) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    constexpr int ncols    = ncols1 * ncols2;
+
+    const int bidx0 = item_ct1.get_group(2);
+    const int j     = item_ct1.get_group(1);
+    const int c     = item_ct1.get_group(0);
+    const int jc    = j*ncols2 + c;
+    const int tid   = item_ct1.get_local_id(2);
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
+
+    const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+    const int kbc0_stop =
+        int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+
+    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
+    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+        return;
+    }
+
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
+        return;
+    }
+
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
+
+    // Load the partial result that needs a fixup:
+    float dst_val = 0.0f;
+    float max_val = 0.0f;
+    float rowsum  = 0.0f;
+    {
+        dst_val = *dst;
+
+        const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc];
+        max_val                = tmp.x();
+        rowsum                 = tmp.y();
+    }
+
+    // Iterate over previous blocks and compute the combined results.
+    // All SYCL blocks that get here must have a previous block that needs a fixup.
+    int bidx = bidx0 - 1;
+    int kbc_stop = kbc0;
+    while(true) {
+        const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+        if (kbc == kbc_stop) { // Did not have any data.
+            bidx--;
+            kbc_stop = kbc;
+            continue;
+        }
+
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+        const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc];
+
+        // Scale the current and new value accumulators depending on the max. values.
+        const float max_val_new = sycl::fmax(max_val, tmp.x());
+
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x() - max_val_new;
+
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f;
+
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val * rowsum + scale_add * tmp.y();
+
+        max_val = max_val_new;
+
+        // If this block started in a previous tile we are done and don't need to combine additional partial results.
+        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+            break;
+        }
+        bidx--;
+        kbc_stop = kbc;
+    }
+
+    // Write back final result:
+    *dst = dst_val / rowsum;
+}
+
+template   // D == head size
+
+static void flash_attn_combine_results(const float * __restrict__ VKQ_parts,
+                                       const sycl::float2 * __restrict__ VKQ_meta,
+                                       float * __restrict__ dst,
+                                       const int parallel_blocks,
+                                       uint8_t * dpct_local) {
+    // Dimension 0: threadIdx.x
+    // Dimension 1: blockIdx.x
+    // Dimension 2: blockIdx.y
+    // Dimension 3: blockIdx.z
+    // Memory layout is permuted with [0, 2, 1, 3]
+
+    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int ne01     = item_ct1.get_group_range(2);
+    const int ne02     = item_ct1.get_group_range(1);
+
+    const int col      = item_ct1.get_group(2);
+    const int head     = item_ct1.get_group(1);
+    const int sequence = item_ct1.get_group(0);
+
+    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+    VKQ_meta  += j_dst_unrolled * parallel_blocks;
+    dst       += j_dst_unrolled *                 D;
+
+    const int tid = item_ct1.get_local_id(2);
+    __builtin_assume(tid < D);
+
+    auto meta = (sycl::float2 *) dpct_local;
+    for (int i = tid; i < 2*parallel_blocks; i += D) {
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
+    }
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    float kqmax = meta[0].x();
+    for (int l = 1; l < parallel_blocks; ++l) {
+        kqmax = sycl::max(kqmax, meta[l].x());
+    }
+
+    float VKQ_numerator   = 0.0f;
+    float VKQ_denominator = 0.0f;
+    for (int l = 0; l < parallel_blocks; ++l) {
+        const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax);
+
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
+        VKQ_denominator += KQ_max_scale * meta[l].y();
+    }
+
+    dst[tid] = VKQ_numerator / VKQ_denominator;
+}
+
+template 
+static void lauch_kernel(
+    dpct::dim3 group_range,
+    dpct::dim3 local_range,
+    queue_ptr q,
+    unsigned int local_mem_size,
+    const char* __restrict__ Q,
+    const char* __restrict__ K,
+    const char* __restrict__ V,
+    const char* __restrict__ mask,
+    const char* __restrict__ sinks,
+    const int* __restrict__ KV_max,
+    float* __restrict__ dst,
+    sycl::float2* __restrict__ dst_meta,
+    const float scale,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const uint32_t n_head_log2,
+    const float logit_softcap,
+    const int32_t ne00,
+    const sycl::uint3 ne01,
+    const int32_t ne02,
+    const int32_t ne03,
+    const int32_t nb01,
+    const int32_t nb02,
+    const int32_t nb03,
+    const int32_t ne10,
+    const int32_t ne11,
+    const int32_t ne12,
+    const int32_t ne13,
+    const int32_t nb11,
+    const int32_t nb12,
+    const int64_t nb13,
+    const int32_t nb21,
+    const int32_t nb22,
+    const int64_t nb23,
+    const int32_t ne31,
+    const int32_t ne32,
+    const int32_t ne33,
+    const int32_t nb31,
+    const int32_t nb32,
+    const int64_t nb33) {
+    GGML_UNUSED(local_mem_size);
+    q->submit([&](sycl::handler &cgh) {
+        cgh.parallel_for(
+            sycl::nd_range<3>(
+                static_cast>(group_range * local_range),
+                static_cast>(local_range)),
+            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
+                GGML_UNUSED(item_ct1);
+                fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+                             max_bias, m0, m1, n_head_log2, logit_softcap, ne00,
+                             ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11,
+                             ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23,
+                             ne31, ne32, ne33, nb31, nb32, nb33);
+            });
+    });
+}
+
+template 
+void launch_fattn(
+    ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared,
+    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) {
+
+    constexpr int ncols = ncols1 * ncols2;
+
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
+
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+    GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+    GGML_ASSERT(V->nb[0] == ggml_element_size(V));
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+
+    ggml_sycl_pool & pool = ctx.pool();
+    dpct::queue_ptr  main_stream = ctx.stream();
+    const int id  = ggml_sycl_get_device();
+    const int nsm = ggml_sycl_info().devices[id].nsm;
+
+    ggml_sycl_pool_alloc   K_f16(pool);
+    ggml_sycl_pool_alloc   V_f16(pool);
+    ggml_sycl_pool_alloc    KV_max(pool);
+    ggml_sycl_pool_alloc  dst_tmp(pool);
+    ggml_sycl_pool_alloc dst_tmp_meta(pool);
+
+    const char * K_data = (const char *) K->data;
+    size_t nb11 = K->nb[1];
+    size_t nb12 = K->nb[2];
+    size_t nb13 = K->nb[3];
+
+    const char * V_data = (const char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
+
+    if (need_f16_K && K->type != GGML_TYPE_F16) {
+        const size_t bs = ggml_blck_size(K->type);
+        const size_t ts = ggml_type_size(K->type);
+
+        K_f16.alloc(ggml_nelements(K));
+        if (ggml_is_contiguously_allocated(K)) {
+            to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst);
+            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+
+            nb11 = nb11 * bs * sizeof(sycl::half) / ts;
+            nb12 = nb12 * bs * sizeof(sycl::half) / ts;
+            nb13 = nb13 * bs * sizeof(sycl::half) / ts;
+        } else {
+            GGML_ASSERT(K->nb[0] == ts);
+            to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type);
+            const int64_t s01 = nb11 / ts;
+            const int64_t s02 = nb12 / ts;
+            const int64_t s03 = nb13 / ts;
+            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
+
+            nb11 = K->ne[0] * sizeof(sycl::half);
+            nb12 = K->ne[1] * nb11;
+            nb13 = K->ne[2] * nb12;
+        }
+        K_data = (char *) K_f16.ptr;
+    }
+
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        if (V_is_K_view) {
+            V_data = K_data;
+            nb21   = nb11;
+            nb22   = nb12;
+            nb23   = nb13;
+        } else {
+            const size_t bs = ggml_blck_size(V->type);
+            const size_t ts = ggml_type_size(V->type);
+
+            V_f16.alloc(ggml_nelements(V));
+            if (ggml_is_contiguously_allocated(V)) {
+                to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst);
+                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+                V_data = (char *) V_f16.ptr;
+
+                nb21 = nb21 * bs * sizeof(sycl::half) / ts;
+                nb22 = nb22 * bs * sizeof(sycl::half) / ts;
+                nb23 = nb23 * bs * sizeof(sycl::half) / ts;
+            } else {
+                GGML_ASSERT(V->nb[0] == ts);
+                to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type);
+                const int64_t s01 = nb21 / ts;
+                const int64_t s02 = nb22 / ts;
+                const int64_t s03 = nb23 / ts;
+                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+                nb21 = V->ne[0] * sizeof(sycl::half);
+                nb22 = V->ne[1] * nb21;
+                nb23 = V->ne[2] * nb22;
+            }
+            V_data = (char *) V_f16.ptr;
+        }
+    }
+
+    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int gqa_ratio    = Q->ne[2] / K->ne[2];
+    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
+
+    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
+    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
+    //     multiple sequences of possibly different lengths.
+    if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+        const int s31 = mask->nb[1] / sizeof(sycl::half2);
+        const int s33 = mask->nb[3] / sizeof(sycl::half2);
+
+        const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
+        const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1);
+
+        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
+        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
+
+        KV_max.alloc(ne_KV_max);
+        {
+            dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 });
+
+            main_stream->submit([&](sycl::handler & cgh) {
+                sycl::local_accessor buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh);
+
+                auto mask_data_ct0  = (const sycl::half2 *) mask->data;
+                auto KV_max_ptr_ct1 = KV_max.ptr;
+
+                cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
+                                 [=](sycl::nd_item<3> item_ct1) {
+                                     GGML_UNUSED(item_ct1);
+                                     flash_attn_mask_to_KV_max(
+                                         mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
+                                         buf_iw_acc_ct1.get_multi_ptr().get());
+                                 });
+            });
+        }
+        SYCL_CHECK(0);
+    }
+
+    const dpct::dim3 block_dim(warp_size, nwarps, 1);
+
+    // Max. number of active blocks limited by occupancy.
+    int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu;
+    int parallel_blocks = max_blocks_per_sm;
+    dpct::dim3 blocks_num;
+    if (stream_k) {
+        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+        const int max_blocks = max_blocks_per_sm*nsm;
+        const int nblocks_stream_k = max_blocks;
+        const bool use_stream_k = true;
+
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+        blocks_num.y = 1;
+        blocks_num.z = 1;
+
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
+        }
+    } else {
+        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
+
+        // parallel_blocks must not be larger than what the tensor size allows:
+        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+        // todo fix the hard code change
+        // parallel_blocks = ntiles_KQ;
+
+        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+        // Test whether parallel_blocks can be set to a higher value for better efficiency.
+        const int blocks_per_wave = nsm * max_blocks_per_sm;
+        int nwaves_best = 0;
+        int efficiency_percent_best = 0;
+        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+            const int nblocks_total = ntiles_total * parallel_blocks_test;
+            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+            if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
+                break;
+            }
+
+            if (efficiency_percent > efficiency_percent_best) {
+                nwaves_best = nwaves;
+                efficiency_percent_best = efficiency_percent;
+                parallel_blocks = parallel_blocks_test;
+            }
+        }
+
+        blocks_num.x = ntiles_x;
+        blocks_num.y = parallel_blocks;
+        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
+
+        if (parallel_blocks > 1) {
+            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+        }
+    }
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // TODO other tensor dimensions after removal of WMMA kernel:
+    const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]);
+
+    GGML_ASSERT(block_dim.x % warp_size == 0);
+
+    lauch_kernel(
+        blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data,
+        mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr,
+        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1,
+        n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0],
+        K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0,
+        mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
+        mask ? mask->nb[3] : 0);
+    SYCL_CHECK(0);
+
+    if (stream_k) {
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            const dpct::dim3 block_dim_combine(DV, 1, 1);
+            const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 };
+
+            main_stream->submit([&](sycl::handler & cgh) {
+                auto KQV_data_ct0         = (float *) KQV->data;
+                auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
+                auto Q_ne_ct2             = Q->ne[1];
+                auto Q_ne_ct3             = Q->ne[2];
+                auto Q_ne_ct4             = Q->ne[3];
+                auto K_ne_ct5             = K->ne[1];
+                auto K_ne_ct6             = K->ne[2];
+
+                cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
+                                 [=](sycl::nd_item<3> item_ct1) {
+                                     GGML_UNUSED(item_ct1);
+                                     flash_attn_stream_k_fixup(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
+                                                                                   Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
+                                                                                   K_ne_ct5, K_ne_ct6, nbatch_fa);
+                                 });
+            });
+        }
+    } else if (parallel_blocks > 1) {
+        const dpct::dim3 block_dim_combine(DV, 1, 1);
+        const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
+        const size_t     nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2);
+        main_stream->submit([&](sycl::handler & cgh) {
+            sycl::local_accessor dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh);
+
+            auto dst_tmp_ptr_ct0      = dst_tmp.ptr;
+            auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
+            auto KQV_data_ct2         = (float *) KQV->data;
+
+            cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 GGML_UNUSED(item_ct1);
+                                 flash_attn_combine_results(
+                                     dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,
+                                     dpct_local_acc_ct1.get_multi_ptr().get());
+                             });
+        });
+    }
+    SYCL_CHECK(0);
+}
diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp
new file mode 100644
index 00000000..9d4f019c
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn-tile.cpp
@@ -0,0 +1,55 @@
+#include 
+#include 
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+#include "fattn-tile.hpp"
+#include 
+#include 
+namespace syclex = sycl::ext::oneapi::experimental;
+
+void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+    switch (K->ne[0]) {
+        case  40: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 40,  40>(ctx, dst);
+        } break;
+        case  64: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 64,  64>(ctx, dst);
+        } break;
+        case  72: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 72,  72>(ctx, dst);
+        } break;
+        case  80: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 80,  80>(ctx, dst);
+        } break;
+        case  96: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 96,  96>(ctx, dst);
+        } break;
+        case 112: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);
+        } break;
+        case 128: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);
+        } break;
+        case 256: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
+        } break;
+        case 576: {
+            GGML_ASSERT(V->ne[0] == 512);
+            ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
+        } break;
+        default: {
+            GGML_ABORT("Unsupported head size");
+        } break;
+    }
+}
diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp
new file mode 100644
index 00000000..29fd0f8c
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn-tile.hpp
@@ -0,0 +1,1338 @@
+#include 
+#include 
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+
+#include 
+#include 
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+#define GGML_SYCL_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
+    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                          \
+        static_assert((nthreads)          <= 512, "bad nthreads");                                    \
+        static_assert((occupancy)         <=   8, "bad occupancy");                                   \
+        static_assert((nbatch_fa)         <= 256, "bad nbatch_fa");                                   \
+        static_assert((nbatch_K)          <= 256, "bad nbatch_K");                                    \
+        return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23);    \
+    }                                                                                                 \
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  64,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  64,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  64,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  64,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  64,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 3,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 3,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 2,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2, 256, 2, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2,  64,  32)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2, 256, 2, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 5, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 5, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 8,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 8,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 6,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 128, 6,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+    if(fast_fp16_available(cc))
+        return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
+    else
+        return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
+#ifdef SYCL_FAST_FP16
+    return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
+#else
+    return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);
+#endif // SYCL_FAST_FP16
+}
+
+static int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
+}
+
+template 
+static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,
+                                                      sycl::half2 * const __restrict__ tile_KV,
+                                                      const int stride_KV,
+                                                      const int i_sup) {
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    auto load = [&] (const int n) {
+        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+        const int stride_j = warp_size >> n;
+
+        if (stride_j == 0) {
+            return;
+        }
+
+        const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
+        const int j0_stop  =                             ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
+        const int stride_i = warp_size / stride_j;
+
+        if (j0_start == j0_stop) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+            const int i = i0 + item_ct1.get_local_id(1) * stride_i +
+                          (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);
+
+            if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+                    const int j = j0 * cpy_ne + (stride_j == warp_size ? item_ct1.get_local_id(2) :
+                                                                         item_ct1.get_local_id(2) % stride_j) *
+                                                    cpy_ne;
+
+                    const __dpct_align__(16) sycl::half2 zero[cpy_ne] = {
+                        { 0.0f, 0.0f }
+                    };
+                    ggml_sycl_memcpy_1(
+                        tile_KV + i*(J/2 + J_padding) + j,
+                        !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+                }
+            }
+        }
+    };
+    // 1: max 64*16=512 bytes, 512 half
+    // 2: max 32*16=512 bytes, 256 half
+    // 3: max 16*16=256 bytes, 128 half
+    // 4: max  8*16=128 bytes,  64 half
+    // 5: max  4*16= 64 bytes,  32 half
+    // 6: max  2*16= 32 bytes,  16 half
+    // 7: max  1*16= 16 bytes,   8 half
+    static_assert(J % 8 == 0, "bad J");
+    static_assert((J/2) % cpy_ne == 0, "bad J");
+    ggml_sycl_unroll<7>{}(load);
+}
+
+template 
+static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,
+                                                      float * const __restrict__ tile_KV,
+                                                      const int stride_KV,
+                                                      const int i_sup) {
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    auto load = [&] (const int n) {
+        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+        const int stride_j = warp_size >> n;
+
+        if (stride_j == 0) {
+            return;
+        }
+
+        const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
+        const int j0_stop  =                             (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
+        const int stride_i = warp_size / stride_j;
+
+        if (j0_start == j0_stop) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+            const int i = i0 + item_ct1.get_local_id(1) * stride_i +
+                          (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);
+
+            if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+                    const int j = j0 * (cpy_ne / 2) + (stride_j == warp_size ? item_ct1.get_local_id(2) :
+                                                                               item_ct1.get_local_id(2) % stride_j) *
+                                                          (cpy_ne / 2);
+
+                    const sycl::half2 zero[cpy_ne / 2] = {
+                        { 0.0f, 0.0f }
+                    };
+                    __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne / 2];
+                    ggml_sycl_memcpy_1(
+                        tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+
+                    __dpct_align__(16) sycl::float2 tmp_f2[cpy_ne / 2];
+#pragma unroll
+                    for (int l = 0; l < cpy_ne/2; ++l) {
+                        tmp_f2[l] = tmp_h2[l].template convert();
+                    }
+                    ggml_sycl_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
+                }
+            }
+        }
+    };
+    // 1: max 32*16=512 bytes, 128 float
+    // 2: max 16*16=256 bytes,  64 float
+    // 3: max  8*16=128 bytes,  32 float
+    // 4: max  4*16= 64 bytes,  16 float
+    // 5: max  2*16= 32 bytes,   8 float
+    static_assert(J % 8 == 0, "bad J");
+    static_assert(J % cpy_ne == 0, "bad J");
+    ggml_sycl_unroll<5>{}(load);
+}
+
+// Function that performs a single iteration in for the KQ matrix multiplication:
+template 
+static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
+                                                    const sycl::half2 * const __restrict__ K_h2,
+                                                    T_vec_dot * const KV_tmp,
+                                                    const int         stride_K2,
+                                                    const int         k_VKQ_0,
+                                                    const int         k_VKQ_sup,
+                                                    const int         k_KQ_0,
+                                                    float *           KQ_acc) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int ncols = ncols1*ncols2;
+    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+    flash_attn_tile_load_tile
+        (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
+    item_ct1.barrier();
+
+#ifdef SYCL_FAST_FP16
+    static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
+        __dpct_align__(16) sycl::half2 K_k[nbatch_fa / (np * warp_size)][cpy_ne];
+        __dpct_align__(16) sycl::half2 Q_k[cpw][cpy_ne];
+#else
+    static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
+        __dpct_align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __dpct_align__(16) float Q_k[cpw][cpy_ne];
+#endif // SYCL_FAST_FP16
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+            const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+#ifdef SYCL_FAST_FP16
+            ggml_sycl_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
+#else
+            ggml_sycl_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K   + cpy_ne) + k_KQ_1]);
+#endif // SYCL_FAST_FP16
+        }
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+            const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+#ifdef SYCL_FAST_FP16
+            ggml_sycl_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
+#else
+            ggml_sycl_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ    + k_KQ_0   + k_KQ_1]);
+#endif // SYCL_FAST_FP16
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+#pragma unroll
+            for (int jc0 = 0; jc0 < cpw; ++jc0) {
+#pragma unroll
+                for (int k = 0; k < cpy_ne; ++k) {
+                    ggml_sycl_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
+                }
+            }
+        }
+    }
+
+    if (k_KQ_0 + nbatch_K < DKQ) {
+        item_ct1.barrier();  // Sync not needed on last iteration.
+    }
+}
+
+// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
+template 
+/*
+The total declared local variable size in device function flash_attn_tile_iter exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.
+*/
+static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
+                                                 const sycl::half2 * const __restrict__ K_h2,
+                                                 const sycl::half2 * const __restrict__ V_h2,
+                                                 const sycl::half * const __restrict__ mask,
+                                                 const sycl::uint3 ne01,
+                                                 const float       logit_softcap,
+                                                 const float       slope,
+                                                 T_KQ * const      KQ,
+                                                 T_vec_dot * const KV_tmp,
+                                                 const int         stride_K2,
+                                                 const int         stride_V2,
+                                                 const int         stride_mask,
+                                                 float * const     KQ_max,
+                                                 float * const     KQ_sum,
+                                                 T_acc * const     VKQ,
+                                                 const int         k_VKQ_0,
+                                                 const int         k_VKQ_max,
+                                                 const int         col_Q_0,
+                                                 float *           KQ_max_new_shared) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int ncols = ncols1*ncols2;
+    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+    constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+#ifdef SYCL_FAST_FP16
+    constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
+#else
+    constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
+#endif // SYCL_FAST_FP16
+    static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
+    const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
+
+    float KQ_max_new[cpw];
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        KQ_max_new[jc0] = KQ_max[jc0];
+    }
+
+    float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
+
+    // KQ = K @ Q matrix multiplication:
+    constexpr int nbatch_K_last = DKQ % nbatch_K;
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
+        flash_attn_tile_iter_KQ(
+            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+    }
+    if (nbatch_K_last > 0) {
+        constexpr int k_KQ_0 = DKQ - nbatch_K_last;
+        flash_attn_tile_iter_KQ(
+            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+    }
+
+    // Apply logit softcap + mask, update KQ_max:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int j = fastmodulo(col_Q_0 + (jc0 + (item_ct1.get_local_id(1) / np) * cpw) / ncols2, ne01);
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+            const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+            // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+            // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+            KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
+#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+
+            if (use_logit_softcap) {
+                KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] =
+                    logit_softcap * sycl::tanh((float) KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0]);
+            }
+
+            if (!oob_check || i_KQ < k_VKQ_sup) {
+                KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] +=
+                    (ncols2 > 1 || mask) ? slope * sycl::vec(mask[j * stride_mask + k_VKQ_0 + i_KQ])
+                                                       .convert()[0] :
+                                           0.0f;
+
+                KQ_max_new[jc0] =
+                    sycl::fmax((float) KQ_max_new[jc0],
+                               (float) (KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] + FATTN_KQ_MAX_OFFSET));
+            }
+        }
+
+        KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]);
+    }
+
+    if constexpr (np == 1) {
+        item_ct1.barrier();
+    } else {
+        static_assert(cpw == 1, "bad cpw");
+
+        if (item_ct1.get_local_id(2) == 0) {
+            KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
+        }
+        item_ct1.barrier();
+        KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
+        KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]);
+    }
+
+    // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
+#ifdef SYCL_FAST_FP16
+        __dpct_align__(16) sycl::half tmp[nbatch_fa / (np * warp_size)][KQ_cs];
+#else
+        __dpct_align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#endif // SYCL_FAST_FP16
+
+#pragma unroll
+        for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
+            const int jc = jc0 + jc1;
+
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc] - KQ_max_new[jc]));
+            KQ_max[jc] = KQ_max_new[jc];
+
+            float KQ_sum_add = 0.0f;
+#pragma unroll
+            for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+                const float val =
+                    !oob_check || i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2) <
+                                      static_cast(k_VKQ_sup) ?
+                        sycl::native::exp((float) (KQ_acc[(i0 / (np * warp_size)) * cpw + jc] - KQ_max[jc])) :
+                        0.0f;
+                KQ_sum_add += val;
+                tmp[i0/(np*warp_size)][jc1] = val;
+            }
+            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
+
+#ifdef SYCL_FAST_FP16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale_h2.x();
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale_h2.y();
+            }
+#else
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;
+            }
+#endif // SYCL_FAST_FP16
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+            const int i = i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+            ggml_sycl_memcpy_1(
+                KQ + (jc0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs)) * (nbatch_fa * KQ_cs) + i * KQ_cs,
+                tmp[i0 / (np * warp_size)]);
+        }
+    }
+
+    // VKQ = V @ KQ matrix multiplication:
+    static_assert(DV <= DKQ, "bad DV");
+    static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
+    constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
+    static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
+    static_assert(nbatch_V % np == 0, "bad nbatch_V");
+#pragma unroll
+    for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
+        flash_attn_tile_load_tile
+            (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
+        item_ct1.barrier();
+
+#ifdef SYCL_FAST_FP16
+#pragma unroll
+        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+            __dpct_align__(16) sycl::half2 V_k[(DVp / 2) / warp_size];
+            __dpct_align__(16) sycl::half2 KQ_k[cpw];
+
+            constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1(&V_k[i0 / warp_size],
+                                                 &KV_tmp[(k1 + item_ct1.get_local_id(1) % np) * (DV / 2) + i0 +
+                                                         item_ct1.get_local_id(2) * cpy_ne_D]);
+            }
+#pragma unroll
+            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+                const int jc_KQ = jc_VKQ_0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs);
+
+                __dpct_align__(16) sycl::half tmp[KQ_cs];
+                ggml_sycl_memcpy_1(
+                    &tmp, KQ + jc_KQ * (nbatch_fa * KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np) * KQ_cs);
+#pragma unroll
+                for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
+                    KQ_k[jc_VKQ_0 + jc_VKQ_1] = sycl::half2(tmp[jc_VKQ_1]);
+                }
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() +=
+                        V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0].x();
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() +=
+                        V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0].y();
+                }
+            }
+        }
+#else
+#pragma unroll
+        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+            __dpct_align__(16) sycl::float2 V_k[(DVp/2)/warp_size];
+            __dpct_align__(16) float  KQ_k[cpw];
+
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + item_ct1.get_local_id(1) % np)*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);
+            }
+#pragma unroll
+            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+                const int jc_KQ = jc_VKQ_0/KQ_cs + (item_ct1.get_local_id(1) / np)*(cpw/KQ_cs);
+
+                ggml_sycl_memcpy_1(
+                    &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np)*KQ_cs);
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0];
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0];
+                }
+            }
+        }
+#endif // SYCL_FAST_FP16
+        item_ct1.barrier();
+    }
+}
+
+template   // D == head size
+/*
+The total declared local variable size in device function flash_attn_tile exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.
+*/
+static void flash_attn_tile(const char *  Q,
+                            const char *  K,
+                            const char *  V,
+                            const char *  mask,
+                            const char *  sinks,
+                            const int *  KV_max,
+                            float *  dst,
+                            sycl::float2 *  dst_meta,
+                            const float          scale,
+                            const float          max_bias,
+                            const float          m0,
+                            const float          m1,
+                            const uint32_t       n_head_log2,
+                            const float          logit_softcap,
+                            const int32_t        ne00,
+                            const sycl::uint3    ne01,
+                            const int32_t        ne02,
+                            const int32_t        ne03,
+                            const int32_t        nb01,
+                            const int32_t        nb02,
+                            const int32_t        nb03,
+                            const int32_t        ne10,
+                            const int32_t        ne11,
+                            const int32_t        ne12,
+                            const int32_t        ne13,
+                            const int32_t        nb11,
+                            const int32_t        nb12,
+                            const int64_t        nb13,
+                            const int32_t        nb21,
+                            const int32_t        nb22,
+                            const int64_t        nb23,
+                            const int32_t        ne31,
+                            const int32_t        ne32,
+                            const int32_t        ne33,
+                            const int32_t        nb31,
+                            const int32_t        nb32,
+                            const int64_t        nb33) {
+#ifdef SYCL_FLASH_ATTN
+    // Skip unused kernel variants for faster compilation:
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    if ((use_logit_softcap && !(DV == 128 || DV == 256))) {
+        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+            max_bias, m0, m1, n_head_log2, logit_softcap,
+            ne00, ne01, ne02, ne03,
+                  nb01, nb02, nb03,
+            ne10, ne11, ne12, ne13,
+                  nb11, nb12, nb13,
+                  nb21, nb22, nb23,
+                  ne31, ne32, ne33,
+                  nb31, nb32, nb33);
+        return;
+    }
+
+    static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
+
+    constexpr int ncols     = ncols1*ncols2;
+
+    constexpr int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
+    constexpr int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
+    constexpr int nbatch_K  = ggml_sycl_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
+
+    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int col_Q_0 = item_ct1.get_group(2) * ncols1;  // Index of the first Q column for this SYCL block to work on.
+
+    const int           sequence  = item_ct1.get_group(0) / (ne02 / ncols2);
+    const int           head0     = item_ct1.get_group(0) * ncols2 - sequence * ne02;  // == item_ct1.get_group(0) % (ne02/ncols2)
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);
+    const sycl::half2 * K_h2      = (const sycl::half2 *) (K + nb13 * sequence + nb12 * (head0 / gqa_ratio));
+    const sycl::half2 * V_h2 =
+        (const sycl::half2 *) (V + nb23 * sequence + nb22 * (head0 / gqa_ratio));  // K and V have same shape
+
+    const sycl::half * maskh = mask ? (const sycl::half *) (mask + nb33 * (sequence % ne33)) : nullptr;
+
+    const int stride_K2   = nb11 / sizeof(sycl::half2);
+    const int stride_V2   = nb21 / sizeof(sycl::half2);
+    const int stride_mask = nb31 / sizeof(sycl::half);
+
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
+    constexpr int np  = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
+
+    static_assert(cpw == 1 || np == 1, "bad cpw / np");
+    static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
+
+    constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
+    constexpr int DVp  = (DV  + 2*warp_size - 1) & ~(2*warp_size - 1); // DV  padded to multiple of 2*warp_size.
+
+    // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
+    // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
+    //     KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
+    // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
+    // VKQ == Accumulators in registers for the final VKQ result.
+
+
+#ifdef SYCL_FAST_FP16
+    constexpr size_t lsm_size1 = ncols * DKQ/2 ;
+    constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV ;
+    constexpr size_t lsm_size3 = ncols * nbatch_fa;
+    constexpr size_t lsm_size4 = nwarps;
+
+    constexpr size_t local_share_mem_size = lsm_size1 * sizeof(sycl::half2) +
+                                            lsm_size2 * sizeof(sycl::half2) +
+                                            lsm_size3 * sizeof(sycl::half) +
+                                            lsm_size4 * sizeof(float);
+
+    syclex::work_group_static lsm;
+
+    sycl::half2 *Q_tmp = (sycl::half2 *)&lsm;
+    sycl::half2 *KV_tmp = (sycl::half2*)(Q_tmp +lsm_size1);
+    sycl::half *KQ = (sycl::half *)(KV_tmp+lsm_size2);
+    float *KQ_max_new_shared = (float *)(KQ+lsm_size3);
+
+    __dpct_align__(16) sycl::half2 VKQ[cpw * ((DVp / 2) / warp_size)] = {
+        { 0.0f, 0.0f }
+    };
+#else
+    constexpr size_t lsm_size1 = ncols * DKQ ;
+    constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV;
+    constexpr size_t lsm_size3 = ncols * nbatch_fa;
+    constexpr size_t lsm_size4 = nwarps;
+
+    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 +lsm_size3 + lsm_size4) * sizeof(float);
+
+    syclex::work_group_static lsm;
+
+    float *Q_tmp = (float *)&lsm;
+    float *KV_tmp = Q_tmp +lsm_size1;
+    float *KQ = KV_tmp+lsm_size2;
+    float *KQ_max_new_shared = KQ+lsm_size3;
+
+    __dpct_align__(16) sycl::float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+
+
+#endif // SYCL_FAST_FP16
+
+    float KQ_max[cpw] = {};
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
+    }
+    float KQ_sum[cpw] = {0.0f};
+
+    // Load Q data, convert to FP16 if fast:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+        const int j = jc / ncols2;
+        const int c = jc % ncols2;
+
+        constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
+
+#pragma unroll
+        for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
+            if (i0 + np * warp_size * cpy_ne_D <= DKQ ||
+                i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D <
+                    DKQ) {
+                __dpct_align__(16) float tmp_f[cpy_ne_D] = { 0.0f };
+                ggml_sycl_memcpy_1(
+                    tmp_f, &Q_f[c * (nb02 / sizeof(float)) + fastmodulo(col_Q_0 + j, ne01) * (nb01 / sizeof(float)) +
+                                i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) +
+                                item_ct1.get_local_id(2) * cpy_ne_D]);
+
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    tmp_f[i1] *= scale;
+                }
+
+#ifdef SYCL_FAST_FP16
+                __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne_D / 2];
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
+                    tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+                    // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+                    // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+                    tmp_h2[i1 / 2] *= sycl::half2(0.25f, 0.25f);
+#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+                }
+                ggml_sycl_memcpy_1(
+                    &Q_tmp[jc * (DKQ / 2) + i0 / 2 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D / 2) +
+                           item_ct1.get_local_id(2) * (cpy_ne_D / 2)],
+                    tmp_h2);
+#else
+                ggml_sycl_memcpy_1(
+                    &Q_tmp[jc* DKQ    + i0   + (item_ct1.get_local_id(1) % np)*(warp_size*cpy_ne_D)   + item_ct1.get_local_id(2)* cpy_ne_D],
+                    tmp_f);
+#endif // SYCL_FAST_FP16
+            }
+        }
+    }
+
+    item_ct1.barrier();
+
+    // Main loop over KV cache:
+    const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
+    if (ncols2 == 1) {
+        // Branch with out-of-bounds checks.
+        int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa;
+        while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
+            constexpr bool oob_check = false;
+            flash_attn_tile_iter(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+                                            KQ_max_new_shared);
+            k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa;
+        }
+        if (k_VKQ_0 < k_VKQ_max) {
+            constexpr bool oob_check = true;
+            flash_attn_tile_iter(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+                                            KQ_max_new_shared);
+        }
+    } else {
+        // Branch without out-of-bounds checks.
+        for (int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; k_VKQ_0 < k_VKQ_max;
+             k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa) {
+
+            constexpr bool oob_check = false;
+            flash_attn_tile_iter(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+                                            KQ_max_new_shared);
+        }
+    }
+
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]);
+    }
+
+    if constexpr (np > 1) {
+        static_assert(cpw == 1, "bad cpw");
+        static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
+
+#ifdef SYCL_FAST_FP16
+        sycl::half2 * VKQ_combine = (sycl::half2 *) KV_tmp;
+#else
+        float * VKQ_combine    = (float *) KV_tmp;
+#endif // SYCL_FAST_FP16
+
+        float * KQ_sum_combine = (float *) Q_tmp;
+
+        if (item_ct1.get_local_id(1) % np != 0) {
+
+#ifdef SYCL_FAST_FP16
+            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1(
+                    &VKQ_combine[item_ct1.get_local_id(1) * (DVp / 2) + i0 + item_ct1.get_local_id(2) * cpy_ne_D],
+                    &VKQ[i0 / warp_size]);
+            }
+#else
+
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1(
+                    &VKQ_combine[item_ct1.get_local_id(1)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
+            }
+#endif // SYCL_FAST_FP16
+
+            if (item_ct1.get_local_id(2) == 0) {
+                KQ_sum_combine[item_ct1.get_local_id(1)] = KQ_sum[0];
+            }
+            return;
+        }
+
+        item_ct1.barrier();
+
+#pragma unroll
+        for (int ip = 1; ip < np; ++ip) {
+#ifdef SYCL_FAST_FP16
+            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                __dpct_align__(16) sycl::half2 tmp[cpy_ne_D];
+                ggml_sycl_memcpy_1(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip) * (DVp / 2) + i0 +
+                                                                   item_ct1.get_local_id(2) * cpy_ne_D]);
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    VKQ[i0/warp_size + i1] += tmp[i1];
+                }
+            }
+#else
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                __dpct_align__(16) float tmp[cpy_ne_D];
+                ggml_sycl_memcpy_1(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
+                }
+            }
+#endif // SYCL_FAST_FP16
+
+            KQ_sum[0] += KQ_sum_combine[item_ct1.get_local_id(1) + ip];
+        }
+    }
+
+    // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
+    if (sinks && item_ct1.get_group(1) == 0) {
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+            const int   jc   = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+            const float sink = ((const float *) sinks)[head0 + jc % ncols2];
+
+            float       KQ_max_new_j = sycl::fmax((float) KQ_max[jc0], sink);
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc0] - KQ_max_new_j));
+            KQ_max[jc0] = KQ_max_new_j;
+
+            const float val = sycl::native::exp((float) (sink - KQ_max[jc0]));
+            KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
+
+#ifdef SYCL_FAST_FP16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;
+            }
+#endif // SYCL_FAST_FP16
+        }
+    }
+
+    // Write back results:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+        const int j = jc / ncols2;
+        const int c = jc % ncols2;
+
+        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z())) {
+            return;
+        }
+
+        const float scale = item_ct1.get_group_range(1) == 1 ? 1.0f / KQ_sum[jc0] : 1.0f;
+
+        const int j_dst_unrolled =
+            ((sequence * int(ne01.z()) + col_Q_0 + j) * ne02 + head0 + c) * item_ct1.get_group_range(1) +
+            item_ct1.get_group(1);
+
+#ifdef SYCL_FAST_FP16
+        constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+            __dpct_align__(16) sycl::float2 tmp[cpy_ne_D];
+#pragma unroll
+            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                tmp[i1] = VKQ[jc0 * ((DVp / 2) / warp_size) + i0 / warp_size + i1]
+                              .template convert();
+                tmp[i1].x() *= scale;
+                tmp[i1].y() *= scale;
+            }
+            if (i0 + warp_size * cpy_ne_D <= DV / 2 || i0 + item_ct1.get_local_id(2) * cpy_ne_D < DV / 2) {
+                ggml_sycl_memcpy_1(
+                    &dst[j_dst_unrolled * DV + 2 * i0 + item_ct1.get_local_id(2) * (2 * cpy_ne_D)], tmp);
+            }
+        }
+#else
+        constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+            if (i0 + warp_size*cpy_ne_D <= DV || i0 + item_ct1.get_local_id(2)*cpy_ne_D < DV) {
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
+                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x() *= scale;
+                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y() *= scale;
+                }
+                ggml_sycl_memcpy_1(
+                    &dst[j_dst_unrolled*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D],
+                    &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
+            }
+        }
+#endif // SYCL_FAST_FP16
+
+        if (item_ct1.get_group_range(1) != 1 && item_ct1.get_local_id(2) == 0) {
+            dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
+        }
+    }
+#else
+    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+        max_bias, m0, m1, n_head_log2, logit_softcap,
+        ne00, ne01, ne02, ne03,
+              nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+              nb11, nb12, nb13,
+              nb21, nb22, nb23,
+              ne31, ne32, ne33,
+              nb31, nb32, nb33);
+#endif // SYCL_FLASH_ATTN
+}
+
+template 
+static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    const int id        = ggml_sycl_get_device();
+    const int cc        = ggml_sycl_info().devices[id].cc;
+    const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE
+
+    constexpr size_t nbytes_shared = 0;
+
+    if constexpr (DV <= 256) {
+        if (Q->ne[1] > 16/ncols2) {
+            constexpr int cols_per_block = 32;
+            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            launch_fattn, warp_size>
+                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+            return;
+        }
+    }
+
+    if (Q->ne[1] > 8/ncols2) {
+        constexpr int cols_per_block = 16;
+        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+        launch_fattn, warp_size>
+            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+        return;
+    }
+
+    if constexpr (ncols2 <= 8) {
+        if (Q->ne[1] > 4/ncols2) {
+            constexpr int cols_per_block = 8;
+            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            launch_fattn, warp_size>
+                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+            return;
+        }
+    }
+
+    if constexpr (ncols2 <= 4) {
+        if (Q->ne[1] > 2/ncols2) {
+            constexpr int cols_per_block = 4;
+            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            launch_fattn, warp_size>
+                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+            return;
+        }
+    }
+
+    if constexpr (ncols2 <= 2) {
+        constexpr int cols_per_block = 2;
+        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+        launch_fattn, warp_size>
+            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+        return;
+    }
+
+    GGML_ABORT("fatal error");
+}
+
+template 
+static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+    // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
+    // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
+    //const bool nvidia = GGML_SYCL_CC_IS_NVIDIA(ggml_sycl_info().devices[ggml_sycl_get_device()].cc);
+    const int gqa_limit = gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
+    const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+    if constexpr (DV == 512) {
+        if (use_gqa_opt && gqa_ratio % 16 == 0) {
+            launch_fattn_tile_switch_ncols1(ctx, dst);
+            return;
+        }
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1(ctx, dst);
+            return;
+        }
+    }
+
+    if constexpr (DV <= 256) {
+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
+            launch_fattn_tile_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 2 == 0) {
+            launch_fattn_tile_switch_ncols1(ctx, dst);
+            return;
+        }
+
+        launch_fattn_tile_switch_ncols1(ctx, dst);
+        return;
+    }
+    GGML_ABORT("fatal error");
+}
+
+template 
+void ggml_sycl_flash_attn_ext_tile_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_switch_ncols2(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_switch_ncols2(ctx, dst);
+    }
+}
+
+void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#define DECL_FATTN_TILE_CASE(DKQ, DV)                             \
+    template void ggml_sycl_flash_attn_ext_tile_case              \
+    (ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_TILE_CASE( 40,  40);
+extern DECL_FATTN_TILE_CASE( 64,  64);
+extern DECL_FATTN_TILE_CASE( 72,  72);
+extern DECL_FATTN_TILE_CASE( 80,  80);
+extern DECL_FATTN_TILE_CASE( 96,  96);
+extern DECL_FATTN_TILE_CASE(112, 112);
+extern DECL_FATTN_TILE_CASE(128, 128);
+extern DECL_FATTN_TILE_CASE(256, 256);
+extern DECL_FATTN_TILE_CASE(576, 512);
+
diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp
new file mode 100644
index 00000000..48c38905
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn-vec.hpp
@@ -0,0 +1,667 @@
+#ifndef GGML_SYCL_FATTN_VEC_HPP
+#define GGML_SYCL_FATTN_VEC_HPP
+
+#include 
+#include 
+#include 
+#include 
+
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+#include "fattn-common.hpp"
+#include 
+#include 
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) {
+    return 128;
+    GGML_UNUSED(cc);
+}
+
+static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() {
+    return 128;
+}
+
+// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// that contain a break that can not be resolved at compile time.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+
+template   // D == head size
+static void flash_attn_ext_vec(const char* __restrict__ Q,
+                        const char* __restrict__ K,
+                        const char* __restrict__ V,
+                        const char* __restrict__ mask,
+                        const char* __restrict__ sinks,
+                        const int* __restrict__ KV_max,
+                        float* __restrict__ dst,
+                        sycl::float2* __restrict__ dst_meta,
+                        const float scale,
+                        const float max_bias,
+                        const float m0,
+                        const float m1,
+                        const uint32_t n_head_log2,
+                        const float logit_softcap,
+                        const int32_t ne00,
+                        const sycl::uint3 ne01,
+                        const int32_t ne02,
+                        const int32_t ne03,
+                        const int32_t nb01,
+                        const int32_t nb02,
+                        const int32_t nb03,
+                        const int32_t ne10,
+                        const int32_t ne11,
+                        const int32_t ne12,
+                        const int32_t ne13,
+                        const int32_t nb11,
+                        const int32_t nb12,
+                        const int64_t nb13,
+                        const int32_t nb21,
+                        const int32_t nb22,
+                        const int64_t nb23,
+                        const int32_t ne31,
+                        const int32_t ne32,
+                        const int32_t ne33,
+                        const int32_t nb31,
+                        const int32_t nb32,
+                        const int64_t nb33) {
+#ifdef SYCL_FLASH_ATTN
+    // Skip unused kernel variants for faster compilation:
+
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+            max_bias, m0, m1, n_head_log2, logit_softcap,
+            ne00, ne01, ne02, ne03,
+                  nb01, nb02, nb03,
+            ne10, ne11, ne12, ne13,
+                  nb11, nb12, nb13,
+                  nb21, nb22, nb23,
+                  ne31, ne32, ne33,
+                  nb31, nb32, nb33);
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size);
+    constexpr int nthreads_V_q  = (D/4 < warp_size ? D/4 : warp_size);
+
+    constexpr int nthreads    = ggml_sycl_fattn_vec_get_nthreads_device();
+    constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
+    constexpr int nthreads_V  = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
+
+    static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K");
+    static_assert(warp_size % nthreads_V  == 0, "bad nthreads_V");
+
+    constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
+    constexpr int V_cols_per_iter   = warp_size / nthreads_V;
+
+    constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ();
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+#ifdef GGML_SYCL_F16
+    constexpr dequantize_V_t dequantize_V = get_dequantize_V();
+#else
+    constexpr dequantize_V_t dequantize_V = get_dequantize_V();
+#endif // GGML_SYCL_F16
+
+    const int ic0 = item_ct1.get_group(2) * ncols;  // Index of the Q/QKV column to work on.
+
+    const int sequence  = item_ct1.get_group(0) / ne02;
+    const int head      = item_ct1.get_group(0) - sequence * ne02;
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
+
+    const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0);
+
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
+
+    static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64.");
+    constexpr int nwarps = nthreads / warp_size;
+    const int     tid    = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2);
+    __builtin_assume(tid < nthreads);
+
+    constexpr int ne_KQ      = ncols*D;
+    constexpr int ne_combine = nwarps*V_cols_per_iter*D;
+
+    constexpr size_t lsm_size1 = ncols * warp_size;
+    constexpr size_t lsm_size2 = ncols * warp_size;
+#ifdef GGML_SYCL_F16
+    sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } };
+    constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
+    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half);
+
+    syclex::work_group_static lsm;
+
+    float *KQ_max_shared = (float *)&lsm;
+    float *KQ_sum_shared = KQ_max_shared+lsm_size1;
+    sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2);
+
+
+#else
+    sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
+
+    constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
+    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float);
+
+
+    syclex::work_group_static lsm;
+    float *KQ_max_shared = (float *)&lsm;
+    float *KQ_sum_shared = KQ_max_shared+lsm_size1;
+    float* KQ = KQ_sum_shared + lsm_size2;
+
+#endif // GGML_SYCL_F16
+
+    float KQ_max[ncols];
+    float KQ_sum[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ_max[j] = -FLT_MAX/2.0f;
+        KQ_sum[j] = 0.0f;
+    }
+
+    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+#ifdef GGML_SYCL_F16
+    sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}};  // Will be initialized completely.
+#else
+    sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+#endif // GGML_SYCL_F16
+    int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
+    sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)];
+    if constexpr (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + item_ct1.get_local_id(1);
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            sycl::float2 * tmp_q_ds  = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 1 && ic0 + j >= int(ne01.z())) {
+#pragma unroll
+                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) {
+                    const int i = i0 + item_ct1.get_local_id(2);
+
+                    if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
+                        tmp_q_i32[i] = 0;
+                    }
+                }
+                if (item_ct1.get_local_id(2) < D/QK8_1) {
+                    tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f);
+                }
+            } else {
+                const float * Q_f = (const float *) (Q + j*nb01);
+                constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size;
+#pragma unroll
+                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
+                    quantize_q8_1_to_shared
+                        (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
+                }
+            }
+        }
+
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            sycl::float2 * tmp_q_ds  = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
+                const int i =
+                    i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ);
+
+                Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
+                Q_ds[j][i0/nthreads_KQ]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    } else {
+#ifdef GGML_SYCL_F16
+        const sycl::half2 scale_h2 = sycl::half2(scale, scale);
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+                const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) :
+                                                               item_ct1.get_local_id(2) % nthreads_KQ) *
+                                       cpy_ne;
+
+                sycl::float2 tmp[cpy_ne] = {
+                    { 0.0f, 0.0f }
+                };
+                if (ncols == 1 || ic0 + j < int(ne01.z())) {
+                    ggml_sycl_memcpy_1(tmp,            &Q_j[i]);
+                    ggml_sycl_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
+                }
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne; ++i1) {
+                    Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y());
+                }
+            }
+#pragma unroll
+            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+                Q_reg[j][k] *= scale_h2;
+            }
+        }
+#else
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+                const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne;
+                if (ncols == 1 || ic0 + j < int(ne01.z())) {
+                    ggml_sycl_memcpy_1(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);
+                    ggml_sycl_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
+                }
+            }
+#pragma unroll
+            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+                Q_reg[j][k].x() *= scale;
+                Q_reg[j][k].y() *= scale;
+            }
+        }
+#endif // GGML_SYCL_F16
+    }
+
+    const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
+    K += item_ct1.get_group(1) * nthreads * nb11;
+    V += item_ct1.get_group(1) * nthreads * nb21;
+    maskh += item_ct1.get_group(1) * nthreads;
+    for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max;
+         k_VKQ_0 += item_ct1.get_group_range(1) * nthreads,
+             // Increment pointers after each loop:
+         K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21,
+             maskh += item_ct1.get_group_range(1) * nthreads) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+        float KQ_reg[ncols]={}; // KQ in registers.
+        float KQ_max_new[ncols]={};
+
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            KQ_max_new[j] = KQ_max[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
+            const int i_KQ = item_ct1.get_local_id(1) * warp_size +
+                             (nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0;
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap * sycl::tanh(sum);
+                }
+                if (mask) {
+                    sum += slope * sycl::vec(maskh[j * ne11 + i_KQ])
+                                       .convert()[0];
+                }
+
+                KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum);
+
+                if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2)
+                                                 : item_ct1.get_local_id(2) %
+                                                       nthreads_KQ) == i_KQ_0) {
+                  KQ_reg[j] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+            for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) {
+               KQ_max_new[j] = sycl::fmax(
+                  (float)KQ_max_new[j],
+                  (float)dpct::permute_sub_group_by_xor(
+                      sycl::ext::oneapi::this_work_item::get_sub_group(),
+                      KQ_max_new[j],
+                      offset,
+                      warp_size));
+            }
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j]));
+            KQ_max[j] = KQ_max_new[j];
+
+            KQ_reg[j]            = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j]));
+            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
+            KQ[j*nthreads + tid] = KQ_reg[j];
+
+#ifdef GGML_SYCL_F16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
+                VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
+            }
+#endif // GGML_SYCL_F16
+        }
+
+        sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group());
+
+#pragma unroll
+        for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) {
+            const int k = item_ct1.get_local_id(1) * warp_size + k0 +
+                          (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V);
+
+#ifdef GGML_SYCL_F16
+            sycl::half2 KQ_k[ncols];
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                KQ_k[j] = sycl::half2(KQ[j * nthreads + k]);
+            }
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+                sycl::half2 tmp[V_rows_per_thread / 2];
+                dequantize_V(V + k * nb21, tmp,
+                             2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) :
+                                                                      item_ct1.get_local_id(2) % nthreads_V) *
+                                               V_rows_per_thread);
+#pragma unroll
+                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+                    for (int j = 0; j < ncols; ++j) {
+                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
+                    }
+                }
+            }
+#else
+            float KQ_k[ncols];
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                KQ_k[j] = KQ[j*nthreads + k];
+            }
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+                sycl::float2 tmp[V_rows_per_thread/2];
+                dequantize_V(V + k*nb21, tmp,
+                    2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread);
+#pragma unroll
+                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+                    for (int j = 0; j < ncols; ++j) {
+                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j];
+                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j];
+                    }
+                }
+            }
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    if (sinks && item_ct1.get_group(1) == 0) {
+        const float sink = ((const float *) sinks)[head];
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + item_ct1.get_local_id(1);
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+            const float kqmax_new_j  = sycl::fmax(sink, (float) KQ_max[j]);
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j));
+            KQ_max[j] = kqmax_new_j;
+
+            KQ_sum[j] = KQ_sum[j] * KQ_max_scale +
+                        (item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f);
+#ifdef GGML_SYCL_F16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
+                VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
+            }
+#endif // GGML_SYCL_F16
+        }
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (item_ct1.get_local_id(1) == 0) {
+            KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f;
+            KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f;
+        }
+    }
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (item_ct1.get_local_id(2) == 0) {
+            KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j];
+        }
+    }
+
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) {
+            break;
+        }
+
+        float kqmax_new         = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
+        kqmax_new = warp_reduce_max(kqmax_new);
+        const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new));
+        KQ_max[j_VKQ] = kqmax_new;
+
+#ifdef GGML_SYCL_F16
+        sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) +
+                                (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2);
+
+        const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale);
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+            VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
+        }
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+            const int i_VKQ =
+                i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) *
+                              (V_rows_per_thread / 2);
+
+            ggml_sycl_memcpy_1(VKQ_tmp + i_VKQ,
+                                                                       &VKQ[j_VKQ][i_VKQ_0 / nthreads_V]);
+        }
+#else
+        sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2)
+            + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2);
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+            VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale;
+            VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale;
+        }
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+            const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2);
+
+            ggml_sycl_memcpy_1(VKQ_tmp + i_VKQ,                       &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
+            ggml_sycl_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
+        }
+#endif // GGML_SYCL_F16
+
+        KQ_sum[j_VKQ] *= kqmax_scale;
+        KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
+        if (item_ct1.get_local_id(2) == 0) {
+            KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ];
+        }
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+
+        if (nthreads <= D || tid < D) {
+            KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
+            KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
+
+#pragma unroll
+            for (int i0 = 0; i0 < D; i0 += nthreads) {
+                float dst_val = 0;
+#pragma unroll
+                for (int w = 0; w < nwarps; ++w) {
+#pragma unroll
+                    for (int v = 0; v < V_cols_per_iter; ++v) {
+                        dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
+                    }
+                }
+                if (item_ct1.get_group_range(1) == 1) {
+                    dst_val /= KQ_sum[j_VKQ];
+                }
+                dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) +
+                     item_ct1.get_group(1)) *
+                        D +
+                    i0 + tid] = dst_val;
+            }
+        }
+
+        if (j_VKQ < ncols-1) {
+            item_ct1.barrier(sycl::access::fence_space::local_space);
+        }
+
+    }
+
+    if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) {
+        dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) +
+                 item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]);
+    }
+#else
+    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+        max_bias, m0, m1, n_head_log2, logit_softcap,
+        ne00, ne01, ne02, ne03,
+              nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+              nb11, nb12, nb13,
+              nb21, nb22, nb23,
+              ne31, ne32, ne33,
+              nb31, nb32, nb33);
+
+#endif // SYCL_FLASH_ATTN
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+
+template 
+void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+
+    const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE
+
+    const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
+
+    const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc);
+    const int nwarps   = nthreads / warp_size;
+
+    const bool need_f16_K = type_K == GGML_TYPE_F16;
+    const bool need_f16_V = type_V == GGML_TYPE_F16;
+    constexpr size_t nbytes_shared = 0;
+
+    launch_fattn, warp_size>(
+        ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+}
+
+template 
+void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block = 1;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 2;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst);
+    }
+}
+
+#define DECL_FATTN_VEC_CASE(D, type_K, type_V)                              \
+    template void ggml_sycl_flash_attn_ext_vec_case                         \
+    (ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
+
+#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K)             \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16);  \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
+
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
+
+#endif // GGML_SYCL_FATTN_VEC_HPP
diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp
new file mode 100644
index 00000000..c276ed89
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn.cpp
@@ -0,0 +1,225 @@
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+
+#include 
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+#include "fattn-tile.hpp"
+#include "fattn-vec.hpp"
+#include "fattn.hpp"
+
+
+#define FATTN_VEC_CASE(D, type_K, type_V)                                                                        \
+    {                                                                                                            \
+        const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
+        const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
+        if (Q->ne[0] == (D) && type_K_okay && type_V_okay) {                                                     \
+            ggml_sycl_flash_attn_ext_vec_case(ctx, dst);                                      \
+            return;                                                                                              \
+        }                                                                                                        \
+    }                                                                    \
+
+#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
+    FATTN_VEC_CASE( 64, type_K, type_V)       \
+    FATTN_VEC_CASE(128, type_K, type_V)       \
+    FATTN_VEC_CASE(256, type_K, type_V)       \
+
+static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[0];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_SYCL_FA_ALL_QUANTS
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#else
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#endif // GGML_SYCL_FA_ALL_QUANTS
+
+    GGML_ABORT("Not match KV type in vec");
+}
+
+// Best FlashAttention kernel for a specific GPU:
+enum best_fattn_kernel {
+    BEST_FATTN_KERNEL_NONE     =   0,
+    BEST_FATTN_KERNEL_VEC      = 100,
+    BEST_FATTN_KERNEL_TILE     = 200,
+};
+
+static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
+    GGML_UNUSED(device);
+#ifndef SYCL_FLASH_ATTN
+    GGML_UNUSED(dst);
+    return BEST_FATTN_KERNEL_NONE;
+#endif// SYCL_FLASH_ATTN
+
+    if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;
+
+    const ggml_tensor * KQV   = dst;
+    const ggml_tensor * Q     = dst->src[0];
+    const ggml_tensor * K     = dst->src[1];
+    const ggml_tensor * V     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    for (const ggml_tensor * t : {Q, K, V, mask}) {
+        if (t == nullptr || ggml_is_quantized(t->type)) {
+            continue;
+        }
+        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+            if (t->nb[i] % 16 != 0) {
+                gqa_opt_applies = false;
+                break;
+            }
+        }
+    }
+
+    switch (K->ne[0]) {
+        case  40:
+        case  64:
+        case  72:
+        case  80:
+        case  96:
+        case 128:
+        case 112:
+        case 256:
+            if (V->ne[0] != K->ne[0]) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        case 576:
+            if (V->ne[0] != 512) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            if (!gqa_opt_applies) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
+    }
+
+#ifndef GGML_SYCL_FA_ALL_QUANTS
+    if (K->type != V->type) {
+        return BEST_FATTN_KERNEL_NONE;
+    }
+#endif // GGML_SYCL_FA_ALL_QUANTS
+
+    switch (K->type) {
+        case GGML_TYPE_F32:
+        case GGML_TYPE_F16:
+            break;
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+#ifndef GGML_SYCL_FA_ALL_QUANTS
+            return BEST_FATTN_KERNEL_NONE;
+#endif // GGML_SYCL_FA_ALL_QUANTS
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
+    }
+
+    if (mask && mask->ne[2] != 1) {
+        return BEST_FATTN_KERNEL_NONE;
+    }
+
+    // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
+    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+    // Todo: Use the XMX kernel if possible:
+
+    // If there are no tensor cores available, use the generic tile kernel:
+    if (can_use_vector_kernel) {
+        if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+            if (Q->ne[1] == 1) {
+                if (!gqa_opt_applies) {
+                    return BEST_FATTN_KERNEL_VEC;
+                }
+            }
+        } else {
+            if (Q->ne[1] <= 2) {
+                return BEST_FATTN_KERNEL_VEC;
+            }
+        }
+    }
+    return BEST_FATTN_KERNEL_TILE;
+}
+
+void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_set_device(ctx.device);
+    switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {
+        case BEST_FATTN_KERNEL_NONE:
+            GGML_ABORT("Not support Flash-Attention");
+        case BEST_FATTN_KERNEL_TILE:
+            ggml_sycl_flash_attn_ext_tile(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_VEC:
+            ggml_sycl_flash_attn_ext_vec(ctx, dst);
+            break;
+    }
+}
+
+bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
+    return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
+}
diff --git a/ggml/src/ggml-sycl/fattn.hpp b/ggml/src/ggml-sycl/fattn.hpp
new file mode 100644
index 00000000..f2a8ffc9
--- /dev/null
+++ b/ggml/src/ggml-sycl/fattn.hpp
@@ -0,0 +1,22 @@
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_FATTN_HPP
+#define GGML_SYCL_FATTN_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst);
+
+#endif // GGML_SYCL_FATTN_HPP
diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp
new file mode 100644
index 00000000..8c76afbd
--- /dev/null
+++ b/ggml/src/ggml-sycl/gated_delta_net.cpp
@@ -0,0 +1,309 @@
+#include 
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+#include "gated_delta_net.hpp"
+#include 
+
+
+template 
+void gated_delta_net_sycl(const float *     q,
+                          const float *     k,
+                          const float *     v,
+                          const float *     g,
+                          const float *     beta,
+                          const float *     curr_state,
+                          float *           dst,
+                          int64_t           H,
+                          int64_t           n_tokens,
+                          int64_t           n_seqs,
+                          int64_t           sq1,
+                          int64_t           sq2,
+                          int64_t           sq3,
+                          int64_t           sv1,
+                          int64_t           sv2,
+                          int64_t           sv3,
+                          int64_t           sb1,
+                          int64_t           sb2,
+                          int64_t           sb3,
+                          const sycl::uint3 neqk1_magic,
+                          const sycl::uint3 rq3_magic,
+                          float             scale) {
+    auto           item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const uint32_t h_idx    = item_ct1.get_group(2);
+    const uint32_t sequence = item_ct1.get_group(1);
+    // each warp owns one column, using warp-level primitives to reduce across rows
+    const int      lane     = item_ct1.get_local_id(2);
+    const int      col      = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
+
+    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+    const uint32_t iq3 = fastdiv(sequence, rq3_magic);
+
+    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+    float *       attn_data        = dst;
+    float *       state            = dst + attn_score_elems;
+
+    const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+    state += state_offset;
+    curr_state += state_offset;
+    attn_data += (sequence * n_tokens * H + h_idx) * S_v;
+
+    constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
+    static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+    float         s_shard[rows_per_lane];
+#pragma unroll
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i = r * warp_size + lane;
+        s_shard[r]  = curr_state[i * S_v + col];
+    }
+
+    for (int t = 0; t < n_tokens; t++) {
+        const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
+
+        const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
+        const float * beta_t = beta + gb_offset;
+        const float * g_t    = g    + gb_offset * (KDA ? S_v : 1);
+
+        const float beta_val = *beta_t;
+
+        if constexpr (!KDA) {
+            const float g_val = sycl::native::exp(*g_t);
+
+            // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
+            float kv_shard = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += s_shard[r] * k_t[i];
+            }
+            float kv_col = warp_reduce_sum(kv_shard);
+
+            // delta[col] = (v[col] - g * kv[col]) * beta
+            float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
+
+            // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_partial = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
+            }
+
+            float attn_col = warp_reduce_sum(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
+        } else {
+            // kv[col] = sum_i g[i] * S[i][col] * k[i]
+            float kv_shard = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i];
+            }
+
+            float kv_col = warp_reduce_sum(kv_shard);
+
+            // delta[col] = (v[col] - kv[col]) * beta
+            float delta_col = (v_t[col] - kv_col) * beta_val;
+
+            // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_partial = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
+            }
+
+            float attn_col = warp_reduce_sum(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
+        }
+
+        attn_data += S_v * H;
+    }
+
+    // Write state back to global memory
+#pragma unroll
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i          = r * warp_size + lane;
+        state[i * S_v + col] = s_shard[r];
+    }
+}
+
+template 
+static void launch_gated_delta_net(const float *   q_d,
+                                   const float *   k_d,
+                                   const float *   v_d,
+                                   const float *   g_d,
+                                   const float *   b_d,
+                                   const float *   s_d,
+                                   float *         dst_d,
+                                   int64_t         S_v,
+                                   int64_t         H,
+                                   int64_t         n_tokens,
+                                   int64_t         n_seqs,
+                                   int64_t         sq1,
+                                   int64_t         sq2,
+                                   int64_t         sq3,
+                                   int64_t         sv1,
+                                   int64_t         sv2,
+                                   int64_t         sv3,
+                                   int64_t         sb1,
+                                   int64_t         sb2,
+                                   int64_t         sb3,
+                                   int64_t         neqk1,
+                                   int64_t         rq3,
+                                   float           scale,
+                                   dpct::queue_ptr stream) {
+    //TODO: Add chunked kernel for even faster pre-fill
+    const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
+
+    const int num_warps = 4;
+    dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+    dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
+
+    const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
+    const sycl::uint3 rq3_magic   = init_fastdiv_values(rq3);
+
+    int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
+
+    switch (S_v) {
+        case 16:
+            {
+                constexpr int sv = 16;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                     [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                         gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
+                                                                       n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
+                                                                       sb3, neqk1_magic, rq3_magic, scale);
+                                     });
+            }
+            break;
+        case 32:
+            {
+                constexpr int sv = 32;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                     [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                         gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
+                                                                       n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
+                                                                       sb3, neqk1_magic, rq3_magic, scale);
+                                     });
+            }
+            break;
+        case 64: {
+            {
+                constexpr int sv = 64;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                            gated_delta_net_sycl(
+                                                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
+                                                sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+                                        });
+            }
+            break;
+        }
+        case 128: {
+            {
+                constexpr int sv = 128;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                            gated_delta_net_sycl(
+                                                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
+                                                sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+                                        });
+            }
+            break;
+        }
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * src_q     = dst->src[0];
+    ggml_tensor * src_k     = dst->src[1];
+    ggml_tensor * src_v     = dst->src[2];
+    ggml_tensor * src_g     = dst->src[3];
+    ggml_tensor * src_beta  = dst->src[4];
+    ggml_tensor * src_state = dst->src[5];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
+
+    const int64_t S_v      = nev0;
+    const int64_t H        = nev1;
+    const int64_t n_tokens = nev2;
+    const int64_t n_seqs   = nev3;
+
+    const bool kda = (src_g->ne[0] == S_v);
+
+    GGML_ASSERT(neq1 == nek1);
+    const int64_t neqk1 = neq1;
+
+    const int64_t rq3 = nev3 / neq3;
+
+    const float * q_d = (const float *) src_q->data;
+    const float * k_d = (const float *) src_k->data;
+    const float * v_d = (const float *) src_v->data;
+    const float * g_d = (const float *) src_g->data;
+    const float * b_d = (const float *) src_beta->data;
+
+    const float * s_d   = (const float *) src_state->data;
+    float *       dst_d = (float *) dst->data;
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+    GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
+    GGML_ASSERT(src_g->ne[0] == 1 || kda);
+    GGML_ASSERT(ggml_is_contiguous(src_g));
+    GGML_ASSERT(ggml_is_contiguous(src_beta));
+    GGML_ASSERT(ggml_is_contiguous(src_state));
+
+    // strides in floats (beta strides used for both g and beta offset computation)
+    const int64_t sq1 = nbq1 / sizeof(float);
+    const int64_t sq2 = nbq2 / sizeof(float);
+    const int64_t sq3 = nbq3 / sizeof(float);
+    const int64_t sv1 = nbv1 / sizeof(float);
+    const int64_t sv2 = nbv2 / sizeof(float);
+    const int64_t sv3 = nbv3 / sizeof(float);
+    const int64_t sb1 = nbb1 / sizeof(float);
+    const int64_t sb2 = nbb2 / sizeof(float);
+    const int64_t sb3 = nbb3 / sizeof(float);
+
+    const float scale = 1.0f / sqrtf((float) S_v);
+
+    dpct::queue_ptr stream = ctx.stream();
+
+    if (kda) {
+        launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
+    } else {
+        launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
+    }
+}
+
+void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);
+    ggml_sycl_op_gated_delta_net(ctx, dst);
+}
diff --git a/ggml/src/ggml-sycl/gated_delta_net.hpp b/ggml/src/ggml-sycl/gated_delta_net.hpp
new file mode 100644
index 00000000..a3308ee8
--- /dev/null
+++ b/ggml/src/ggml-sycl/gated_delta_net.hpp
@@ -0,0 +1,8 @@
+#pragma once
+
+#include 
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+
+void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 8f8176b6..12819705 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -35,6 +35,7 @@
 #endif
 #include 
 
+#include "ggml.h"
 #include "ggml-sycl.h"
 #include "ggml-impl.h"
 #include "ggml-backend-impl.h"
@@ -43,17 +44,18 @@
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml-sycl/element_wise.hpp"
+#include "ggml-sycl/gated_delta_net.hpp"
+#include "ggml-sycl/gemm.hpp"
+#include "ggml-sycl/getrows.hpp"
 #include "ggml-sycl/norm.hpp"
 #include "ggml-sycl/presets.hpp"
-#include "ggml-sycl/gemm.hpp"
+#include "ggml-sycl/quantize.hpp"
+#include "ggml-sycl/repeat_back.hpp"
 #include "ggml-sycl/set_rows.hpp"
 #include "ggml-sycl/set.hpp"
-#include "ggml-sycl/sycl_hw.hpp"
-#include "ggml-sycl/getrows.hpp"
-#include "ggml-sycl/repeat_back.hpp"
-#include "ggml-sycl/quantize.hpp"
 #include "ggml-sycl/ssm_conv.hpp"
-#include "ggml.h"
+#include "ggml-sycl/sycl_hw.hpp"
+
 
 static bool g_sycl_loaded = false;
 int g_ggml_sycl_debug = 0;
@@ -62,6 +64,8 @@ int g_ggml_sycl_disable_graph = 0;
 int g_ggml_sycl_disable_dnn = 0;
 int g_ggml_sycl_prioritize_dmmv = 0;
 int g_ggml_sycl_use_async_mem_op = 0;
+int g_ggml_sycl_enable_flash_attention = 1;
+
 
 static ggml_sycl_device_info ggml_sycl_init() {
     ggml_sycl_device_info info = {};
@@ -94,11 +98,14 @@ static ggml_sycl_device_info ggml_sycl_init() {
 
         info.devices[i].cc =
             100 * prop.get_major_version() + 10 * prop.get_minor_version();
-        info.devices[i].nsm = prop.get_max_compute_units();
+        info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
         info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
         info.devices[i].smpbo = prop.get_local_mem_size();
+        info.devices[i].warp_size = WARP_SIZE;
 
         info.max_work_group_sizes[i] = prop.get_max_work_group_size();
+        info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
+
     }
 
     for (int id = 0; id < info.device_count; ++id) {
@@ -211,7 +218,37 @@ static void ggml_check_sycl() try {
         g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
         g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
         g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
+
+#ifdef SYCL_FLASH_ATTN
+        g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
+#else
+        g_ggml_sycl_enable_flash_attention = 0;
+#endif
+
         GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
+
+        GGML_LOG_INFO("Build with Macros:\n");
+#if defined(GGML_SYCL_FORCE_MMQ)
+        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: no\n");
+#endif
+#if defined(GGML_SYCL_F16)
+        GGML_LOG_INFO("  GGML_SYCL_F16: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_F16: no\n");
+#endif
+#if defined(GGML_SYCL_GRAPH)
+        GGML_LOG_INFO("  GGML_SYCL_GRAPH: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_GRAPH: no\n");
+#endif
+#if defined(GGML_SYCL_DNNL)
+        GGML_LOG_INFO("  GGML_SYCL_DNNL: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_DNNL: no\n");
+#endif
+
         GGML_LOG_INFO("Running with Environment Variables:\n");
         GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
         GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
@@ -226,16 +263,12 @@ static void ggml_check_sycl() try {
         GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
 #endif
         GGML_LOG_INFO("  GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
-        GGML_LOG_INFO("Build with Macros:\n");
-#if defined(GGML_SYCL_FORCE_MMQ)
-        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: yes\n");
+
+#ifdef SYCL_FLASH_ATTN
+        GGML_LOG_INFO("  GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
 #else
-        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: no\n");
-#endif
-#if defined(GGML_SYCL_F16)
-        GGML_LOG_INFO("  GGML_SYCL_F16: yes\n");
-#else
-        GGML_LOG_INFO("  GGML_SYCL_F16: no\n");
+        GGML_LOG_INFO("  GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
+            g_ggml_sycl_enable_flash_attention);
 #endif
 
 /* NOT REMOVE, keep it for next optimize for XMX.
@@ -1157,13 +1190,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_
     GGML_UNUSED(buft);
 }
 
+inline void * aligned_malloc_host(size_t alignment, size_t size) {
+#ifdef _WIN32
+    return _aligned_malloc(size, alignment);
+#else
+    return aligned_alloc(alignment, size);
+#endif
+}
+
+inline void free_aligned_mem_host(void * memblock) {
+#ifdef _WIN32
+    _aligned_free(memblock);
+#else
+    free(memblock);
+#endif
+}
+
 static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_sycl_host_free(buffer->context);
+    free_aligned_mem_host((void *)buffer->context);
 }
 
 static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * ptr = ggml_sycl_host_malloc(size);
-
+    void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size);
     if (ptr == nullptr) {
         // fallback to cpu buffer
         return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
@@ -1825,6 +1873,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
     }
 }
 
+static void top_k_f32_sycl(
+    const float * src,
+    int32_t * dst_indices,
+    const int64_t ncols,
+    const int64_t nrows,
+    const int k,
+    dpct::queue_ptr main_stream
+) {
+    const int block_size = 128;
+
+    const sycl::range<1> block_dims(block_size);
+    const sycl::range<1> grid_dims(nrows);
+
+    main_stream->submit([&](sycl::handler &cgh) {
+        sycl::local_accessor shared_vals(sycl::range<1>(block_size * k), cgh);
+        sycl::local_accessor shared_idx(sycl::range<1>(block_size * k), cgh);
+
+        cgh.parallel_for(
+            sycl::nd_range<1>(grid_dims * block_dims, block_dims),
+            [=](sycl::nd_item<1> item_ct1) {
+                const int row = item_ct1.get_group(0);
+                const int tid = item_ct1.get_local_id(0);
+
+                if (row >= nrows) return;
+
+                const float * src_row = src + row * ncols;
+                int32_t * dst_idx_row = dst_indices + row * k;
+
+                float local_vals[32];
+                int local_idx[32];
+
+                for (int i = 0; i < k; i++) {
+                    local_vals[i] = -FLT_MAX;
+                    local_idx[i] = -1;
+                }
+
+                for (int col = tid; col < ncols; col += block_size) {
+                    float val = src_row[col];
+
+                    if (val > local_vals[k-1]) {
+                        int pos = k - 1;
+                        while (pos > 0 && val > local_vals[pos - 1]) {
+                            pos--;
+                        }
+
+                        for (int i = k - 1; i > pos; i--) {
+                            local_vals[i] = local_vals[i - 1];
+                            local_idx[i] = local_idx[i - 1];
+                        }
+                        local_vals[pos] = val;
+                        local_idx[pos] = col;
+                    }
+                }
+
+                for (int i = 0; i < k; i++) {
+                    shared_vals[tid * k + i] = local_vals[i];
+                    shared_idx[tid * k + i] = local_idx[i];
+                }
+                item_ct1.barrier(sycl::access::fence_space::local_space);
+
+                if (tid == 0) {
+                    float final_vals[32];
+                    int final_idx[32];
+
+                    for (int i = 0; i < k; i++) {
+                        final_vals[i] = -FLT_MAX;
+                        final_idx[i] = -1;
+                    }
+
+                    for (int t = 0; t < block_size; t++) {
+                        for (int i = 0; i < k; i++) {
+                            float val = shared_vals[t * k + i];
+                            int idx = shared_idx[t * k + i];
+
+                            if (val > final_vals[k-1]) {
+                                int pos = k - 1;
+                                while (pos > 0 && val > final_vals[pos - 1]) {
+                                    pos--;
+                                }
+
+                                for (int j = k - 1; j > pos; j--) {
+                                    final_vals[j] = final_vals[j - 1];
+                                    final_idx[j] = final_idx[j - 1];
+                                }
+                                final_vals[pos] = val;
+                                final_idx[pos] = idx;
+                            }
+                        }
+                    }
+
+                    for (int i = 0; i < k; i++) {
+                        dst_idx_row[i] = final_idx[i];
+                    }
+
+                    if (k > 1) {
+                        int32_t temp = dst_idx_row[0];
+                        dst_idx_row[0] = dst_idx_row[1];
+                        dst_idx_row[1] = temp;
+                    }
+                }
+            });
+    });
+}
+
 static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
                                const int nrows, queue_ptr stream) {
     const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
@@ -2048,8 +2200,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
             const sycl::half alpha_f16 = 1.0f;
             const sycl::half beta_f16  = 0.0f;
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
-                *stream, oneapi::math::transpose::trans,
-                oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
+                *stream, oneapi::mkl::transpose::trans,
+                oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
                 &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
                 src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
                 dst_f16.get(), dpct::library_data_t::real_half, ldc,
@@ -2092,8 +2244,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
         {
             const float alpha = 1.0f;
             const float beta  = 0.0f;
-            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
-                get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
+            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
+                *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
                 src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
                 dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
         }
@@ -2216,6 +2368,30 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor *
                          main_stream, ctx.device);
 }
 
+static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float * src0_dd = static_cast(src0->data);
+    int32_t * dst_dd = static_cast(dst->data);
+
+    const int k = dst->ne[0];
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    GGML_ASSERT(k > 0 && k <= 32);
+    GGML_ASSERT(k <= ncols);
+
+    top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream);
+}
+
 inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_I32);
@@ -2248,6 +2424,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
     diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
 }
 
+static void tri_f32_sycl(
+    const float * src,
+    float * dst,
+    const int64_t ne0,
+    const int64_t ne1,
+    const int64_t ne2,
+    const int64_t ne3,
+    const ggml_tri_type ttype,
+    dpct::queue_ptr main_stream
+) {
+    const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
+
+    main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
+        const int64_t idx = (int64_t) tid[0];
+
+        const int64_t i0 = idx % ne0;
+        const int64_t t1 = idx / ne0;
+        const int64_t i1 = t1 % ne1;
+
+        bool keep = false;
+        switch (ttype) {
+            case GGML_TRI_TYPE_LOWER:      keep = (i0 <  i1); break;
+            case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
+            case GGML_TRI_TYPE_UPPER:      keep = (i0 >  i1); break;
+            case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
+            default: keep = false; break;
+        }
+
+        dst[idx] = keep ? src[idx] : 0.0f;
+    });
+}
+
+static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    GGML_ASSERT(src0);
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float * src0_dd = static_cast(src0->data);
+    float *       dst_dd  = static_cast(dst->data);
+
+    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
+
+    const int64_t ne0 = src0->ne[0];
+    const int64_t ne1 = src0->ne[1];
+    const int64_t ne2 = src0->ne[2];
+    const int64_t ne3 = src0->ne[3];
+
+    tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
+}
+
+
 inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -2810,7 +3045,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
 
         }
 #if GGML_SYCL_DNNL
-        // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
+        // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
         const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
         src1_f16_alloc.alloc(ne_src1);
         const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
@@ -2819,7 +3054,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
 # else
         const int64_t ne_src1 = ggml_nelements(src1);
         src1_f16_alloc.alloc(ne_src1);
-        const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
+        const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
         GGML_ASSERT(to_fp16_nc_sycl != nullptr);
         to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
 #endif
@@ -2963,8 +3198,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
             const int64_t smb = ne12 == 1 ? s13       : s12;
 
             // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
-                                                        oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
+                                                        oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
                                                         src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
                                                         src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
                                                         mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
@@ -2988,7 +3223,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
             });
 
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-                *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+                *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
                 (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
                 (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
                 (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
@@ -3316,18 +3551,17 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
 
 
     // mmvq and mmq need the __dp4a instruction which is available for gen12+
-    // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
+    // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
     use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
 #ifdef SYCL_USE_XMX
     use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
 #endif // SYCL_USE_XMX
 
-    // mmvq path is faster in the CUDA backend.
-    if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
-        // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
-        // is enabled takes precedence over DMMV, the current if-else implementation
-        // requires disabling DMMV if both conditions are met
-        || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
+    // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
+    // is enabled takes precedence over DMMV, the current if-else implementation
+    // requires disabling DMMV if both conditions are met
+    if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
+                                          ggml_sycl_supports_reorder_mmvq(src0->type)))) {
         use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
     }
 
@@ -3771,6 +4005,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                 case GGML_UNARY_OP_EXP:
                     ggml_sycl_exp(ctx, dst);
                     break;
+                case GGML_UNARY_OP_SOFTPLUS:
+                    ggml_sycl_softplus(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_SGN:
                     ggml_sycl_sgn(ctx, dst);
                     break;
@@ -3897,6 +4134,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_TRANSPOSE:
             GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
             break;
+        case GGML_OP_TRI:
+            ggml_sycl_op_tri(ctx, dst);
+            break;
         case GGML_OP_DIAG_MASK_INF:
             ggml_sycl_diag_mask_inf(ctx, dst);
             break;
@@ -3909,6 +4149,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ROPE:
             ggml_sycl_rope(ctx, dst);
             break;
+        case GGML_OP_ROPE_BACK:
+            ggml_sycl_rope_back(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_sycl_im2col(ctx, dst);
             break;
@@ -3927,6 +4170,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ARGSORT:
             ggml_sycl_argsort(ctx, dst);
             break;
+        case GGML_OP_TOP_K:
+            ggml_sycl_op_top_k(ctx, dst);
+            break;
         case GGML_OP_TIMESTEP_EMBEDDING:
             ggml_sycl_op_timestep_embedding(ctx, dst);
             break;
@@ -3939,6 +4185,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_sycl_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_GATED_DELTA_NET:
+            ggml_sycl_gated_delta_net(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_sycl_ssm_conv(ctx, dst);
             break;
@@ -3948,6 +4197,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ARANGE:
             ggml_sycl_arange(ctx, dst);
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_sycl_flash_attn_ext(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -3978,16 +4230,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free,
     GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
     ggml_sycl_set_device(device);
 
-    /*
-    DPCT1009:218: SYCL uses exceptions to report errors and does not use the
-    error codes. The original code was commented out and a warning string was
-    inserted. You need to rewrite this code.
-    */
-    /*
-    DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
-    device information which may not be supported by all compilers or runtimes.
-    You may need to adjust the code.
-    */
     SYCL_CHECK(CHECK_TRY_ERROR(
         dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
 }
@@ -4109,6 +4351,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
         if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
             continue;
         }
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
 #ifndef NDEBUG
         assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device));
         for (int j = 0; j < GGML_MAX_SRC; j++) {
@@ -4386,10 +4631,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_SOFTPLUS:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_CEIL:
                     return true;
                 case GGML_UNARY_OP_FLOOR:
-                case GGML_UNARY_OP_CEIL:
                 case GGML_UNARY_OP_ROUND:
                 case GGML_UNARY_OP_TRUNC:
 #if defined (GGML_SYCL_F16)
@@ -4588,18 +4834,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type);
 #endif
         case GGML_OP_NORM:
-            return true;
         case GGML_OP_L2_NORM:
         case GGML_OP_GROUP_NORM:
-            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
-            return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+            return true;
         case GGML_OP_RMS_NORM_BACK:
-            return ((op->src[0]->ne[0] % WARP_SIZE) == 0);
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_SCALE:
             return true;
         case GGML_OP_CONT:
             return op->src[0]->type != GGML_TYPE_BF16;
+        case GGML_OP_TRI:
+            {
+                const ggml_tensor * src0 = op->src[0];
+                return src0 &&
+                       op->type == GGML_TYPE_F32 &&
+                       ggml_is_contiguous(src0);
+            }
         case GGML_OP_DIAG_MASK_INF:
             return true;
         case GGML_OP_SOFT_MAX:
@@ -4610,6 +4861,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return max_bias == 0.0f;
         }
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
         case GGML_OP_IM2COL:
             return true;
         case GGML_OP_UPSCALE:
@@ -4621,9 +4873,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ARGSORT:
             return op->src[0]->ne[0] * sizeof(int) <=
                    ggml_sycl_info().devices[device].smpbo;
+        case GGML_OP_TOP_K: {
+            const ggml_tensor * src0 = op->src[0];
+            const int k = op->ne[0];
+            return src0 &&
+                op->type == GGML_TYPE_I32 &&
+                src0->type == GGML_TYPE_F32 &&
+                ggml_is_contiguous(src0) &&
+                k > 0 && k <= 32;
+        }
         case GGML_OP_POOL_2D:
-        case GGML_OP_ACC:
             return true;
+        case GGML_OP_ACC:
+            return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
         case GGML_OP_PAD:
             // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
             if (ggml_get_op_params_i32(op, 8) != 0) {
@@ -4635,6 +4897,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
         case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_GATED_DELTA_NET:
             return true;
         case GGML_OP_SSM_CONV:
             return op->type == GGML_TYPE_F32 &&
@@ -4644,6 +4907,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return op->type == GGML_TYPE_F32;
         case GGML_OP_ARANGE:
             return op->type == GGML_TYPE_F32;
+        case GGML_OP_FLASH_ATTN_EXT:
+            return ggml_sycl_flash_attn_ext_supported(device, op);
         default:
             return false;
     }
diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp
index 823d3a48..09fce128 100644
--- a/ggml/src/ggml-sycl/norm.cpp
+++ b/ggml/src/ggml-sycl/norm.cpp
@@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
     }
 }
 
-static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
-    const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
-    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
-        item_ct1.get_local_id(1);
-    const int tid = item_ct1.get_local_id(2);
-    const int nthreads = item_ct1.get_local_range(2);
-    const int nwarps = nthreads / WARP_SIZE;
+template
+static void l2_norm_f32(const float * x, float * dst, const int ncols,
+    const int64_t stride_row, const int64_t stride_channel,
+    const int64_t stride_sample, const float eps,
+    const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
+    const int nrows     = item_ct1.get_group_range(2);
+    const int nchannels = item_ct1.get_group_range(1);
+
+    const int row     = item_ct1.get_group(2);
+    const int channel = item_ct1.get_group(1);
+    const int sample  = item_ct1.get_group(0);
+    const int tid     = item_ct1.get_local_id(2);
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row * ncols + col];
+        const float xi = x[col];
         tmp += xi * xi;
     }
 
-    // sum up partial sums
-    tmp = warp_reduce_sum(tmp, item_ct1);
-    if (block_size > WARP_SIZE) {
-
-        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
-        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
-        if (lane_id == 0) {
-            s_sum[warp_id] = tmp;
-        }
-        /*
-        DPCT1118:3: SYCL group functions and algorithms must be encountered in
-        converged control flow. You may need to adjust the code.
-        */
-        item_ct1.barrier(sycl::access::fence_space::local_space);
-        size_t nreduce = nwarps / WARP_SIZE;
-        tmp = 0.f;
-        for (size_t i = 0; i < nreduce; i += 1)
-        {
-            tmp += s_sum[lane_id + i * WARP_SIZE];
-        }
-        tmp = warp_reduce_sum(tmp, item_ct1);
-    }
-
-    const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
+    tmp = block_reduce(tmp, s_sum, block_size);
+    const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row * ncols + col] = scale * x[row * ncols + col];
+        dst[col] = scale * x[col];
     }
 }
 
@@ -251,7 +238,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
         const float eps, queue_ptr stream, int device) {
 
     const sycl::range<3> global_dims(nsamples, nchannels, nrows);
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     if (ncols < 1024) {
         const sycl::range<3> block_dims(1, 1, WARP_SIZE);
         stream->submit([&](sycl::handler& cgh) {
@@ -334,7 +320,6 @@ static void group_norm_f32_sycl(const float* x, float* dst,
 
 static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
         const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
     // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
 
     const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -371,43 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
     }
 }
 
-static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
-    const int nrows, const float eps,
-    queue_ptr stream, int device) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
-    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
+template
+static void l2_norm_f32_sycl(const float *   x,
+                             float *         dst,
+                             const int       ncols,
+                             const int       nrows,
+                             const int       nchannels,
+                             const int       nsamples,
+                             const int64_t   stride_row,
+                             const int64_t   stride_channel,
+                             const int64_t   stride_sample,
+                             const float     eps,
+                             queue_ptr       stream,
+                             int             device) {
+    const dpct::dim3 blocks_num(nrows, nchannels, nsamples);
+
     if (ncols < 1024) {
-        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+        const dpct::dim3 block_dims(warp_size, 1, 1);
         stream->submit([&](sycl::handler& cgh) {
             cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                sycl::nd_range<3>(blocks_num * block_dims,
                     block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                    l2_norm_f32(x, dst, ncols, eps, item_ct1,
-                        nullptr, WARP_SIZE);
+                [[sycl::reqd_sub_group_size(warp_size)]] {
+                    l2_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
+                        nullptr, warp_size);
                 });
             });
     }
     else {
         const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
-        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
+        assert(work_group_size % (warp_size * warp_size) == 0);
         const sycl::range<3> block_dims(1, 1, work_group_size);
-        /*
-        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
+        int lsm_size =  block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0;
         stream->submit([&](sycl::handler& cgh) {
-            sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
+            sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(lsm_size),
                 cgh);
+
             cgh.parallel_for(
-                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
+                sycl::nd_range<3>(blocks_num * block_dims,
                     block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
-                    l2_norm_f32(x, dst, ncols, eps, item_ct1,
-                        get_pointer(s_sum_acc_ct1), work_group_size);
+                [[sycl::reqd_sub_group_size(warp_size)]] {
+                    l2_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample,
+                        eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
                 });
             });
     }
@@ -637,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d
 }
 
 void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+    dpct::queue_ptr     stream = ctx.stream();
 
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-
-    const int64_t ne00 = dst->src[0]->ne[0];
-    const int64_t nrows = ggml_nrows(dst->src[0]);
-    const float * src0_dd = static_cast(dst->src[0]->data);
-    float * dst_dd = static_cast(dst->data);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
-    l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
 
+    /*support both WARP_SIZE or WARP_32_SIZE in code
+      choose by hardware for better performance
+    */
+    l2_norm_f32_sycl(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
 }
diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp
index 3a17f3a1..f52b11f0 100644
--- a/ggml/src/ggml-sycl/outprod.cpp
+++ b/ggml/src/ggml-sycl/outprod.cpp
@@ -32,12 +32,12 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
     // Handle transposition of src1
     const bool src1_T = ggml_is_transposed(src1);
-    const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
+    const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
     const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
 
     try {
-        // Perform matrix multiplication using oneMath GEMM
-        oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
+        // Perform matrix multiplication using oneMKL GEMM
+        oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op,
                                                ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
     }
     catch (sycl::exception const& exc) {
diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp
index b6517374..dc4dad1d 100644
--- a/ggml/src/ggml-sycl/presets.hpp
+++ b/ggml/src/ggml-sycl/presets.hpp
@@ -73,4 +73,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
 #define QK_WARP_SIZE 32
+#define WARP_32_SIZE 32
+#define WARP_16_SIZE 16
+
 #endif // GGML_SYCL_PRESETS_HPP
diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp
index d0d5ac9a..14490fea 100644
--- a/ggml/src/ggml-sycl/quants.hpp
+++ b/ggml/src/ggml-sycl/quants.hpp
@@ -29,7 +29,7 @@ namespace ggml_sycl_reordered {
 // [qs0, qs1, qs2, ..., qsN]  [d0, d1, d2, ..., dN]
 //
 // Notes: out-of-bounds qs will run into d values
-// Aligment relies on the allocated size of qs
+// Alignment relies on the allocated size of qs
 
 template  struct block_q_t;
 
diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp
index 69140b19..9d83a1e9 100644
--- a/ggml/src/ggml-sycl/rope.cpp
+++ b/ggml/src/ggml-sycl/rope.cpp
@@ -1,4 +1,5 @@
 #include "rope.hpp"
+#include "convert.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml.h"
 
@@ -15,367 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
     return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
 }
 
-// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
-// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
-static void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
-    // Get n-d rotational scaling corrected for extrapolation
+template 
+static void rope_yarn(const float theta_extrap, const float freq_scale,
+                      const rope_corr_dims corr_dims, const int64_t i0,
+                      const float ext_factor, float mscale, float &cos_theta,
+                      float &sin_theta) {
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
     if (ext_factor != 0.0f) {
-        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        float ramp_mix =
+            rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
         theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 
-        // Get n-d magnitude scaling corrected for interpolation
         mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
     }
-    *cos_theta = sycl::cos(theta) * mscale;
-    *sin_theta = sycl::sin(theta) * mscale;
+    cos_theta = sycl::cos(theta) * mscale;
+    sin_theta = sycl::sin(theta) * mscale;
+    if (!forward) {
+        sin_theta *= -1.0f;
+    }
 }
 
-template 
-static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
-                      const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
-                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
-                      const sycl::nd_item<3> & item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
+template 
+static void rope_norm(const T *x, D *dst, const int ne00, const int ne01,
+                      const int ne02, const int s01, const int s02,
+                      const int s03, const int s1, const int s2, const int s3,
+                      const int n_dims, const int32_t *pos,
+                      const float freq_scale, const float ext_factor,
+                      const float attn_factor, const rope_corr_dims corr_dims,
+                      const float theta_scale, const float *freq_factors,
+                      const int64_t *row_indices, const int set_rows_stride) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
 
-    const int row0     = row % ne1;
-    const int channel0 = row / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int i  = row * ne0 + i0;
-    const int i2 = channel0 * s2 + row0 * s1 + i0;
+    int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
 
+    if (set_rows_stride != 0) {
+        idst = i1 * s1 + i0;
+        idst += row_indices[i2] * set_rows_stride;
+    }
+
+    const auto &store_coaelsced = [&](float x0, float x1) {
+        if constexpr (std::is_same_v) {
+            sycl::float2 v = sycl::float2(x0, x1);
+            ggml_sycl_memcpy_1<8>(dst + idst, &v);
+        } else if constexpr (std::is_same_v) {
+            sycl::half2 v = sycl::half2(x0, x1);
+            ggml_sycl_memcpy_1<4>(dst + idst, &v);
+        }
+    };
     if (i0 >= n_dims) {
-        *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2);
+        store_coaelsced(x[ix + 0], x[ix + 1]);
         return;
     }
 
-    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
+    const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i2 + 0];
-    const float x1 = x[i2 + 1];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + 1];
 
-    dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
-    dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
+    store_coaelsced(x0 * cos_theta - x1 * sin_theta,
+                    x0 * sin_theta + x1 * cos_theta);
 }
 
-template 
-static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
-                      const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
-                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
-                      const sycl::nd_item<3> & item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
+template 
+static void rope_neox(const T *x, D *dst, const int ne00, const int ne01,
+                      const int ne02, const int s01, const int s02,
+                      const int s03, const int s1, const int s2, const int s3,
+                      const int n_dims, const int32_t *pos,
+                      const float freq_scale, const float ext_factor,
+                      const float attn_factor, const rope_corr_dims corr_dims,
+                      const float theta_scale, const float *freq_factors,
+                      const int64_t *row_indices, const int set_rows_stride) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
 
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
 
-    const int row0     = row % ne1;
-    const int channel0 = row / ne1;
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 
-    const int i  = row * ne0 + i0 / 2;
-    const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
+    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
+
+    if (set_rows_stride != 0) {
+        idst = i1 * s1 + i0 / 2;
+        idst += row_indices[i2] * set_rows_stride;
+    }
 
     if (i0 >= n_dims) {
-        *reinterpret_cast *>(dst + i + i0 / 2) = *reinterpret_cast *>(x + i2 + i0 / 2);
+        dst[idst + i0 / 2 + 0] = ggml_sycl_cast(x[ix + i0 / 2 + 0]);
+        dst[idst + i0 / 2 + 1] = ggml_sycl_cast(x[ix + i0 / 2 + 1]);
+
         return;
     }
 
-    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
+    const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i2 + 0];
-    const float x1 = x[i2 + n_dims / 2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims / 2];
 
-    dst[i + 0]          = x0 * cos_theta - x1 * sin_theta;
-    dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
+    dst[idst + 0] = ggml_sycl_cast(x0 * cos_theta - x1 * sin_theta);
+    dst[idst + n_dims / 2] = ggml_sycl_cast(x0 * sin_theta + x1 * cos_theta);
 }
 
-template 
-static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                        const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
-                        const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-                        const float theta_scale, const float * freq_factors, const mrope_sections sections,
-                        const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
-    // get index pos
-    const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
-    if (i0 >= ne0) {
+template 
+static void rope_multi(const T *x, T *dst, const int ne00, const int ne01,
+                       const int ne02, const int s01, const int s02,
+                       const int s03, const int s1, const int s2, const int s3,
+                       const int n_dims, const int32_t *pos,
+                       const float freq_scale, const float ext_factor,
+                       const float attn_factor, const rope_corr_dims corr_dims,
+                       const float theta_scale, const float *freq_factors,
+                       const mrope_sections sections, const bool is_imrope) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
+
+    if (i0 >= ne00) {
         return;
     }
-    const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
 
-    const int    row_x     = row_dst % ne1;
-    const int    channel_x = row_dst / ne1;
-    const int    idst      = (row_dst * ne0) + (i0 / 2);
-    const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
+
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     if (i0 >= n_dims) {
-        *reinterpret_cast *>(dst + idst + i0 / 2) = *reinterpret_cast *>(x + i0 / 2 + ix);
+        dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0];
+        dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1];
+
         return;
     }
 
-    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    const int sect_dims =
+        sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
     const int sec_w = sections.v[1] + sections.v[0];
     const int sector = (i0 / 2) % sect_dims;
 
-
     float theta_base = 0.0;
     if (is_imrope) {
-        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
-            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
-        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
-            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
+        if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
+            theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
+            theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
+            theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
         } else {
-            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+            theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
         }
     } else {
         if (sector < sections.v[0]) {
-            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sections.v[0] && sector < sec_w) {
-            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
-        }
-        else if (sector >= sec_w + sections.v[2]) {
-            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
+            theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector >= sections.v[0] && sector < sec_w) {
+            theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f);
+        } else if (sector >= sec_w + sections.v[2]) {
+            theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f);
         }
     }
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
-    float       cos_theta;
-    float       sin_theta;
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
-    const float x0 = x[ix + 0];
-    const float x1 = x[ix + n_dims/2];
 
-    // store results in dst
-    dst[idst + 0]      = x0 * cos_theta - x1 * sin_theta;
-    dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
+
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims / 2];
+
+    dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
+    dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
 }
 
+template 
+static void rope_vision(const T *x, T *dst, const int ne00, const int ne01,
+                        const int ne02, const int s01, const int s02,
+                        const int s03, const int s1, const int s2, const int s3,
+                        const int n_dims, const int32_t *pos,
+                        const float freq_scale, const float ext_factor,
+                        const float attn_factor, const rope_corr_dims corr_dims,
+                        const float theta_scale, const float *freq_factors,
+                        const mrope_sections sections) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                        item_ct1.get_local_id(1));
 
-
-template 
-static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                        const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
-                        const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-                        const float theta_scale, const float * freq_factors, const mrope_sections sections,
-                        const sycl::nd_item<3> & item_ct1) {
-    // get index pos
-    const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
-    if (i0 >= ne0) {
+    if (i0 >= ne00) {
         return;
     }
-    const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
-    const int    row_x     = row_dst % ne1;
-    const int    channel_x = row_dst / ne1;
-    const int    idst      = (row_dst * ne0) + (i0 / 2);
-    const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
+
+    const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                        item_ct1.get_local_id(2);
+
+    const uint32_t i3 = row_dst / (ne01 * ne02);
+    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+    int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+    const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
 
     const int sect_dims = sections.v[0] + sections.v[1];
-    const int sector    = (i0 / 2) % sect_dims;
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
-    float theta_base = 0.0f;
+    float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base  = pos[channel_x] * sycl::pow(theta_scale, (float) p);
-    } else {
-        // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
+        theta_base = pos[i2] * dpct::pow(theta_scale, p);
+    } else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base  = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
+        theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
-    float       cos_theta;
-    float       sin_theta;
-    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0,
+                       ext_factor, attn_factor, cos_theta, sin_theta);
+
     const float x0 = x[ix + 0];
     const float x1 = x[ix + n_dims];
 
-    // store results in dst
-    dst[idst + 0]      = x0 * cos_theta - x1 * sin_theta;
+    dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
     dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
 }
 
-template 
-static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
-                           const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
-                           const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
-                           const float * freq_factors, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int            num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3> block_nums(1, num_blocks_x, nr);
+template 
+static void
+rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01,
+               const int ne02, const int s01, const int s02, const int s03,
+               const int s1, const int s2, const int s3, const int n_dims,
+               const int nr, const int32_t *pos, const float freq_scale,
+               const float freq_base, const float ext_factor,
+               const float attn_factor, const rope_corr_dims corr_dims,
+               const float *freq_factors, const int64_t *row_indices,
+               const int set_rows_stride, dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-
     if (freq_factors == nullptr) {
-        /*
-        DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                                theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_norm(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     } else {
-        /*
-        DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                               theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_norm(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     }
 }
 
-template 
-static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
-                           const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
-                           const float freq_base, const float ext_factor, const float attn_factor,
-                           const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int            num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3> block_nums(1, num_blocks_x, nr);
+template 
+static void
+rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01,
+               const int ne02, const int s01, const int s02, const int s03,
+               const int s1, const int s2, const int s3, const int n_dims,
+               const int nr, const int32_t *pos, const float freq_scale,
+               const float freq_base, const float ext_factor,
+               const float attn_factor, const rope_corr_dims corr_dims,
+               const float *freq_factors, const int64_t *row_indices,
+               const int set_rows_stride, dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
 
     const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-
     if (freq_factors == nullptr) {
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                                theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_neox(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     } else {
-        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
-            rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
-                               theta_scale, freq_factors, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_neox(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, row_indices, set_rows_stride);
+            });
     }
 }
 
-template 
-static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                             const size_t s2, const int n_dims, const int nr, const int32_t * pos,
-                             const float freq_scale, const float freq_base, const float ext_factor,
-                             const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
-                             const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3>    grid_dims(1, n_blocks_y, nr);
-    const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
+template 
+static void
+rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01,
+                const int ne02, const int s01, const int s02, const int s03,
+                const int s1, const int s2, const int s3, const int n_dims,
+                const int nr, const int32_t *pos, const float freq_scale,
+                const float freq_base, const float ext_factor,
+                const float attn_factor, const rope_corr_dims corr_dims,
+                const float *freq_factors, const mrope_sections sections,
+                const bool is_imrope, dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
-    // Add FP16 capability check if T could be sycl::half
-    if constexpr (std::is_same_v) {
-        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-    }
-    // launch kernel
     if (freq_factors == nullptr) {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                  corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_multi(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections, is_imrope);
+            });
     } else {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                 corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_multi(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections, is_imrope);
+            });
     }
 }
 
+template 
+static void
+rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01,
+                 const int ne02, const int s01, const int s02, const int s03,
+                 const int s1, const int s2, const int s3, const int n_dims,
+                 const int nr, const int32_t *pos, const float freq_scale,
+                 const float freq_base, const float ext_factor,
+                 const float attn_factor, const rope_corr_dims corr_dims,
+                 const float *freq_factors, const mrope_sections sections,
+                 dpct::queue_ptr stream) {
+    GGML_ASSERT(ne00 % 2 == 0);
+    const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x =
+        (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
+    const dpct::dim3 block_nums(nr, n_blocks_x, 1);
 
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-
-// rope vision
-template 
-static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
-                             const size_t s2, const int n_dims, const int nr, const int32_t * pos,
-                             const float freq_scale, const float freq_base, const float ext_factor,
-                             const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
-                             const mrope_sections sections, queue_ptr stream) {
-    GGML_ASSERT(ne0 % 2 == 0);
-    const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
-    const sycl::range<3>    grid_dims(1, n_blocks_y, nr);
-    const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
-
-    const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
-    // Add FP16 capability check if T could be sycl::half
-    if constexpr (std::is_same_v) {
-        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-    }
-    // launch kernel
     if (freq_factors == nullptr) {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                  corr_dims, theta_scale, freq_factors, sections, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_vision(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections);
+            });
     } else {
-        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
-            rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
-                                 corr_dims, theta_scale, freq_factors, sections, item_ct1);
-        });
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                GGML_UNUSED(item_ct1);
+                rope_vision(
+                    x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims,
+                    pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                    theta_scale, freq_factors, sections);
+            });
     }
 }
 
-inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+template 
+void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst,
+                            const ggml_tensor *set_rows = nullptr) {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
+    const ggml_tensor *src2 = dst->src[2];
 
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    const int64_t ne00 = dst->src[0]->ne[0]; // head dims
-    const int64_t ne01 = dst->src[0]->ne[1]; // num heads
-    const int64_t ne02 = dst->src[0]->ne[2]; // num heads
-    const int64_t nr = ggml_nrows(dst->src[0]);
+    const float *src0_d = (const float *)src0->data;
+    const float *src1_d = (const float *)src1->data;
 
-    const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
-    const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
+    void *dst_d = dst->data;
+    const int64_t *row_indices = nullptr;
+    ggml_type dst_type = dst->type;
+    int set_rows_stride = 0;
 
+    if (set_rows != nullptr) {
+        GGML_ASSERT(forward);
+        dst_d = set_rows->data;
+        row_indices = (const int64_t *)set_rows->src[1]->data;
+        dst_type = set_rows->type;
+        set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
+    }
+    dpct::queue_ptr stream = ctx.stream();
 
-    //const int n_past      = ((int32_t *) dst->op_params)[0];
-    const int n_dims      = ((int32_t *) dst->op_params)[1];
-    const int mode        = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type ||
+                (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
+
+    const int64_t ne00 = src0->ne[0]; // head dims
+    const int64_t ne01 = src0->ne[1]; // num heads
+    const int64_t ne02 = src0->ne[2]; // num heads
+    const int64_t nr = ggml_nrows(src0);
+
+    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
+    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
+
+    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
+    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
+    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
+
+    const int n_dims = ((int32_t *)dst->op_params)[1];
+    const int mode = ((int32_t *)dst->op_params)[2];
+    const int n_ctx_orig = ((int32_t *)dst->op_params)[4];
     mrope_sections sections;
 
-    // RoPE alteration for extended context
     float freq_base;
     float freq_scale;
     float ext_factor;
@@ -383,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
     float beta_fast;
     float beta_slow;
 
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(§ions.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
+    memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float));
+    memcpy(§ions.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4);
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
     const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
@@ -397,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
     if (is_mrope) {
-        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
+        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 ||
+                    sections.v[2] > 0);
     }
 
     if (is_vision) {
-        GGML_ASSERT(n_dims == ne00/2);
+        GGML_ASSERT(n_dims == ne00 / 2);
     }
 
-    const int32_t * pos = (const int32_t *) dst->src[1]->data;
+    const int32_t *pos = (const int32_t *)src1_d;
 
-    const float * freq_factors = nullptr;
-    if (dst->src[2] != nullptr) {
-        freq_factors = (const float *) dst->src[2]->data;
+    const float *freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *)src2->data;
     }
 
     rope_corr_dims corr_dims;
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
-
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
+                             beta_slow, corr_dims.v);
 
     // compute
     if (is_neox) {
         GGML_SYCL_DEBUG("%s: neox path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
-                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
-                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
-                           main_stream);
+        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
+            rope_neox_sycl(
+                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
+                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
+            rope_neox_sycl(
+                (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
+                s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
+            rope_neox_sycl(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
         } else {
-            GGML_ABORT("fatal error");
+            GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else if (is_mrope && !is_vision) {
         GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
-                s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                freq_factors, sections, is_imrope, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
-                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
-                             is_imrope, main_stream);
+        if (src0->type == GGML_TYPE_F32) {
+            rope_multi_sycl((const float *)src0_d, (float *)dst_d,
+                                     ne00, ne01, ne02, s01, s02, s03, s1, s2,
+                                     s3, n_dims, nr, pos, freq_scale, freq_base,
+                                     ext_factor, attn_factor, corr_dims,
+                                     freq_factors, sections, is_imrope, stream);
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_multi_sycl(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                sections, is_imrope, stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else if (is_vision) {
         GGML_SYCL_DEBUG("%s: vision path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
-                             s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
-                             freq_factors, sections, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
-                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
-                             main_stream);
+        if (src0->type == GGML_TYPE_F32) {
+            rope_vision_sycl(
+                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
+                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                ext_factor, attn_factor, corr_dims, freq_factors, sections,
+                stream);
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_vision_sycl(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                sections, stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else {
         GGML_SYCL_DEBUG("%s: norm path\n", __func__);
-        if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
-                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
-        } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
-                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
-                           main_stream);
+        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
+            rope_norm_sycl(
+                (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01,
+                s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+                set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
+            rope_norm_sycl(
+                (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02,
+                s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
+        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
+            rope_norm_sycl(
+                (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01,
+                ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                row_indices, set_rows_stride, stream);
         } else {
-            GGML_ABORT("fatal error");
+            GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     }
 }
 
-void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
-    ggml_sycl_op_rope(ctx, dst);
+
+    ggml_sycl_op_rope_impl(ctx, dst);
 }
 
+void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
+    ggml_sycl_op_rope_impl(ctx, dst);
+}
+
+void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope,
+                          ggml_tensor *set_rows) {
+    scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3);
+    ggml_sycl_op_rope_impl(ctx, rope, set_rows);
+}
diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp
index 8c7141aa..b95a5858 100644
--- a/ggml/src/ggml-sycl/rope.hpp
+++ b/ggml/src/ggml-sycl/rope.hpp
@@ -15,6 +15,12 @@
 
 #include "common.hpp"
 
+#define SYCL_ROPE_BLOCK_SIZE 256
+
 void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
 
+void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
+
 #endif // GGML_SYCL_ROPE_HPP
diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp
index b41124ac..fdf9b843 100644
--- a/ggml/src/ggml-sycl/softmax.cpp
+++ b/ggml/src/ggml-sycl/softmax.cpp
@@ -37,7 +37,7 @@ struct soft_max_params {
 };
 
 // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
-// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
+// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here.
 #ifdef __clang__
 #pragma clang diagnostic push
 #pragma clang diagnostic ignored "-Wpass-failed"
@@ -102,7 +102,7 @@ static void soft_max_f32(const float *         x,
         max_val   = sycl::max(max_val, val);
     }
     // find the max value in the block
-    max_val = warp_reduce_max(max_val);
+    max_val = warp_reduce_max(max_val);
 
     if (block_size > WARP_SIZE) {
         if (warp_id == 0) {
@@ -116,7 +116,7 @@ static void soft_max_f32(const float *         x,
         item_ct1.barrier();
 
         max_val = buf_iw[lane_id];
-        max_val = warp_reduce_max(max_val);
+        max_val = warp_reduce_max(max_val);
     }
     float tmp = 0.0f; // partial sum
 
@@ -133,7 +133,7 @@ static void soft_max_f32(const float *         x,
         vals[col] = val;
     }
     // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp);
+    tmp = warp_reduce_sum(tmp);
     if (block_size > WARP_SIZE) {
         item_ct1.barrier();
         if (warp_id == 0) {
@@ -153,7 +153,7 @@ static void soft_max_f32(const float *         x,
         for (size_t i = 1; i < nreduce; i += 1) {
             tmp += buf_iw[lane_id + i * WARP_SIZE];
         }
-        tmp = warp_reduce_sum(tmp);
+        tmp = warp_reduce_sum(tmp);
     }
     if (sinks) {
         tmp += sycl::native::exp(sinks[i02] - max_val);
@@ -191,7 +191,7 @@ static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
         dgf_dot += dstf[col]*grad[col];
     }
 
-    dgf_dot = warp_reduce_sum(dgf_dot);
+    dgf_dot = warp_reduce_sum(dgf_dot);
 
     for (int col = tid; col < ncols; col += WARP_SIZE) {
         dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp
new file mode 100644
index 00000000..5c06d42f
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(112, 112);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp
new file mode 100644
index 00000000..f74e1202
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(128, 128);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp
new file mode 100644
index 00000000..b574fe93
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(256, 256);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp
new file mode 100644
index 00000000..8c8fb692
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(40, 40);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp
new file mode 100644
index 00000000..f218552e
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(576, 512);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp
new file mode 100644
index 00000000..99303a53
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(64, 64);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp
new file mode 100644
index 00000000..50592768
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(72, 72);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp
new file mode 100644
index 00000000..74f1ea5e
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(80, 80);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp
new file mode 100644
index 00000000..cefb46dd
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(96, 96);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp
new file mode 100644
index 00000000..32cf4f28
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp
new file mode 100644
index 00000000..a61a1902
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp
new file mode 100644
index 00000000..63b74fb3
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp
new file mode 100644
index 00000000..46e2d985
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp
new file mode 100644
index 00000000..7aabb6ff
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp
new file mode 100644
index 00000000..148ea217
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp
new file mode 100644
index 00000000..4b169dbc
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp
new file mode 100644
index 00000000..79f530b1
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp
new file mode 100644
index 00000000..2f7db51c
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp
new file mode 100644
index 00000000..9e3bf0b1
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp
new file mode 100644
index 00000000..18081879
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp
new file mode 100644
index 00000000..1c387b0d
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp
new file mode 100644
index 00000000..f005b376
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp
new file mode 100644
index 00000000..3553b1cd
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp
new file mode 100644
index 00000000..687ec567
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp
new file mode 100644
index 00000000..2663bfe7
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp
new file mode 100644
index 00000000..641b7c7a
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp
new file mode 100644
index 00000000..3d3181d4
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp
new file mode 100644
index 00000000..85d5026a
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp
new file mode 100644
index 00000000..1e81401a
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp
new file mode 100644
index 00000000..54251473
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp
new file mode 100644
index 00000000..d418c1fb
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp
new file mode 100644
index 00000000..0f26cfab
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp
new file mode 100644
index 00000000..4fb98723
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp
new file mode 100644
index 00000000..85b79cd1
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp
new file mode 100644
index 00000000..7348323b
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp
new file mode 100644
index 00000000..f19af2aa
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp
new file mode 100644
index 00000000..d7075bac
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp
new file mode 100644
index 00000000..627f9a57
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp
new file mode 100644
index 00000000..23304eec
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp
new file mode 100644
index 00000000..95acb5d4
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp
new file mode 100644
index 00000000..5e88f4ba
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp
new file mode 100644
index 00000000..69f297fe
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp
new file mode 100644
index 00000000..455842a9
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp
new file mode 100644
index 00000000..f7ef7391
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp
new file mode 100644
index 00000000..1c633bdf
--- /dev/null
+++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp
index 43482b36..9a267d85 100644
--- a/ggml/src/ggml-sycl/vecdotq.hpp
+++ b/ggml/src/ggml-sycl/vecdotq.hpp
@@ -650,6 +650,19 @@ static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
     return d8_0*d8_1 * sumi;
 }
 
+template 
+static __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = ggml_sycl_dp4a(v[i], u[i], sumi);
+    }
+
+    return d8_0*d8_1 * ((T) sumi);
+}
+
 template 
 static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
                                                     const sycl::half2 &dm8,
diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp
index c10e2f76..b56e0c24 100644
--- a/ggml/src/ggml-sycl/wkv.cpp
+++ b/ggml/src/ggml-sycl/wkv.cpp
@@ -1,7 +1,7 @@
 #include 
 #include "wkv.hpp"
 
-constexpr int WKV_BLOCK_SIZE = 64;  // Matching CUDA_WKV_BLOCK_SIZE
+constexpr int WKV_BLOCK_SIZE = 64;
 
 // Helper function for the main kernel
 template 
diff --git a/ggml/src/ggml-virtgpu/CMakeLists.txt b/ggml/src/ggml-virtgpu/CMakeLists.txt
new file mode 100644
index 00000000..e6b020be
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/CMakeLists.txt
@@ -0,0 +1,70 @@
+cmake_minimum_required(VERSION 3.19)
+cmake_policy(SET CMP0114 NEW)
+
+include(ExternalProject)
+
+message(STATUS "Including the VirtGPU/Virglrenderer API Remoting")
+
+# Download venus_hw.h from virglrenderer repository
+ExternalProject_Add(
+    venus_hw_header
+    URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h
+    DOWNLOAD_NO_EXTRACT YES
+    DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include
+    DOWNLOAD_NAME venus_hw.h
+    CONFIGURE_COMMAND ""
+    BUILD_COMMAND ""
+    INSTALL_COMMAND ""
+    LOG_DOWNLOAD ON
+)
+
+if (NOT GGML_VIRTGPU_BACKEND STREQUAL "ONLY")
+    message(STATUS "Enable the VirtGPU/Virglrenderer API Remoting frontend library")
+
+    find_package(PkgConfig REQUIRED)
+    pkg_check_modules(DRM REQUIRED libdrm)
+    if (NOT GGML_BACKEND_DL)
+      # cannot simply use USE_VIRTGPU, as in the 'else()' case the
+      # frontend isn't compiled
+      target_compile_definitions(ggml PUBLIC "GGML_USE_VIRTGPU_FRONTEND")
+    endif()
+
+    ggml_add_backend_library(ggml-virtgpu
+                             ggml-backend-buffer.cpp
+                             ggml-backend.cpp
+                             ggml-backend-device.cpp
+                             ggml-backend-reg.cpp
+                             ggml-backend-buffer-type.cpp
+                             virtgpu-apir.h
+                             virtgpu-forward.gen.h
+                             virtgpu.cpp
+                             virtgpu-shm.cpp
+                             virtgpu-utils.cpp
+                             virtgpu-forward-device.cpp
+                             virtgpu-forward-buffer-type.cpp
+                             virtgpu-forward-buffer.cpp
+                             virtgpu-forward-backend.cpp
+                             virtgpu-forward-impl.h
+                             apir_cs_ggml-rpc-front.cpp
+                             ../../include/ggml-virtgpu.h)
+
+    target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/)
+
+    target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES})
+    target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS})
+    target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER})
+
+    target_include_directories(ggml-virtgpu PUBLIC ./include)
+    target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
+
+    # Ensure venus_hw.h is downloaded before building ggml-virtgpu
+    add_dependencies(ggml-virtgpu venus_hw_header)
+
+    target_compile_options(ggml-virtgpu PRIVATE -std=c++20)
+else()
+    message(STATUS "Not building the VirtGPU/Virglrenderer API Remoting frontend library")
+endif()
+
+if (NOT GGML_VIRTGPU_BACKEND STREQUAL "OFF")
+    add_subdirectory("backend")
+endif()
diff --git a/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp
new file mode 100644
index 00000000..d2e87330
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp
@@ -0,0 +1,87 @@
+#include "backend/shared/apir_cs_rpc.h"
+#include "ggml-backend-impl.h"
+#include "ggml-impl.h"
+#include "ggml-remoting.h"
+
+#include 
+#include 
+#include 
+#include 
+
+apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) {
+    apir_rpc_tensor result;
+    result.id   = reinterpret_cast(tensor);
+    result.type = tensor->type;
+    if (tensor->buffer) {
+        ggml_backend_buffer_t buffer = tensor->buffer;
+
+        result.buffer = BUFFER_TO_HOST_HANDLE(buffer);
+    } else {
+        result.buffer = 0;
+    }
+    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
+        result.ne[i] = tensor->ne[i];
+        result.nb[i] = tensor->nb[i];
+    }
+    result.op = tensor->op;
+    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
+        result.op_params[i] = tensor->op_params[i];
+    }
+    result.flags = tensor->flags;
+    for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
+        result.src[i] = reinterpret_cast(tensor->src[i]);
+    }
+    result.view_src  = reinterpret_cast(tensor->view_src);
+    result.view_offs = tensor->view_offs;
+    result.data      = reinterpret_cast(tensor->data);
+    if (tensor->data) {
+        if (!tensor->buffer) {
+            GGML_ABORT("%s: tensor has data but not buffer", __func__);
+        }
+        // tensor->data is serialized as an offset to the buffer base address
+        result.data -= reinterpret_cast(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base);
+    }
+    snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
+    return result;
+}
+
+void apir_add_tensor(ggml_tensor *                       tensor,
+                     std::vector &      tensors,
+                     std::unordered_set & visited) {
+    if (tensor == nullptr) {
+        return;
+    }
+    if (visited.find(tensor) != visited.end()) {
+        return;
+    }
+    visited.insert(tensor);
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        apir_add_tensor(tensor->src[i], tensors, visited);
+    }
+    apir_add_tensor(tensor->view_src, tensors, visited);
+    tensors.push_back(apir_serialize_tensor(tensor));
+}
+
+void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector & output) {
+    uint32_t                          n_nodes = cgraph->n_nodes;
+    std::vector      tensors;
+    std::unordered_set visited;
+    for (uint32_t i = 0; i < n_nodes; i++) {
+        apir_add_tensor(cgraph->nodes[i], tensors, visited);
+    }
+    // serialization format:
+    // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) |
+    uint32_t n_tensors = tensors.size();
+    int      output_size =
+        sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor);
+    output.resize(output_size, 0);
+    memcpy(output.data(), &n_nodes, sizeof(n_nodes));
+    for (uint32_t i = 0; i < n_nodes; i++) {
+        memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
+    }
+    uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
+    *out_ntensors           = n_tensors;
+    apir_rpc_tensor * out_tensors =
+        (apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
+    memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor));
+}
diff --git a/ggml/src/ggml-virtgpu/backend/CMakeLists.txt b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt
new file mode 100644
index 00000000..0b49c403
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt
@@ -0,0 +1,21 @@
+cmake_minimum_required(VERSION 3.19)
+cmake_policy(SET CMP0114 NEW)
+
+message(STATUS "Enable the VirtGPU/Virglrenderer backend library")
+
+ggml_add_backend_library(ggml-virtgpu-backend
+                         backend.cpp
+                         backend-dispatched.cpp
+                         backend-dispatched-backend.cpp
+                         backend-dispatched-device.cpp
+                         backend-dispatched-buffer.cpp
+                         backend-dispatched-buffer-type.cpp
+                         shared/api_remoting.h
+                         shared/apir_backend.h
+                         shared/apir_cs.h
+                         apir_cs_ggml-rpc-back.cpp)
+
+target_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20)
+
+# Add include directory for ggml-backend-impl.h and other core headers
+target_include_directories(ggml-virtgpu-backend PRIVATE ../..)
diff --git a/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp
new file mode 100644
index 00000000..60a8a93b
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp
@@ -0,0 +1,115 @@
+#include "ggml-backend-impl.h"
+#include "ggml-impl.h"
+#include "shared/apir_cs_rpc.h"
+
+#include 
+#include 
+#include 
+#include 
+
+std::unordered_set backend_buffers;
+
+void apir_track_backend_buffer(ggml_backend_buffer_t buffer) {
+    backend_buffers.insert(buffer);
+}
+
+bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) {
+    auto it = backend_buffers.find(buffer);
+    if (it == backend_buffers.end()) {
+        return false;
+    }
+
+    backend_buffers.erase(it);
+    return true;
+}
+
+std::unordered_set apir_get_track_backend_buffers() {
+    return backend_buffers;
+}
+
+ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) {
+    ggml_tensor * result =
+        ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
+    for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = tensor->nb[i];
+    }
+    result->buffer = reinterpret_cast(tensor->buffer);
+    if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) {
+        printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer);
+        result->buffer = nullptr;
+    }
+
+    uint64_t tensor_data = tensor->data;
+    if (result->buffer) {
+        // require that the tensor data does not go beyond the buffer end
+        uint64_t tensor_size  = (uint64_t) ggml_nbytes(result);
+        uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
+        uint64_t buffer_size  = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
+
+        // tensor->data is serialized as an offset to the buffer base address
+        tensor_data += buffer_start;
+
+        GGML_ASSERT(tensor_data + tensor_size >= tensor_data);  // check for overflow
+        GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size);
+    }
+
+    result->op = (ggml_op) tensor->op;
+    for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
+        result->op_params[i] = tensor->op_params[i];
+    }
+    result->flags = tensor->flags;
+    result->data  = reinterpret_cast(tensor_data);
+    ggml_set_name(result, tensor->name);
+    return result;
+}
+
+ggml_tensor * apir_create_node(uint64_t                                                      id,
+                               ggml_context *                                                ctx,
+                               const std::unordered_map & tensor_ptrs,
+                               std::unordered_map &                 tensor_map) {
+    if (id == 0) {
+        return nullptr;
+    }
+    if (tensor_map.find(id) != tensor_map.end()) {
+        return tensor_map[id];
+    }
+    const apir_rpc_tensor * tensor = tensor_ptrs.at(id);
+    ggml_tensor *           result = apir_deserialize_tensor(ctx, tensor);
+    if (result == nullptr) {
+        return nullptr;
+    }
+    tensor_map[id] = result;
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
+    }
+    result->view_src  = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
+    result->view_offs = tensor->view_offs;
+    return result;
+}
+
+ggml_cgraph * apir_deserialize_graph(uint32_t                n_nodes,
+                                     uint32_t                n_tensors,
+                                     const apir_rpc_tensor * tensors,
+                                     const uint64_t *        nodes) {
+    size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
+    ggml_init_params params = {
+        /*.mem_size   =*/buf_size,
+        /*.mem_buffer =*/NULL,
+        /*.no_alloc   =*/true,
+    };
+    ggml_context * ctx   = ggml_init(params);
+    ggml_cgraph *  graph = ggml_new_graph_custom(ctx, n_nodes, false);
+    graph->n_nodes       = n_nodes;
+    std::unordered_map tensor_ptrs;
+    for (uint32_t i = 0; i < n_tensors; i++) {
+        tensor_ptrs[tensors[i].id] = &tensors[i];
+    }
+    std::unordered_map tensor_map;
+    for (uint32_t i = 0; i < n_nodes; i++) {
+        int64_t id;
+        memcpy(&id, &nodes[i], sizeof(id));
+        graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map);
+    }
+
+    return graph;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-convert.h b/ggml/src/ggml-virtgpu/backend/backend-convert.h
new file mode 100644
index 00000000..1978d21f
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-convert.h
@@ -0,0 +1,13 @@
+#include "shared/apir_backend.h"
+
+#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name)
+
+static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) {
+    // in the backend, the buffer handle is the buffer pointer
+    return (apir_buffer_host_handle_t) buffer;
+}
+
+static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) {
+    // in the backend, the buffer handle is the buffer pointer
+    return (apir_buffer_type_host_handle_t) buft;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp
new file mode 100644
index 00000000..03a037f1
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp
@@ -0,0 +1,102 @@
+#include "backend-dispatched.h"
+#include "backend-virgl-apir.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "shared/apir_backend.h"
+
+#include 
+
+static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) {
+    if (cgraph_size == 0) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation);
+        return 1;
+    }
+
+    // place-holder: validate that the size of shmem_res_id is <= cgraph_size
+    // need to add another method in the Virgl->APIR callback interface
+    GGML_UNUSED(shmem_res_id);
+
+    return 0;  // Valid
+}
+
+uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+
+    static bool async_backend_initialized = false;
+    static bool async_backend;
+
+    if (!async_backend_initialized) {
+        ggml_backend_dev_props props;
+
+        dev->iface.get_props(dev, &props);
+        async_backend             = props.caps.async;
+        async_backend_initialized = true;
+    }
+
+    uint32_t shmem_res_id;
+    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
+
+    const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
+    if (!shmem_data) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
+        apir_decoder_set_fatal(dec);
+        return 1;
+    }
+    size_t cgraph_size;
+    apir_decode_size_t(dec, &cgraph_size);
+
+    if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) {
+        apir_decoder_set_fatal(dec);
+        return 1;
+    }
+
+    apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size);
+
+    ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size);
+
+    if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__);
+        return 1;
+    }
+
+    if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__,
+                       cgraph->n_nodes, cgraph->n_leafs);
+        return 1;
+    }
+
+    ggml_status status;
+#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1
+    for (int idx = 0; idx < cgraph->n_nodes; idx++) {
+        ggml_tensor * op = ggml_graph_node(cgraph, idx);
+        if (dev->iface.supports_op(dev, op)) {
+            continue;
+        }
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx,
+                       ggml_op_desc(op));
+
+        status = GGML_STATUS_ABORTED;
+        apir_encode_ggml_status(enc, &status);
+
+        return 0;
+    }
+#endif
+
+    // Check if backend is properly initialized
+    if (!bck) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__);
+
+        return 1;
+    }
+
+    status = bck->iface.graph_compute(bck, cgraph);
+
+    if (async_backend && bck->iface.synchronize) {
+        bck->iface.synchronize(bck);
+    }
+
+    apir_encode_ggml_status(enc, &status);
+
+    return 0;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp
new file mode 100644
index 00000000..c66dbaa9
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp
@@ -0,0 +1,105 @@
+#include "backend-dispatched.h"
+#include "backend-virgl-apir.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+
+#include 
+
+uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    ggml_backend_buffer_type_t buft;
+    buft = apir_decode_ggml_buffer_type(dec);
+
+    const char * string = buft->iface.get_name(buft);
+
+    const size_t string_size = strlen(string) + 1;
+    apir_encode_array_size(enc, string_size);
+    apir_encode_char_array(enc, string, string_size);
+
+    return 0;
+}
+
+uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    ggml_backend_buffer_type_t buft;
+    buft = apir_decode_ggml_buffer_type(dec);
+
+    size_t value = buft->iface.get_alignment(buft);
+    apir_encode_size_t(enc, &value);
+
+    return 0;
+}
+
+uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    ggml_backend_buffer_type_t buft;
+    buft = apir_decode_ggml_buffer_type(dec);
+
+    size_t value = SIZE_MAX;
+    if (buft->iface.get_max_size) {
+        value = buft->iface.get_max_size(buft);
+    }
+
+    apir_encode_size_t(enc, &value);
+
+    return 0;
+}
+
+/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */
+uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+    const bool is_host = false;
+
+    apir_encode_bool_t(enc, &is_host);
+
+    return 0;
+}
+
+uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    ggml_backend_buffer_type_t buft;
+    buft = apir_decode_ggml_buffer_type(dec);
+
+    size_t size;
+    apir_decode_size_t(dec, &size);
+
+    ggml_backend_buffer_t buffer;
+
+    buffer = buft->iface.alloc_buffer(buft, size);
+
+    apir_encode_ggml_buffer(enc, buffer);
+
+    if (buffer) {
+        apir_track_backend_buffer(buffer);
+    }
+
+    return 0;
+}
+
+uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    ggml_backend_buffer_type_t buft;
+    buft = apir_decode_ggml_buffer_type(dec);
+
+    const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
+
+    // Check for decode error
+    if (op == nullptr) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__);
+        apir_decoder_set_fatal(dec);
+        return 1;
+    }
+
+    size_t value;
+    if (buft->iface.get_alloc_size) {
+        value = buft->iface.get_alloc_size(buft, op);
+    } else {
+        value = ggml_nbytes(op);  // Default fallback
+    }
+
+    apir_encode_size_t(enc, &value);
+
+    return 0;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp
new file mode 100644
index 00000000..3ade8d99
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp
@@ -0,0 +1,179 @@
+#include "backend-dispatched.h"
+#include "backend-virgl-apir.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+
+#include 
+
+static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) {
+    // Only check for critical integer overflow - no arbitrary size limits
+    if (offset > SIZE_MAX - size) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size);
+        return 1;
+    }
+
+    return 0;  // Valid
+}
+
+uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    ggml_backend_buffer_t buffer;
+    buffer = apir_decode_ggml_buffer(dec);
+
+    if (!buffer || apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
+        return 1;
+    }
+
+    uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer);
+    apir_encode_uintptr_t(enc, &base);
+
+    return 0;
+}
+
+uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(enc);
+
+    ggml_backend_buffer_t buffer;
+    buffer = apir_decode_ggml_buffer(dec);
+
+    if (!buffer || apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
+        return 1;
+    }
+
+    ggml_tensor * tensor;
+    // safe to remove the const qualifier here
+    tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
+
+    uint32_t shmem_res_id;
+    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
+
+    size_t offset;
+    apir_decode_size_t(dec, &offset);
+
+    size_t size;
+    apir_decode_size_t(dec, &size);
+
+    if (validate_buffer_operation(offset, size, __func__) != 0) {
+        return 1;
+    }
+
+    void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
+
+    if (!shmem_data) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
+        return 1;
+    }
+
+    buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size);
+
+    return 0;
+}
+
+uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(enc);
+
+    ggml_backend_buffer_t buffer;
+    buffer = apir_decode_ggml_buffer(dec);
+
+    if (!buffer || apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
+        return 1;
+    }
+
+    const ggml_tensor * tensor;
+    // safe to remove the const qualifier here
+    tensor = apir_decode_ggml_tensor(dec);
+
+    uint32_t shmem_res_id;
+    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
+
+    size_t offset;
+    apir_decode_size_t(dec, &offset);
+
+    size_t size;
+    apir_decode_size_t(dec, &size);
+
+    if (validate_buffer_operation(offset, size, __func__) != 0) {
+        return 1;
+    }
+
+    void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
+    if (!shmem_data) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
+        return 1;
+    }
+
+    buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size);
+
+    return 0;
+}
+
+uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+
+    ggml_backend_buffer_t buffer;
+    buffer = apir_decode_ggml_buffer(dec);
+
+    if (!buffer || apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
+        return 1;
+    }
+
+    const ggml_tensor * src;
+    // safe to remove the const qualifier here
+    src               = apir_decode_ggml_tensor(dec);
+    ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec);
+
+    bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst);
+
+    apir_encode_bool_t(enc, &ret);
+
+    return 0;
+}
+
+uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(enc);
+
+    ggml_backend_buffer_t buffer;
+    buffer = apir_decode_ggml_buffer(dec);
+
+    if (!buffer || apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
+        return 1;
+    }
+
+    uint8_t value;
+    apir_decode_uint8_t(dec, &value);
+
+    buffer->iface.clear(buffer, value);
+
+    return 0;
+}
+
+uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(enc);
+
+    ggml_backend_buffer_t buffer;
+    buffer = apir_decode_ggml_buffer(dec);
+
+    if (!buffer || apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__);
+        return 1;
+    }
+
+    if (!apir_untrack_backend_buffer(buffer)) {
+        GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer);
+        return 1;
+    }
+
+    buffer->iface.free_buffer(buffer);
+
+    return 0;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp
new file mode 100644
index 00000000..c7acb8b5
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp
@@ -0,0 +1,148 @@
+#include "backend-dispatched.h"
+#include "backend-virgl-apir.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+
+#include 
+
+uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    int32_t dev_count = reg->iface.get_device_count(reg);
+    apir_encode_int32_t(enc, &dev_count);
+
+    return 0;
+}
+
+uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    int32_t dev_count = reg->iface.get_device_count(reg);
+    apir_encode_int32_t(enc, &dev_count);
+
+    return 0;
+}
+
+uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    const char * string = dev->iface.get_name(dev);
+
+    const size_t string_size = strlen(string) + 1;
+    apir_encode_array_size(enc, string_size);
+    apir_encode_char_array(enc, string, string_size);
+
+    return 0;
+}
+
+uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    const char * string = dev->iface.get_description(dev);
+
+    const size_t string_size = strlen(string) + 1;
+    apir_encode_array_size(enc, string_size);
+    apir_encode_char_array(enc, string, string_size);
+
+    return 0;
+}
+
+uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    uint32_t type = dev->iface.get_type(dev);
+    apir_encode_uint32_t(enc, &type);
+
+    return 0;
+}
+
+uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    size_t free, total;
+    dev->iface.get_memory(dev, &free, &total);
+
+    apir_encode_size_t(enc, &free);
+    apir_encode_size_t(enc, &total);
+
+    return 0;
+}
+
+uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+
+    const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec);
+
+    bool supports_op = dev->iface.supports_op(dev, op);
+
+    apir_encode_bool_t(enc, &supports_op);
+
+    return 0;
+}
+
+uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev);
+
+    apir_encode_ggml_buffer_type(enc, bufft);
+
+    return 0;
+}
+
+uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    ggml_backend_dev_props props;
+    dev->iface.get_props(dev, &props);
+
+    apir_encode_bool_t(enc, &props.caps.async);
+    apir_encode_bool_t(enc, &props.caps.host_buffer);
+    apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr);
+    apir_encode_bool_t(enc, &props.caps.events);
+
+    return 0;
+}
+
+uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) {
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dec);
+
+    uint32_t shmem_res_id;
+    apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id);
+
+    void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id);
+    if (!shmem_ptr) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__);
+        apir_decoder_set_fatal(dec);
+        return 1;
+    }
+
+    size_t size;
+    apir_decode_size_t(dec, &size);
+    size_t max_tensor_size;
+    apir_decode_size_t(dec, &max_tensor_size);
+
+    ggml_backend_buffer_t buffer;
+    buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size);
+
+    apir_encode_ggml_buffer(enc, buffer);
+    apir_encode_ggml_buffer_type(enc, buffer->buft);
+
+    if (buffer) {
+        apir_track_backend_buffer(buffer);
+    }
+
+    return 0;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp
new file mode 100644
index 00000000..c80e4aab
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp
@@ -0,0 +1,51 @@
+#include "backend-dispatched.h"
+
+#include "backend-virgl-apir.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+
+#include 
+
+ggml_backend_reg_t reg = NULL;
+ggml_backend_dev_t dev = NULL;
+ggml_backend_t     bck = NULL;
+
+uint64_t timer_start = 0;
+uint64_t timer_total = 0;
+uint64_t timer_count = 0;
+
+uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) {
+    if (reg != NULL) {
+        GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: already initialized\n", __func__);
+        return APIR_BACKEND_INITIALIZE_ALREADY_INITED;
+    }
+    ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p;
+
+    reg = ggml_backend_reg_fct();
+    if (reg == NULL) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend registration failed\n", __func__);
+        return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED;
+    }
+
+    size_t device_count = reg->iface.get_device_count(reg);
+    if (!device_count) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__);
+        return APIR_BACKEND_INITIALIZE_NO_DEVICE;
+    }
+
+    dev = reg->iface.get_device(reg, 0);
+
+    if (!dev) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__);
+        return APIR_BACKEND_INITIALIZE_NO_DEVICE;
+    }
+
+    bck = dev->iface.init_backend(dev, NULL);
+    if (!bck) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__);
+        return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED;
+    }
+
+    return APIR_BACKEND_INITIALIZE_SUCCESS;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h
new file mode 100644
index 00000000..3dc334e4
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h
@@ -0,0 +1,73 @@
+#pragma once
+
+/* device */
+uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+
+/* buffer-type */
+uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */
+uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+
+/* buffer */
+uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+
+/* backend */
+uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+
+extern "C" {
+static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {
+
+    /* device */
+
+    /* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT  = */ backend_device_get_device_count,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_COUNT  = */ backend_device_get_count,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_NAME  = */ backend_device_get_name,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION  = */ backend_device_get_description,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_TYPE  = */ backend_device_get_type,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY  = */ backend_device_get_memory,
+    /* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP  = */ backend_device_supports_op,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE  = */ backend_device_get_buffer_type,
+    /* APIR_COMMAND_TYPE_DEVICE_GET_PROPS  = */ backend_device_get_props,
+    /* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR  = */ backend_device_buffer_from_ptr,
+
+    /* buffer-type */
+
+    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME  = */ backend_buffer_type_get_name,
+    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT  = */ backend_buffer_type_get_alignment,
+    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE  = */ backend_buffer_type_get_max_size,
+    /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST  = */ backend_buffer_type_is_host /* DEPRECATED */,
+    /* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER  = */ backend_buffer_type_alloc_buffer,
+    /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE  = */ backend_buffer_type_get_alloc_size,
+
+    /* buffer */
+
+    /* APIR_COMMAND_TYPE_BUFFER_GET_BASE  = */ backend_buffer_get_base,
+    /* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR  = */ backend_buffer_set_tensor,
+    /* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR  = */ backend_buffer_get_tensor,
+    /* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR  = */ backend_buffer_cpy_tensor,
+    /* APIR_COMMAND_TYPE_BUFFER_CLEAR  = */ backend_buffer_clear,
+    /* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER  = */ backend_buffer_free_buffer,
+
+    /* backend */
+
+    /* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE  = */ backend_backend_graph_compute,
+};
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h
new file mode 100644
index 00000000..740ee9e3
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h
@@ -0,0 +1,27 @@
+#pragma once
+
+// clang-format off
+#include 
+#include 
+
+#include 
+
+#include "backend-convert.h"
+#include "backend-virgl-apir.h"
+#include "shared/apir_backend.h"
+#include "shared/apir_cs.h"
+#include "shared/apir_cs_ggml.h"
+// clang-format on
+
+#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: "
+
+struct virgl_apir_context {
+    uint32_t               ctx_id;
+    virgl_apir_callbacks * iface;
+};
+
+typedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx);
+
+#include "backend-dispatched.gen.h"
+
+uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p);
diff --git a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h
new file mode 100644
index 00000000..c65a01cd
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "shared/api_remoting.h"
+
+#include 
+#include 
+#include 
+
+extern ggml_backend_reg_t reg;
+extern ggml_backend_dev_t dev;
+extern ggml_backend_t     bck;
+
+struct virgl_apir_callbacks {
+    const char * (*get_config)(uint32_t virgl_ctx_id, const char * key);
+    void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id);
+};
+
+extern "C" {
+ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs);
+void                      apir_backend_deinit(uint32_t virgl_ctx_id);
+uint32_t                  apir_backend_dispatcher(uint32_t               virgl_ctx_id,
+                                                  virgl_apir_callbacks * virgl_cbs,
+                                                  uint32_t               cmd_type,
+                                                  char *                 dec_cur,
+                                                  const char *           dec_end,
+                                                  char *                 enc_cur,
+                                                  const char *           enc_end,
+                                                  char **                enc_cur_after);
+}
diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp
new file mode 100644
index 00000000..535a05f3
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/backend.cpp
@@ -0,0 +1,144 @@
+#include "backend-dispatched.h"
+#include "backend-virgl-apir.h"
+#include "shared/api_remoting.h"
+#include "shared/apir_backend.h"
+#include "shared/apir_cs.h"
+
+#include 
+#include 
+
+#include 
+
+#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH"
+#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV  "APIR_LLAMA_CPP_GGML_LIBRARY_REG"
+#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV       "APIR_LLAMA_CPP_LOG_TO_FILE"
+
+#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init"
+
+static void * backend_library_handle = NULL;
+static FILE * apir_logfile           = NULL;
+
+static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) {
+    FILE * logfile = (FILE *) user_data;
+    fprintf(logfile, "[%d] %s", level, text);
+    fflush(logfile);
+}
+
+extern "C" {
+void apir_backend_deinit(uint32_t virgl_ctx_id) {
+    GGML_UNUSED(virgl_ctx_id);
+
+    auto buffers = apir_get_track_backend_buffers();
+    for (const auto & buffer : buffers) {
+        apir_untrack_backend_buffer(buffer);
+        buffer->iface.free_buffer(buffer);
+    }
+
+    if (backend_library_handle) {
+        GGML_LOG_INFO(GGML_VIRTGPU_BCK "The GGML backend library was loaded. Unloading it.\n");
+        dlclose(backend_library_handle);
+        backend_library_handle = NULL;
+    }
+
+    if (apir_logfile) {
+        fclose(apir_logfile);
+        apir_logfile = NULL;
+    }
+}
+
+#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path"
+#define APIR_GGML_LIBRARY_REG_KEY  "ggml.library.reg"
+
+ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) {
+    const char * dlsym_error;
+
+    const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV);
+    if (apir_log_to_file) {
+        apir_logfile = fopen(apir_log_to_file, "w");
+        if (apir_logfile) {
+            ggml_log_set(log_to_file_callback, apir_logfile);
+        } else {
+            GGML_LOG_INFO(GGML_VIRTGPU_BCK "Could not open the log file at '%s'\n", apir_log_to_file);
+        }
+    }
+
+    const char * library_name      = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY);
+    const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY);
+    const char * library_reg       = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG;
+
+    if (!library_name) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__,
+                       APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV);
+
+        return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
+    }
+
+    backend_library_handle = dlopen(library_name, RTLD_LAZY);
+
+    if (!backend_library_handle) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror());
+
+        return APIR_LOAD_LIBRARY_CANNOT_OPEN;
+    }
+
+    if (!library_reg) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__,
+                       APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV);
+
+        return APIR_LOAD_LIBRARY_ENV_VAR_MISSING;
+    }
+
+    void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg);
+    dlsym_error                 = dlerror();
+    if (dlsym_error) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n",
+                       __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error);
+
+        return APIR_LOAD_LIBRARY_SYMBOL_MISSING;
+    }
+
+    uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct);
+
+    return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret);
+}
+
+uint32_t apir_backend_dispatcher(uint32_t               virgl_ctx_id,
+                                 virgl_apir_callbacks * virgl_cbs,
+                                 uint32_t               cmd_type,
+                                 char *                 dec_cur,
+                                 const char *           dec_end,
+                                 char *                 enc_cur,
+                                 const char *           enc_end,
+                                 char **                enc_cur_after) {
+    apir_encoder enc = {
+        .cur   = enc_cur,
+        .start = enc_cur,
+        .end   = enc_end,
+        .fatal = false,
+    };
+
+    apir_decoder dec = {
+        .cur   = dec_cur,
+        .end   = dec_end,
+        .fatal = false,
+    };
+
+    virgl_apir_context ctx = {
+        .ctx_id = virgl_ctx_id,
+        .iface  = virgl_cbs,
+    };
+
+    if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) {
+        GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type,
+                       APIR_BACKEND_DISPATCH_TABLE_COUNT);
+        return APIR_BACKEND_FORWARD_INDEX_INVALID;
+    }
+
+    backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type];
+    uint32_t           ret         = forward_fct(&enc, &dec, &ctx);
+
+    *enc_cur_after = enc.cur;
+
+    return ret;
+}
+}
diff --git a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h
new file mode 100644
index 00000000..6bf97e8a
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h
@@ -0,0 +1,95 @@
+#pragma once
+
+/* the rest of this file must match virglrenderer/src/apir-protocol.h */
+
+#include 
+
+#include 
+
+#define APIR_PROTOCOL_MAJOR 0
+#define APIR_PROTOCOL_MINOR 1
+
+#define APIR_HANDSHAKE_MAGIC 0xab1e
+
+enum ApirCommandType {
+    APIR_COMMAND_TYPE_HANDSHAKE   = 0,
+    APIR_COMMAND_TYPE_LOADLIBRARY = 1,
+    APIR_COMMAND_TYPE_FORWARD     = 2,
+
+    APIR_COMMAND_TYPE_LENGTH = 3,
+};
+
+typedef uint64_t ApirCommandFlags;
+
+enum ApirLoadLibraryReturnCode {
+    APIR_LOAD_LIBRARY_SUCCESS                        = 0,
+    // these error codes are returned by the Virglrenderer APIR component
+    APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1,
+    APIR_LOAD_LIBRARY_ALREADY_LOADED                 = 2,
+    APIR_LOAD_LIBRARY_ENV_VAR_MISSING                = 3,
+    APIR_LOAD_LIBRARY_CANNOT_OPEN                    = 4,
+    APIR_LOAD_LIBRARY_SYMBOL_MISSING                 = 5,
+    // any value greater than this is an APIR *backend library* initialization return code
+    APIR_LOAD_LIBRARY_INIT_BASE_INDEX                = 6,
+};
+
+enum ApirForwardReturnCode {
+    APIR_FORWARD_SUCCESS                = 0,
+    // these error codes are returned by the Virglrenderer APIR component
+    APIR_FORWARD_NO_DISPATCH_FCT        = 1,
+    APIR_FORWARD_TIMEOUT                = 2,
+    APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3,
+    // any value greater than this index an APIR *backend library* forward return code
+    APIR_FORWARD_BASE_INDEX             = 4,
+};
+
+__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) {
+    switch (type) {
+        case APIR_COMMAND_TYPE_HANDSHAKE:
+            return "HandShake";
+        case APIR_COMMAND_TYPE_LOADLIBRARY:
+            return "LoadLibrary";
+        case APIR_COMMAND_TYPE_FORWARD:
+            return "Forward";
+        default:
+            return "unknown";
+    }
+}
+
+__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) {
+#define APIR_LOAD_LIBRARY_ERROR(code_name) \
+    do {                                   \
+        if (code == code_name)             \
+            return #code_name;             \
+    } while (0)
+
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS);
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR);
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED);
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING);
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN);
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING);
+    APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
+
+    return "Unknown APIR_COMMAND_TYPE_LoadLibrary error";
+
+#undef APIR_LOAD_LIBRARY_ERROR
+}
+
+__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) {
+#define APIR_FORWARD_ERROR(code_name) \
+    do {                              \
+        if (code == code_name)        \
+            return #code_name;        \
+    } while (0)
+
+    APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS);
+    APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT);
+    APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT);
+    APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS);
+    APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX);
+
+    return "Unknown APIR_COMMAND_TYPE_FORWARD error";
+
+#undef APIR_FORWARD_ERROR
+}
diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h
new file mode 100644
index 00000000..520ac9c7
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h
@@ -0,0 +1,94 @@
+typedef enum ApirBackendCommandType {
+
+    /* device */
+    APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0,
+    APIR_COMMAND_TYPE_DEVICE_GET_COUNT        = 1,
+    APIR_COMMAND_TYPE_DEVICE_GET_NAME         = 2,
+    APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION  = 3,
+    APIR_COMMAND_TYPE_DEVICE_GET_TYPE         = 4,
+    APIR_COMMAND_TYPE_DEVICE_GET_MEMORY       = 5,
+    APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP      = 6,
+    APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE  = 7,
+    APIR_COMMAND_TYPE_DEVICE_GET_PROPS        = 8,
+    APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR  = 9,
+
+    /* buffer-type */
+    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME       = 10,
+    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT  = 11,
+    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE   = 12,
+    APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST        = 13,
+    APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER   = 14,
+    APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15,
+
+    /* buffer */
+    APIR_COMMAND_TYPE_BUFFER_GET_BASE    = 16,
+    APIR_COMMAND_TYPE_BUFFER_SET_TENSOR  = 17,
+    APIR_COMMAND_TYPE_BUFFER_GET_TENSOR  = 18,
+    APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR  = 19,
+    APIR_COMMAND_TYPE_BUFFER_CLEAR       = 20,
+    APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21,
+
+    /* backend */
+    APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22,
+
+    // last command_type index + 1
+    APIR_BACKEND_DISPATCH_TABLE_COUNT = 23,
+} ApirBackendCommandType;
+
+static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {
+    switch (type) {
+        /* device */
+        case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT:
+            return "device_get_device_count";
+        case APIR_COMMAND_TYPE_DEVICE_GET_COUNT:
+            return "device_get_count";
+        case APIR_COMMAND_TYPE_DEVICE_GET_NAME:
+            return "device_get_name";
+        case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION:
+            return "device_get_description";
+        case APIR_COMMAND_TYPE_DEVICE_GET_TYPE:
+            return "device_get_type";
+        case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY:
+            return "device_get_memory";
+        case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP:
+            return "device_supports_op";
+        case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE:
+            return "device_get_buffer_type";
+        case APIR_COMMAND_TYPE_DEVICE_GET_PROPS:
+            return "device_get_props";
+        case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR:
+            return "device_buffer_from_ptr";
+        /* buffer-type */
+        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME:
+            return "buffer_type_get_name";
+        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT:
+            return "buffer_type_get_alignment";
+        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE:
+            return "buffer_type_get_max_size";
+        case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST:
+            return "buffer_type_is_host";
+        case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER:
+            return "buffer_type_alloc_buffer";
+        case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE:
+            return "buffer_type_get_alloc_size";
+        /* buffer */
+        case APIR_COMMAND_TYPE_BUFFER_GET_BASE:
+            return "buffer_get_base";
+        case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR:
+            return "buffer_set_tensor";
+        case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR:
+            return "buffer_get_tensor";
+        case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR:
+            return "buffer_cpy_tensor";
+        case APIR_COMMAND_TYPE_BUFFER_CLEAR:
+            return "buffer_clear";
+        case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER:
+            return "buffer_free_buffer";
+        /* backend */
+        case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE:
+            return "backend_graph_compute";
+
+        default:
+            return "unknown";
+    }
+}
diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h
new file mode 100644
index 00000000..da1e21b5
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include "apir_backend.gen.h"
+
+#include   // for uintptr_t
+#include     // for timespec, clock_gettime
+
+#define APIR_BACKEND_INITIALIZE_SUCCESS                     0
+#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1
+#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY    2
+#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS     3
+#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS        4
+#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED              5
+#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED          6
+#define APIR_BACKEND_INITIALIZE_ALREADY_INITED              7
+#define APIR_BACKEND_INITIALIZE_NO_DEVICE                   8
+#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED         9
+
+// new entries here need to be added to the apir_backend_initialize_error function below
+
+#define APIR_BACKEND_FORWARD_INDEX_INVALID 6
+
+// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received
+#define APIR_BACKEND_CHECK_SUPPORTS_OP 0
+
+typedef uintptr_t apir_buffer_type_host_handle_t;
+typedef uintptr_t apir_buffer_host_handle_t;
+
+static const char * apir_backend_initialize_error(int code) {
+#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \
+    do {                                         \
+        if (code == code_name)                   \
+            return #code_name;                   \
+    } while (0)
+
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE);
+    APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED);
+
+    return "Unknown APIR_BACKEND_INITIALIZE error:/";
+
+#undef APIR_BACKEND_INITIALIZE_ERROR
+}
diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h
new file mode 100644
index 00000000..64bf2ec9
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h
@@ -0,0 +1,378 @@
+#pragma once
+
+#include "ggml-impl.h"
+
+#include 
+#include 
+
+#define likely(x)   __builtin_expect(!!(x), 1)
+#define unlikely(x) __builtin_expect(!!(x), 0)
+
+struct apir_encoder {
+    char *       cur;
+    const char * start;
+    const char * end;
+    bool         fatal;
+};
+
+struct apir_decoder {
+    const char * cur;
+    const char * end;
+    bool         fatal;
+};
+
+/*
+ * new encoder and decoder
+ */
+
+static apir_decoder apir_new_decoder(const char * ptr, size_t size) {
+    apir_decoder dec = {
+        .cur   = ptr,
+        .end   = ptr + size,
+        .fatal = false,
+    };
+
+    return dec;
+}
+
+static apir_encoder apir_new_encoder(char * ptr, size_t size) {
+    apir_encoder enc = {
+        .cur   = ptr,
+        .start = ptr,
+        .end   = ptr + size,
+        .fatal = false,
+    };
+
+    return enc;
+}
+
+/*
+ * fatal flag handling
+ */
+
+static inline void apir_encoder_reset_fatal(apir_encoder * enc) {
+    enc->fatal = false;
+}
+
+static inline void apir_encoder_set_fatal(apir_encoder * enc) {
+    enc->fatal = true;
+}
+
+static inline bool apir_encoder_get_fatal(const apir_encoder * enc) {
+    return enc->fatal;
+}
+
+static inline void apir_decoder_reset_fatal(apir_decoder * dec) {
+    dec->fatal = false;
+}
+
+static inline void apir_decoder_set_fatal(apir_decoder * dec) {
+    dec->fatal = true;
+}
+
+static inline bool apir_decoder_get_fatal(const apir_decoder * dec) {
+    return dec->fatal;
+}
+
+/*
+ * encode peek
+ */
+
+static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) {
+    assert(val_size <= size);
+
+    if (unlikely(size > (size_t) (dec->end - dec->cur))) {
+        GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__);
+        apir_decoder_set_fatal(dec);
+        memset(val, 0, val_size);
+        return false;
+    }
+
+    /* we should not rely on the compiler to optimize away memcpy... */
+    memcpy(val, dec->cur, val_size);
+    return true;
+}
+
+static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) {
+    apir_decoder_peek_internal(dec, size, val, val_size);
+}
+
+static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) {
+    if (unlikely(size > (size_t) (dec->end - dec->cur))) {
+        GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__);
+        apir_decoder_set_fatal(dec);
+        return NULL;
+    }
+    const void * addr = dec->cur;
+    dec->cur += size;
+
+    return addr;
+}
+
+/*
+ * read/write
+ */
+
+static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) {
+    if (apir_decoder_peek_internal(dec, size, val, val_size)) {
+        dec->cur += size;
+    }
+}
+
+static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) {
+    assert(val_size <= size);
+    assert(size <= ((size_t) (enc->end - enc->cur)));
+
+    char * write_addr = enc->cur;
+    /* we should not rely on the compiler to optimize away memcpy... */
+    memcpy(write_addr, val, val_size);
+    enc->cur += size;
+
+    return write_addr;
+}
+
+/*
+ * encode/decode
+ */
+
+static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) {
+    assert(size % 4 == 0);
+    apir_decoder_read(dec, size, data, data_size);
+}
+
+static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) {
+    assert(size % 4 == 0);
+    apir_encoder_write(enc, size, data, data_size);
+}
+
+/*
+ * typed encode/decode
+ */
+
+/* uint8_t */
+
+static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) {
+    apir_encode(enc, sizeof(int), val, sizeof(*val));
+}
+
+static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) {
+    apir_decode(dec, sizeof(int), val, sizeof(*val));
+}
+
+/* uint64_t */
+
+static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) {
+    apir_encode(enc, 8, val, sizeof(*val));
+}
+
+static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) {
+    apir_decode(dec, 8, val, sizeof(*val));
+}
+
+static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) {
+    const size_t size = sizeof(*val) * count;
+    assert(size >= count);
+    apir_encode(enc, size, val, size);
+}
+
+static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) {
+    const size_t size = sizeof(*val) * count;
+    assert(size >= count);
+    apir_decode(dec, size, val, size);
+}
+
+static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) {
+    return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t));
+}
+
+/* int32_t */
+
+static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) {
+    apir_encode(enc, 4, val, sizeof(*val));
+}
+
+static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) {
+    apir_decode(dec, 4, val, sizeof(*val));
+}
+
+static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) {
+    const size_t size = sizeof(*val) * count;
+    assert(size >= count);
+    apir_encode(enc, size, val, size);
+}
+
+static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) {
+    const size_t size = sizeof(*val) * count;
+    assert(size >= count);
+    apir_decode(dec, size, val, size);
+}
+
+/* array size (uint64_t) */
+
+static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) {
+    apir_encode_uint64_t(enc, &size);
+}
+
+static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) {
+    uint64_t size;
+    apir_decode_uint64_t(dec, &size);
+    if (size != expected_size) {
+        GGML_LOG_ERROR("%s: Couldn't decode array from the decoder\n", __func__);
+        apir_decoder_set_fatal(dec);
+        size = 0;
+    }
+    return size;
+}
+
+static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) {
+    uint64_t size;
+    apir_decode_uint64_t(dec, &size);
+    return size;
+}
+
+/* non-array pointer */
+
+static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) {
+    apir_encode_array_size(enc, val ? 1 : 0);
+    return val;
+}
+
+static inline bool apir_decode_simple_pointer(apir_decoder * dec) {
+    return apir_decode_array_size_unchecked(dec);
+}
+
+/* uint32_t */
+
+static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) {
+    apir_encode(enc, 4, val, sizeof(*val));
+}
+
+static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) {
+    apir_decode(dec, 4, val, sizeof(*val));
+}
+
+static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) {
+    const size_t size = sizeof(*val) * count;
+    assert(size >= count);
+    apir_encode(enc, size, val, size);
+}
+
+static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) {
+    const size_t size = sizeof(*val) * count;
+    assert(size >= count);
+    apir_decode(dec, size, val, size);
+}
+
+/* size_t */
+
+static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) {
+    const uint64_t tmp = *val;
+    apir_encode_uint64_t(enc, &tmp);
+}
+
+static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) {
+    uint64_t tmp;
+    apir_decode_uint64_t(dec, &tmp);
+    *val = tmp;
+}
+
+static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) {
+    if (sizeof(size_t) == sizeof(uint64_t)) {
+        apir_encode_uint64_t_array(enc, (const uint64_t *) val, count);
+    } else {
+        for (uint32_t i = 0; i < count; i++) {
+            apir_encode_size_t(enc, &val[i]);
+        }
+    }
+}
+
+static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) {
+    if (sizeof(size_t) == sizeof(uint64_t)) {
+        apir_decode_uint64_t_array(dec, (uint64_t *) val, count);
+    } else {
+        for (uint32_t i = 0; i < count; i++) {
+            apir_decode_size_t(dec, &val[i]);
+        }
+    }
+}
+
+/* opaque blob */
+
+static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) {
+    apir_encode(enc, (size + 3) & ~3, val, size);
+}
+
+static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) {
+    apir_decode(dec, (size + 3) & ~3, val, size);
+}
+
+/* string */
+
+static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) {
+    assert(size && strlen(val) < size);
+    apir_encode_blob_array(enc, val, size);
+}
+
+static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) {
+    apir_decode_blob_array(dec, val, size);
+    if (size) {
+        val[size - 1] = '\0';
+    } else {
+        GGML_LOG_ERROR("%s: Couldn't decode the blog array\n", __func__);
+        apir_decoder_set_fatal(dec);
+    }
+}
+
+/* (temp) buffer allocation */
+
+static inline void * apir_decoder_alloc_array(size_t size, size_t count) {
+    size_t alloc_size;
+    if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) {
+        GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count);
+        return NULL;
+    }
+
+    return malloc(alloc_size);
+}
+
+/* bool */
+
+static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) {
+    apir_encode(enc, sizeof(int), val, sizeof(bool));
+}
+
+static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) {
+    apir_decode(dec, sizeof(int), val, sizeof(bool));
+}
+
+/* apir_buffer_type_host_handle_t */
+
+static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder *                         enc,
+                                                              const apir_buffer_type_host_handle_t * val) {
+    apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
+}
+
+static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder *                   dec,
+                                                              apir_buffer_type_host_handle_t * val) {
+    apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t));
+}
+
+/* apir_buffer_host_handle_t */
+
+static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) {
+    apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
+}
+
+static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) {
+    apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t));
+}
+
+/* uintptr_t */
+
+static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) {
+    apir_encode(enc, sizeof(*val), val, sizeof(*val));
+}
+
+static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) {
+    apir_decode(dec, sizeof(*val), val, sizeof(*val));
+}
diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h
new file mode 100644
index 00000000..fabe3e40
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h
@@ -0,0 +1,232 @@
+#include "apir_cs.h"
+#include "apir_cs_rpc.h"
+#include "ggml-impl.h"
+
+// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer);
+
+static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle);
+
+static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec);
+
+/* apir_rpc_tensor */
+
+static inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) {
+    size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor);
+    apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size);
+}
+
+static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) {
+    size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor);
+
+    return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
+}
+
+static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) {
+    size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors;
+
+    return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size);
+}
+
+/* ggml_tensor */
+
+static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) {
+    apir_rpc_tensor serialized = apir_serialize_tensor(tensor);
+
+    apir_encode_rcp_tensor(enc, &serialized);
+}
+
+static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) {
+    const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec);
+
+    if (!apir_rpc_tensor) {
+        return NULL;
+    }
+
+    ggml_init_params params{
+        /*.mem_size   =*/ggml_tensor_overhead(),
+        /*.mem_buffer =*/NULL,
+        /*.no_alloc   =*/true,
+    };
+
+    ggml_context * ctx = ggml_init(params);
+
+    const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor);
+
+    return tensor;
+}
+
+/* *** ggml_backend_buffer_type_t *** */
+
+// ggml_backend_buffer_type_t is a POINTER (to a struct).
+// Only the host pointer is shared between the host and guest.
+// The guest stores it in `buft->context`.
+// The host simply writes the pointer address in the buffer variable.
+
+static inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) {
+    apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft);
+    apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
+}
+
+static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) {
+    apir_buffer_type_host_handle_t handle;
+
+    apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));
+
+    return (ggml_backend_buffer_type_t) handle;
+}
+
+static inline void apir_encode_apir_buffer_type_host_handle(apir_encoder * enc, apir_buffer_type_host_handle_t handle) {
+    apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
+}
+
+static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) {
+    apir_buffer_type_host_handle_t handle;
+
+    apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle));
+
+    return handle;
+}
+
+/* *** ggml_backend_type_t *** */
+
+// ggml_backend_buffer_t is a POINTER.
+// same logic as for ggml_backend_buffer_type_t
+
+static inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) {
+    apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer);
+    apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle));
+}
+
+static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) {
+    ggml_backend_buffer_t buffer;
+    size_t                buffer_ptr_size = sizeof(buffer);
+
+    apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size);
+
+    // SECURITY: Validate buffer handle against tracked buffers to prevent
+    // guest VM from providing arbitrary host memory addresses
+    if (buffer) {
+        extern std::unordered_set backend_buffers;
+        if (backend_buffers.find(buffer) == backend_buffers.end()) {
+            GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__,
+                          (void *) buffer);
+            // Set fatal flag to prevent further processing with invalid handle
+            apir_decoder_set_fatal(dec);
+            return NULL;
+        }
+    }
+
+    return buffer;
+}
+
+/* enum ggml_status */
+
+static inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) {
+    apir_encoder_write(enc, sizeof(*status), status, sizeof(*status));
+}
+
+static inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) {
+    apir_decoder_read(dec, sizeof(*status), status, sizeof(*status));
+}
+
+/* virtgpu_shmem */
+
+static inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) {
+    apir_encode_uint32_t(enc, &shmem_res_id);
+}
+
+static inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) {
+    apir_decode_uint32_t(dec, shmem_res_id);
+}
+
+/* ggml_cgraph */
+
+static inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector & cgraph_data) {
+    apir_serialize_graph(cgraph, cgraph_data);
+
+    return cgraph_data.size();
+}
+
+static inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector & cgraph_data) {
+    size_t cgraph_size = cgraph_data.size();
+
+    apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size);
+}
+
+static inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) {
+    GGML_UNUSED(cgraph_size);
+
+    uint32_t n_nodes;
+    apir_decode_uint32_t(dec, &n_nodes);
+    const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes);
+
+    uint32_t n_tensors;
+    apir_decode_uint32_t(dec, &n_tensors);
+    const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors);
+
+    return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes);
+}
+
+static inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) {
+    apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle));
+}
+
+static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) {
+    size_t tensor_size = sizeof(*tensor);
+
+    if (tensor->extra) {
+        GGML_ABORT("%s: Cannot pass tensors with extra", __func__);
+    }
+
+    if (tensor->src[0] && tensor->buffer) {
+        static int first = 1;
+        if (first) {
+            GGML_LOG_WARN("%s: Cannot pass tensors with src and buffer\n", __func__);
+            first = 0;
+        }
+    }
+
+    apir_encoder_write(enc, tensor_size, tensor, tensor_size);
+
+    // tensor->data is a pointer inside the device buffer. No need to touch it
+    // tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence.
+    // (could also make a copy of the tensor, and update locally.)
+
+    if (tensor->buffer) {
+        apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer);
+        apir_encode_ggml_buffer_handle(enc, &buffer_handle);
+    }
+
+    if (tensor->view_src) {
+        apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size);
+    }
+
+    for (int i = 0; tensor->src[i]; i++) {
+        const ggml_tensor * tensor_src = tensor->src[i];
+        apir_encoder_write(enc, tensor_size, tensor_src, tensor_size);
+    }
+}
+
+static inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) {
+    // it safe to remove the `const` qualifier here, we *do* want to
+    // modify the shared memory data to fix the `src` pointers.
+    ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
+
+    // tensor->data is a pointer inside the device buffer. No need to touch it
+    // tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence.
+    if (tensor->buffer) {
+        tensor->buffer = apir_decode_ggml_buffer(dec);
+    }
+
+    if (tensor->view_src) {
+        ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
+        tensor->view_src              = tensor_view_src;
+    }
+
+    for (int i = 0; tensor->src[i]; i++) {
+        ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor));
+        tensor->src[i] = tensor_src;  // overwrite op->src[i] pointer with the actual location of the src tensor
+    }
+
+    return tensor;
+}
diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h
new file mode 100644
index 00000000..4cb2f047
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h
@@ -0,0 +1,58 @@
+#pragma once
+
+// clang-format off
+#include "ggml.h"
+#include "ggml-backend-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+// clang-format on
+
+// ggml_tensor is serialized into apir_rpc_tensor
+struct apir_rpc_tensor {
+    uint64_t id;
+    uint32_t type;
+    uint64_t buffer;
+    uint32_t ne[GGML_MAX_DIMS];
+    uint32_t nb[GGML_MAX_DIMS];
+    uint32_t op;
+    int32_t  op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+    int32_t  flags;
+    uint64_t src[GGML_MAX_SRC];
+    uint64_t view_src;
+    uint64_t view_offs;
+    uint64_t data;
+    char     name[GGML_MAX_NAME];
+
+    char padding[4];
+};
+
+/* frontend */
+
+apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor);
+
+void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector & output);
+
+/* backend */
+
+void                                      apir_track_backend_buffer(ggml_backend_buffer_t buffer);
+bool                                      apir_untrack_backend_buffer(ggml_backend_buffer_t buffer);
+std::unordered_set apir_get_track_backend_buffers();
+
+void apir_add_tensor(ggml_tensor *                       tensor,
+                     std::vector &      tensors,
+                     std::unordered_set & visited);
+
+ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor);
+
+ggml_tensor * apir_create_node(uint64_t                                                      id,
+                               ggml_context *                                                ctx,
+                               const std::unordered_map & tensor_ptrs,
+                               std::unordered_map &                 tensor_map);
+
+ggml_cgraph * apir_deserialize_graph(uint32_t                n_nodes,
+                                     uint32_t                n_tensors,
+                                     const apir_rpc_tensor * tensors,
+                                     const uint64_t *        nodes);
diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp
new file mode 100644
index 00000000..8fa20ff4
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp
@@ -0,0 +1,81 @@
+#include "ggml-remoting.h"
+
+static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+                                                                            size_t                     size) {
+    virtgpu * gpu = BUFT_TO_GPU(buft);
+
+    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));
+    if (!context) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__);
+    }
+
+    context->gpu = gpu;
+
+    bool async__unused, host_buffer__unused, events__unused;
+    bool buffer_from_host_ptr;
+    apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused);
+
+    if (buffer_from_host_ptr) {
+        context->apir_context = apir_device_buffer_from_ptr(gpu, size, size);
+        context->base         = context->apir_context.shmem.mmap_ptr;
+        context->is_from_ptr  = true;
+    } else {
+        context->apir_context = apir_buffer_type_alloc_buffer(gpu, gpu->cached_buffer_type.host_handle, size);
+        context->is_from_ptr  = false;
+        context->base         = NULL;
+    }
+
+    ggml_backend_buffer_t buffer =
+        ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size);
+
+    return buffer;
+}
+
+static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    virtgpu * gpu = BUFT_TO_GPU(buft);
+
+    // Return the prefixed name that was built once during initialization
+    return gpu->cached_buffer_type.name;
+}
+
+static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    virtgpu * gpu = BUFT_TO_GPU(buft);
+
+    return gpu->cached_buffer_type.alignment;
+}
+
+static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    virtgpu * gpu = BUFT_TO_GPU(buft);
+
+    return gpu->cached_buffer_type.max_size;
+}
+
+static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
+                                                               const ggml_tensor *        tensor) {
+    virtgpu * gpu = BUFT_TO_GPU(buft);
+
+    if (tensor->buffer == NULL || !tensor->buffer->context ||
+        !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) {
+        return ggml_nbytes(tensor);
+    }
+
+    return apir_buffer_type_get_alloc_size(gpu, gpu->cached_buffer_type.host_handle, tensor);
+}
+
+const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_remoting_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_remoting_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_remoting_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_remoting_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_remoting_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
+
+const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = {
+    /* .get_name         = */ ggml_backend_remoting_buffer_type_get_name,
+    /* .alloc_buffer     = */ NULL,
+    /* .get_alignment    = */ ggml_backend_remoting_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_remoting_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_remoting_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp
new file mode 100644
index 00000000..6b95362d
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp
@@ -0,0 +1,119 @@
+#include "ggml-remoting.h"
+
+#define BUFFER_TO_GPU(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->gpu
+
+static void * ggml_backend_remoting_buffer_get_base(ggml_backend_buffer_t buffer) {
+    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) buffer->context;
+    if (context->base) {
+        return context->base;
+    }
+
+    context->base = apir_buffer_get_base(BUFFER_TO_GPU(buffer), BUFFER_TO_APIR_CONTEXT(buffer));
+
+    return context->base;
+}
+
+static void ggml_backend_remoting_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                                    ggml_tensor *         tensor,
+                                                    const void *          data,
+                                                    size_t                offset,
+                                                    size_t                size) {
+    virtgpu * gpu = BUFFER_TO_GPU(buffer);
+
+    ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);
+    if (context->is_from_ptr) {
+        memcpy((char *) tensor->data + offset, data, size);
+    } else {
+        apir_buffer_set_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size);
+    }
+
+    return;
+}
+
+static void ggml_backend_remoting_buffer_get_tensor(ggml_backend_buffer_t buffer,
+                                                    const ggml_tensor *   tensor,
+                                                    void *                data,
+                                                    size_t                offset,
+                                                    size_t                size) {
+    virtgpu *                              gpu     = BUFFER_TO_GPU(buffer);
+    ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);
+    if (context->is_from_ptr) {
+        memcpy(data, (const char *) tensor->data + offset, size);
+    } else {
+        apir_buffer_get_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size);
+    }
+}
+
+static void ggml_backend_remoting_buffer_set_tensor_from_ptr(ggml_backend_buffer_t buffer,
+                                                             ggml_tensor *         tensor,
+                                                             const void *          data,
+                                                             size_t                offset,
+                                                             size_t                size) {
+    UNUSED(buffer);
+
+    memcpy((char *) tensor->data + offset, data, size);
+
+    return;
+}
+
+static void ggml_backend_remoting_buffer_get_tensor_from_ptr(ggml_backend_buffer_t buffer,
+                                                             const ggml_tensor *   tensor,
+                                                             void *                data,
+                                                             size_t                offset,
+                                                             size_t                size) {
+    UNUSED(buffer);
+
+    memcpy(data, (const char *) tensor->data + offset, size);
+}
+
+static bool ggml_backend_remoting_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
+                                                    const ggml_tensor *   src,
+                                                    ggml_tensor *         dst) {
+    virtgpu * gpu = BUFFER_TO_GPU(buffer);
+
+    bool ret = apir_buffer_cpy_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), src, dst);
+
+    return ret;
+}
+
+static void ggml_backend_remoting_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    virtgpu * gpu = BUFFER_TO_GPU(buffer);
+
+    apir_buffer_clear(gpu, BUFFER_TO_APIR_CONTEXT(buffer), value);
+
+    return;
+}
+
+static void ggml_backend_remoting_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    virtgpu * gpu = BUFFER_TO_GPU(buffer);
+
+    apir_buffer_free_buffer(gpu, BUFFER_TO_APIR_CONTEXT(buffer));
+
+    ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer);
+    free(context);
+    buffer->context = NULL;
+}
+
+const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_remoting_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_remoting_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_remoting_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_remoting_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_remoting_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_remoting_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = {
+    /* .free_buffer     = */ ggml_backend_remoting_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_remoting_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_remoting_buffer_set_tensor_from_ptr,
+    /* .get_tensor      = */ ggml_backend_remoting_buffer_get_tensor_from_ptr,
+    /* .cpy_tensor      = */ ggml_backend_remoting_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_remoting_buffer_clear,
+    /* .reset           = */ NULL,
+};
diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp
new file mode 100644
index 00000000..ec8156bb
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp
@@ -0,0 +1,158 @@
+#include "ggml-remoting.h"
+
+static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    // Return the prefixed name that was built once during initialization
+    return gpu->cached_device_info.name;
+}
+
+static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    // Return the pre-cached description from the virtgpu structure
+    return gpu->cached_device_info.description;
+}
+
+static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    return (enum ggml_backend_dev_type) gpu->cached_device_info.type;
+}
+
+static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    *free  = gpu->cached_device_info.memory_free;
+    *total = gpu->cached_device_info.memory_total;
+}
+
+static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1
+    /* ggml-rpc cheats it like this */
+    /* with the current implementation of serialize_tensor, the src/view aren't properly passed */
+    UNUSED(dev);
+    UNUSED(op);
+
+    return true;
+#else
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    return apir_device_supports_op(gpu, op);
+#endif
+}
+
+static bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    bool supported = buft->device == dev;
+
+    return supported;
+}
+
+static bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    UNUSED(dev);
+    UNUSED(op);
+
+    return false;
+}
+
+static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_remoting_device_get_name(dev);
+    props->description = ggml_backend_remoting_device_get_description(dev);
+    props->type        = ggml_backend_remoting_device_get_type(dev);
+    ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+    virtgpu * gpu = DEV_TO_GPU(dev);
+    apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr,
+                          &props->caps.events);
+
+    props->caps.buffer_from_host_ptr = false;
+    props->caps.async                = false;
+    props->caps.events               = false;
+}
+
+ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    static std::atomic        initialized = false;
+    static ggml_backend_buffer_type buft;
+
+    if (!initialized) {
+        static std::mutex           mutex;
+        std::lock_guard lock(mutex);
+
+        if (!initialized) {
+            buft = {
+                /* .iface    = */ ggml_backend_remoting_buffer_type_interface,
+                /* .device   = */ dev,
+                /* .context  = */ (void *) gpu->cached_buffer_type.host_handle,
+            };
+            initialized = true;
+        }
+    }
+
+    return &buft;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    static std::atomic        initialized = false;
+    static ggml_backend_buffer_type buft;
+
+    if (!initialized) {
+        static std::mutex           mutex;
+        std::lock_guard lock(mutex);
+
+        if (!initialized) {
+            buft = {
+                /* .iface    = */ ggml_backend_remoting_buffer_from_ptr_type_interface,
+                /* .device   = */ dev,
+                /* .context  = */ (void *) gpu->cached_buffer_type.host_handle,
+            };
+            initialized = true;
+        }
+    }
+
+    return &buft;
+}
+
+static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev,
+                                                                          void *             ptr,
+                                                                          size_t             size,
+                                                                          size_t             max_tensor_size) {
+    virtgpu * gpu = DEV_TO_GPU(dev);
+
+    ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context));
+    if (!context) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__);
+    }
+
+    context->gpu          = gpu;
+    context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size);
+    context->base         = ptr;
+    context->is_from_ptr  = true;
+
+    ggml_backend_buffer_t buffer =
+        ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev),
+                                 ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size);
+
+    return buffer;
+}
+
+const ggml_backend_device_i ggml_backend_remoting_device_interface = {
+    /* .get_name             = */ ggml_backend_remoting_device_get_name,
+    /* .get_description      = */ ggml_backend_remoting_device_get_description,
+    /* .get_memory           = */ ggml_backend_remoting_device_get_memory,
+    /* .get_type             = */ ggml_backend_remoting_device_get_type,
+    /* .get_props            = */ ggml_backend_remoting_device_get_props,
+    /* .init_backend         = */ ggml_backend_remoting_device_init,
+    /* .get_buffer_type      = */ ggml_backend_remoting_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_remoting_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_remoting_device_supports_buft,
+    /* .offload_op           = */ ggml_backend_remoting_device_offload_op,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp
new file mode 100644
index 00000000..a4df5956
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp
@@ -0,0 +1,213 @@
+#include "ggml-remoting.h"
+#include "ggml-virtgpu.h"
+
+#include 
+#include 
+
+void ggml_virtgpu_cleanup(virtgpu * gpu);
+
+static virtgpu * apir_initialize() {
+    static virtgpu *         gpu         = NULL;
+    static std::atomic initialized = false;
+
+    if (initialized) {
+        // fast track
+        return gpu;
+    }
+
+    {
+        static std::mutex           mutex;
+        std::lock_guard lock(mutex);
+
+        if (initialized) {
+            // thread safe
+            return gpu;
+        }
+
+        gpu = create_virtgpu();
+        if (!gpu) {
+            initialized = true;
+            return NULL;
+        }
+
+        // Pre-fetch and cache all device information, it will not change
+        gpu->cached_device_info.description = apir_device_get_description(gpu);
+        if (!gpu->cached_device_info.description) {
+            GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__);
+        }
+        gpu->cached_device_info.device_count = apir_device_get_count(gpu);
+        gpu->cached_device_info.type         = apir_device_get_type(gpu);
+
+        {
+            // Get the remote name and create prefixed version
+            char * rmt_device_name = apir_device_get_name(gpu);
+            if (!rmt_device_name) {
+                GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__);
+            }
+
+            size_t device_name_len       = strlen(rmt_device_name) + 11;  // "[virtgpu] " + null terminator
+            gpu->cached_device_info.name = (char *) malloc(device_name_len);
+            if (!gpu->cached_device_info.name) {
+                free(rmt_device_name);
+                GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__);
+            }
+            snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name);
+            free(rmt_device_name);
+        }
+
+        apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total);
+
+        apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu);
+        gpu->cached_buffer_type.host_handle             = buft_host_handle;
+        {
+            // Get the remote name and create prefixed version
+            char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle);
+            if (!rmt_name) {
+                GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__);
+            }
+
+            size_t prefixed_len          = strlen(rmt_name) + 11;  // "[virtgpu] " + null terminator
+            gpu->cached_buffer_type.name = (char *) malloc(prefixed_len);
+            if (!gpu->cached_buffer_type.name) {
+                free(rmt_name);
+                GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__);
+            }
+            snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name);
+            free(rmt_name);
+        }
+
+        gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle);
+        gpu->cached_buffer_type.max_size  = apir_buffer_type_get_max_size(gpu, buft_host_handle);
+
+        initialized = true;
+    }
+
+    return gpu;
+}
+
+static int ggml_backend_remoting_get_device_count() {
+    virtgpu * gpu = apir_initialize();
+    if (!gpu) {
+        return 0;
+    }
+
+    return gpu->cached_device_info.device_count;
+}
+
+static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) {
+    UNUSED(reg);
+
+    return ggml_backend_remoting_get_device_count();
+}
+
+static std::vector devices;
+
+ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) {
+    GGML_ASSERT(device < devices.size());
+    return devices[device];
+}
+
+static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) {
+    if (devices.size() > 0) {
+        GGML_LOG_INFO(GGML_VIRTGPU "%s: already initialized\n", __func__);
+        return;
+    }
+
+    virtgpu * gpu = apir_initialize();
+    if (!gpu) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: apir_initialize failed\n", __func__);
+        return;
+    }
+
+    static std::atomic initialized = false;
+
+    if (initialized) {
+        return;  // fast track
+    }
+
+    {
+        static std::mutex           mutex;
+        std::lock_guard lock(mutex);
+        if (!initialized) {
+            for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) {
+                ggml_backend_remoting_device_context * ctx       = new ggml_backend_remoting_device_context;
+                char                                   desc[256] = "ggml-virtgpu API Remoting device";
+
+                ctx->device      = i;
+                ctx->name        = GGML_VIRTGPU_NAME + std::to_string(i);
+                ctx->description = desc;
+                ctx->gpu         = gpu;
+
+                ggml_backend_dev_t dev = new ggml_backend_device{
+                    /* .iface   = */ ggml_backend_remoting_device_interface,
+                    /* .reg     = */ reg,
+                    /* .context = */ ctx,
+                };
+                devices.push_back(dev);
+            }
+            initialized = true;
+        }
+    }
+}
+
+static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) {
+    UNUSED(reg);
+
+    return ggml_backend_remoting_get_device(device);
+}
+
+static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) {
+    UNUSED(reg);
+
+    return GGML_VIRTGPU_NAME;
+}
+
+static const ggml_backend_reg_i ggml_backend_remoting_reg_i = {
+    /* .get_name         = */ ggml_backend_remoting_reg_get_name,
+    /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_remoting_reg_get_device,
+    /* .get_proc_address = */ NULL,
+};
+
+ggml_backend_reg_t ggml_backend_virtgpu_reg() {
+    virtgpu * gpu = apir_initialize();
+    if (!gpu) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_apir_initialize failed\n", __func__);
+    }
+
+    static ggml_backend_reg reg = {
+        /* .api_version = */ GGML_BACKEND_API_VERSION,
+        /* .iface       = */ ggml_backend_remoting_reg_i,
+        /* .context     = */ gpu,
+    };
+
+    static bool initialized = false;
+    if (initialized) {
+        return ®
+    }
+    initialized = true;
+
+    ggml_backend_remoting_reg_init_devices(®);
+
+    return ®
+}
+
+// public function, not exposed in the GGML interface at the moment
+void ggml_virtgpu_cleanup(virtgpu * gpu) {
+    if (gpu->cached_device_info.name) {
+        free(gpu->cached_device_info.name);
+        gpu->cached_device_info.name = NULL;
+    }
+    if (gpu->cached_device_info.description) {
+        free(gpu->cached_device_info.description);
+        gpu->cached_device_info.description = NULL;
+    }
+    if (gpu->cached_buffer_type.name) {
+        free(gpu->cached_buffer_type.name);
+        gpu->cached_buffer_type.name = NULL;
+    }
+
+    mtx_destroy(&gpu->data_shmem_mutex);
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg)
diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp
new file mode 100644
index 00000000..a63ee2b9
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp
@@ -0,0 +1,69 @@
+#include "../../include/ggml-virtgpu.h"
+#include "ggml-remoting.h"
+
+static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) {
+    UNUSED(backend);
+
+    return "API Remoting backend";
+}
+
+static void ggml_backend_remoting_free(ggml_backend_t backend) {
+    delete backend;
+}
+
+static ggml_status ggml_backend_remoting_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    virtgpu * gpu = DEV_TO_GPU(backend->device);
+
+    return apir_backend_graph_compute(gpu, cgraph);
+}
+
+static void ggml_backend_remoting_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    virtgpu * gpu = DEV_TO_GPU(backend->device);
+#if true
+    UNUSED(gpu);
+    UNUSED(cgraph);
+#else
+    // not working yet
+
+    apir_backend_graph_optimize(gpu, cgraph);
+#endif
+}
+
+static ggml_backend_i ggml_backend_remoting_interface = {
+    /* .get_name                = */ ggml_backend_remoting_get_name,
+    /* .free                    = */ ggml_backend_remoting_free,
+    /* .set_tensor_async        = */ NULL,  // ggml_backend_remoting_set_tensor_async,
+    /* .get_tensor_async        = */ NULL,  // ggml_backend_remoting_get_tensor_async,
+    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_remoting_cpy_tensor_async,
+    /* .synchronize             = */ NULL,  // ggml_backend_remoting_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_remoting_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+    /* .graph_optimize          = */ ggml_backend_remoting_graph_optimize,
+};
+
+static ggml_guid_t ggml_backend_remoting_guid() {
+    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x14, 0x03, 0x86, 0x02,
+                              0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
+
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params) {
+    UNUSED(params);
+
+    ggml_backend_remoting_device_context * ctx = (ggml_backend_remoting_device_context *) dev->context;
+
+    ggml_backend_t remoting_backend = new ggml_backend{
+        /* .guid      = */ ggml_backend_remoting_guid(),
+        /* .interface = */ ggml_backend_remoting_interface,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_virtgpu_reg(), ctx->device),
+        /* .context   = */ ctx,
+    };
+
+    return remoting_backend;
+}
diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h
new file mode 100644
index 00000000..4f70326b
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggml-remoting.h
@@ -0,0 +1,71 @@
+#pragma once
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "virtgpu.h"
+
+#include 
+#include 
+
+#define GGML_VIRTGPU_NAME "ggml-virtgpu"
+#define GGML_VIRTGPU      "ggml-virtgpu: "
+
+// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes
+
+#define USE_ALWAYS_TRUE_SUPPORTS_OP 1
+#define USE_METAL_GUEST_SUPPORTS_OP 0
+
+#define DEV_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->context)->gpu
+
+#define BUFFER_TO_GGML_CONTEXT(name) ((ggml_backend_remoting_buffer_context *) (name)->context)
+
+#define BUFFER_TO_APIR_CONTEXT(name) &((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context
+
+#define BUFFER_TO_HOST_HANDLE(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context.host_handle
+
+#define GET_DEVICE_CONTEXT() (ggml_backend_remoting_device_context *) ggml_backend_remoting_get_device(0)->context
+
+#define BUFT_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->device->context)->gpu
+
+struct ggml_backend_remoting_device_context {
+    size_t      device;
+    std::string name;
+    std::string description;
+
+    std::vector> shared_memory;
+
+    virtgpu * gpu;
+};
+
+struct ggml_backend_remoting_buffer_context {
+    apir_buffer_context_t apir_context;
+
+    virtgpu * gpu;
+
+    void * base;
+
+    bool is_from_ptr;
+};
+
+extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface;
+extern const ggml_backend_device_i      ggml_backend_remoting_device_interface;
+extern const ggml_backend_buffer_i      ggml_backend_remoting_buffer_interface;
+extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface;
+extern const ggml_backend_buffer_i      ggml_backend_remoting_buffer_from_ptr_interface;
+
+ggml_backend_dev_t         ggml_backend_remoting_get_device(size_t device);
+ggml_backend_t             ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params);
+ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev);
+
+static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) {
+    // in the backend, the buffer handle is the buffer pointer
+    return (apir_buffer_type_host_handle_t) buft->context;
+}
+
+static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) {
+    if (!buffer->context) {
+        GGML_ABORT(GGML_VIRTGPU "%s: no context available :/", __func__);
+    }
+    return BUFFER_TO_HOST_HANDLE(buffer);
+}
diff --git a/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml
new file mode 100644
index 00000000..14ef2433
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml
@@ -0,0 +1,166 @@
+# YAML schema for GGML remoting API functions
+# This defines the structure for generating the remoting layer code
+
+# Configuration for the generated files
+config:
+  # Base path for the generated files
+  base_path: "ggml/src"
+
+  # Header files to update
+  files:
+    apir_backend_header: "ggml-virtgpu-apir/backend/shared/apir_backend.gen.h"
+    backend_dispatched_header: "ggml-virtgpu-apir/backend/backend-dispatched.gen.h"
+    virtgpu_forward_header: "ggml-virtgpu-apir/virtgpu-forward.gen.h"
+
+# Simplified function definitions with grouping and metadata combined
+functions:
+  device:
+    group_description: "device"
+    functions:
+      get_device_count:
+        # No specific metadata - uses default void return and base params
+
+      get_count:
+        frontend_return: "int"
+
+      get_name:
+        frontend_return: "char *"
+
+      get_description:
+        frontend_return: "char *"
+
+      get_type:
+        frontend_return: "uint32_t"
+
+      get_memory:
+        frontend_return: "void"
+        frontend_extra_params:
+        - "size_t *free"
+        - "size_t *total"
+
+      supports_op:
+        frontend_return: "bool"
+        frontend_extra_params:
+        - "const ggml_tensor *op"
+
+      get_buffer_type:
+        frontend_return: "apir_buffer_type_host_handle_t"
+
+      get_props:
+        frontend_return: "void"
+        frontend_extra_params:
+        - "bool *async"
+        - "bool *host_buffer"
+        - "bool *buffer_from_host_ptr"
+        - "bool *events"
+
+      buffer_from_ptr:
+        frontend_return: "apir_buffer_context_t"
+        frontend_extra_params:
+        - "size_t size"
+        - "size_t max_tensor_size"
+
+  buffer_type:
+    group_description: "buffer-type"
+    functions:
+      get_name:
+        frontend_return: "char *"
+        frontend_extra_params:
+        - "apir_buffer_type_host_handle_t host_handle"
+
+      get_alignment:
+        frontend_return: "size_t"
+        frontend_extra_params:
+        - "apir_buffer_type_host_handle_t host_handle"
+
+      get_max_size:
+        frontend_return: "size_t"
+        frontend_extra_params:
+        - "apir_buffer_type_host_handle_t host_handle"
+
+      is_host:
+        deprecated: true
+
+      alloc_buffer:
+        frontend_return: "apir_buffer_context_t"
+        frontend_extra_params:
+        - "apir_buffer_type_host_handle_t host_handle"
+        - "size_t size"
+
+      get_alloc_size:
+        frontend_return: "size_t"
+        frontend_extra_params:
+        - "apir_buffer_type_host_handle_t host_handle"
+        - "const ggml_tensor *op"
+
+  buffer:
+    group_description: "buffer"
+    functions:
+      get_base:
+        frontend_return: "void *"
+        frontend_extra_params:
+        - "apir_buffer_context_t *buffer_context"
+
+      set_tensor:
+        frontend_return: "void"
+        frontend_extra_params:
+        - "apir_buffer_context_t *buffer_context"
+        - "ggml_tensor *tensor"
+        - "const void *data"
+        - "size_t offset"
+        - "size_t size"
+
+      get_tensor:
+        frontend_return: "void"
+        frontend_extra_params:
+        - "apir_buffer_context_t *buffer_context"
+        - "const ggml_tensor *tensor"
+        - "void *data"
+        - "size_t offset"
+        - "size_t size"
+
+      cpy_tensor:
+        frontend_return: "bool"
+        frontend_extra_params:
+        - "apir_buffer_context_t *buffer_context"
+        - "const ggml_tensor *src"
+        - "const ggml_tensor *dst"
+
+      clear:
+        frontend_return: "void"
+        frontend_extra_params:
+        - "apir_buffer_context_t *buffer_context"
+        - "uint8_t value"
+
+      free_buffer:
+        frontend_return: "void"
+        frontend_extra_params:
+        - "apir_buffer_context_t *buffer_context"
+
+  backend:
+    group_description: "backend"
+    functions:
+      graph_compute:
+        frontend_return: "ggml_status"
+        frontend_extra_params:
+        - "ggml_cgraph *cgraph"
+
+      graph_optimize:
+        frontend_return: "ggml_cgraph *"
+        frontend_extra_params:
+        - "ggml_cgraph *cgraph"
+        enabled: false
+
+# Naming patterns used for code generation
+naming_patterns:
+  # How to generate enum names
+  enum_prefix: "APIR_COMMAND_TYPE_"
+
+  # How to generate backend function names
+  backend_function_prefix: "backend_"
+
+  # How to generate frontend function names
+  frontend_function_prefix: "apir_"
+
+  # Standard frontend first parameter
+  frontend_base_param: "struct virtgpu *gpu"
diff --git a/ggml/src/ggml-virtgpu/include/apir_hw.h b/ggml/src/ggml-virtgpu/include/apir_hw.h
new file mode 100644
index 00000000..7d6ea226
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/include/apir_hw.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include 
+
+struct virgl_renderer_capset_apir {
+    uint32_t apir_version;
+    uint32_t supports_blob_resources;
+    uint32_t reserved[4];  // For future expansion
+};
diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py
new file mode 100755
index 00000000..dae75fd1
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py
@@ -0,0 +1,333 @@
+#!/usr/bin/env python3
+"""
+# Generated by Claude AI
+
+Script to completely regenerate the GGML remoting codebase from YAML configuration.
+
+This script reads api_functions.yaml and regenerates all the header files and
+implementation templates for the GGML remoting layer.
+
+Usage:
+  python regenerate_remoting.py
+
+The script will:
+1. Read ggmlremoting_functions.yaml configuration
+2. Generate updated header files
+3. Generate implementation templates in dedicated files
+4. Show a summary of what was generated
+"""
+
+import yaml
+from typing import Dict, List, Any
+from pathlib import Path
+import os
+import subprocess
+import shutil
+import logging
+
+NL = '\n' # can't have f"{'\n'}" in f-strings
+
+
+class RemotingCodebaseGenerator:
+    def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"):
+        """Initialize the generator with the YAML configuration."""
+        self.yaml_path = yaml_path
+
+        if not Path(yaml_path).exists():
+            raise FileNotFoundError(f"Configuration file {yaml_path} not found")
+
+        with open(yaml_path, 'r') as f:
+            self.config = yaml.safe_load(f)
+
+        self.functions = self.config['functions']
+        self.naming_patterns = self.config['naming_patterns']
+        self.config_data = self.config['config']
+
+        # Check if clang-format is available
+        self.clang_format_available = self._check_clang_format_available()
+
+    def _check_clang_format_available(self) -> bool:
+        """Check if clang-format is available in the system PATH."""
+        return shutil.which("clang-format") is not None
+
+    def _format_file_with_clang_format(self, file_path: Path) -> bool:
+        """Format a file with clang-format -i. Returns True if successful, False otherwise."""
+        if not self.clang_format_available:
+            return False
+
+        try:
+            subprocess.run(
+                ["clang-format", "-i", str(file_path)],
+                check=True,
+                capture_output=True,
+                text=True
+            )
+            return True
+        except subprocess.CalledProcessError:
+            logging.exception(f"   ⚠️  clang-format failed for {file_path}")
+            return False
+        except Exception as e:
+            logging.exception(f"   ⚠️  Unexpected error formatting {file_path}: {e}")
+            return False
+
+    def generate_enum_name(self, group_name: str, function_name: str) -> str:
+        """Generate the APIR_COMMAND_TYPE enum name for a function."""
+        prefix = self.naming_patterns['enum_prefix']
+        return f"{prefix}{group_name.upper()}_{function_name.upper()}"
+
+    def generate_backend_function_name(self, group_name: str, function_name: str) -> str:
+        """Generate the backend function name."""
+        function_key = f"{group_name}_{function_name}"
+        overrides = self.naming_patterns.get('backend_function_overrides', {})
+
+        if function_key in overrides:
+            return overrides[function_key]
+
+        prefix = self.naming_patterns['backend_function_prefix']
+        return f"{prefix}{group_name}_{function_name}"
+
+    def generate_frontend_function_name(self, group_name: str, function_name: str) -> str:
+        """Generate the frontend function name."""
+        prefix = self.naming_patterns['frontend_function_prefix']
+        return f"{prefix}{group_name}_{function_name}"
+
+    def get_enabled_functions(self) -> List[Dict[str, Any]]:
+        """Get all enabled functions with their metadata."""
+        functions = []
+        enum_value = 0
+
+        for group_name, group_data in self.functions.items():
+            group_description = group_data['group_description']
+
+            for function_name, func_metadata in group_data['functions'].items():
+                # Handle case where func_metadata is None or empty (functions with only comments)
+                if func_metadata is None:
+                    func_metadata = {}
+
+                # Functions are enabled by default unless explicitly disabled
+                if func_metadata.get('enabled', True):
+                    functions.append({
+                        'group_name': group_name,
+                        'function_name': function_name,
+                        'enum_name': self.generate_enum_name(group_name, function_name),
+                        'enum_value': enum_value,
+                        'backend_function': self.generate_backend_function_name(group_name, function_name),
+                        'frontend_function': self.generate_frontend_function_name(group_name, function_name),
+                        'frontend_return': func_metadata.get('frontend_return', 'void'),
+                        'frontend_extra_params': func_metadata.get('frontend_extra_params', []),
+                        'group_description': group_description,
+                        'deprecated': func_metadata.get('deprecated', False),
+                    })
+                    enum_value += 1
+
+        return functions
+
+    def generate_apir_backend_header(self) -> str:
+        """Generate the complete apir_backend.h file."""
+        functions = self.get_enabled_functions()
+
+        # Generate the enum section
+        enum_lines = ["typedef enum ApirBackendCommandType {"]
+        current_group = None
+
+        for func in functions:
+            # Add comment for new group
+            if func['group_name'] != current_group:
+                enum_lines.append("")
+                enum_lines.append(f"  /* {func['group_description']} */")
+                current_group = func['group_name']
+
+            enum_lines.append(f"  {func['enum_name']} = {func['enum_value']},")
+
+        # Add the count
+        total_count = len(functions)
+        enum_lines.append("\n  // last command_type index + 1")
+        enum_lines.append(f"  APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},")
+        enum_lines.append("} ApirBackendCommandType;")
+
+        # Generate function name mapping
+        func_lines = []
+        func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {")
+        func_lines.append("    switch (type) {")
+
+        current_group = None
+        for func in functions:
+            # Add comment for new group
+            if func['group_name'] != current_group:
+                func_lines.append(f"        /* {func['group_description']} */")
+                current_group = func['group_name']
+
+            # Generate clean function name without backend_ prefix
+            clean_name = f"{func['group_name']}_{func['function_name']}"
+            func_lines.append(f"        case {func['enum_name']}:")
+            func_lines.append(f"            return \"{clean_name}\";")
+
+        func_lines.append("")
+        func_lines.append("        default:")
+        func_lines.append("            return \"unknown\";")
+        func_lines.append("    }")
+        func_lines.append("}")
+
+        # Full header template
+        header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n"
+
+        return header_content
+
+    def generate_backend_dispatched_header(self) -> str:
+        """Generate the complete backend-dispatched.h file."""
+        functions = self.get_enabled_functions()
+
+        # Function declarations
+        decl_lines = []
+        current_group = None
+
+        for func in functions:
+            if func['group_name'] != current_group:
+                decl_lines.append(f"\n/* {func['group_description']} */")
+                current_group = func['group_name']
+
+            signature = "uint32_t"
+            params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx"
+            if func['deprecated']:
+                decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */")
+
+            decl_lines.append(f"{signature} {func['backend_function']}({params});")
+
+        # Dispatch table
+        table_lines = []
+        current_group = None
+
+        for func in functions:
+            if func['group_name'] != current_group:
+                table_lines.append(f"\n  /* {func['group_description']} */")
+                table_lines.append("")
+                current_group = func['group_name']
+
+            deprecated = " /* DEPRECATED */" if func['deprecated'] else ""
+            table_lines.append(f"  /* {func['enum_name']}  = */ {func['backend_function']}{deprecated},")
+
+        header_content = f'''\
+#pragma once
+
+{NL.join(decl_lines)}
+
+extern "C" {{
+static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{
+  {NL.join(table_lines)}
+}};
+}}
+'''
+        return header_content
+
+    def generate_virtgpu_forward_header(self) -> str:
+        """Generate the complete virtgpu-forward.gen.h file."""
+        functions = self.get_enabled_functions()
+
+        decl_lines = []
+        current_group = None
+
+        for func in functions:
+            if func['group_name'] != current_group:
+                decl_lines.append("")
+                decl_lines.append(f"/* {func['group_description']} */")
+                current_group = func['group_name']
+
+            if func['deprecated']:
+                decl_lines.append(f"/* {func['frontend_function']} is deprecated. */")
+                continue
+
+            # Build parameter list
+            params = [self.naming_patterns['frontend_base_param']]
+            params.extend(func['frontend_extra_params'])
+            param_str = ', '.join(params)
+
+            decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});")
+
+        header_content = f'''\
+#pragma once
+{NL.join(decl_lines)}
+'''
+        return header_content
+
+    def regenerate_codebase(self) -> None:
+        """Regenerate the entire remoting codebase."""
+        logging.info("🔄 Regenerating GGML Remoting Codebase...")
+        logging.info("=" * 50)
+
+        # Detect if we're running from frontend directory
+        current_dir = os.getcwd()
+        is_frontend_dir = current_dir.endswith('ggml-virtgpu')
+
+        if is_frontend_dir:
+            # Running from ggml/src/ggml-virtgpu-apir
+            logging.info("📍 Detected frontend directory execution")
+            frontend_base = Path(".")
+        else:
+            # Running from project root (fallback to original behavior)
+            logging.info("📍 Detected project root execution")
+            base_path = self.config_data.get('base_path', 'ggml/src')
+            frontend_base = Path(base_path) / "ggml-virtgpu"
+
+        # Compute final file paths
+        backend_base = frontend_base / "backend"
+        apir_backend_path = backend_base / "shared" / "apir_backend.gen.h"
+        backend_dispatched_path = backend_base / "backend-dispatched.gen.h"
+        virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h"
+
+        # Create output directories for each file
+        apir_backend_path.parent.mkdir(parents=True, exist_ok=True)
+        backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True)
+        virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True)
+
+        # Generate header files
+        logging.info("📁 Generating header files...")
+
+        apir_backend_content = self.generate_apir_backend_header()
+        apir_backend_path.write_text(apir_backend_content)
+        logging.info(f"   ✅ {apir_backend_path.resolve()}")
+
+        backend_dispatched_content = self.generate_backend_dispatched_header()
+        backend_dispatched_path.write_text(backend_dispatched_content)
+        logging.info(f"   ✅ {backend_dispatched_path.resolve()}")
+
+        virtgpu_forward_content = self.generate_virtgpu_forward_header()
+        virtgpu_forward_path.write_text(virtgpu_forward_content)
+        logging.info(f"   ✅ {virtgpu_forward_path.resolve()}")
+
+        # Format generated files with clang-format
+        generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path]
+
+        if not self.clang_format_available:
+            logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted.\n"
+                            "   Install clang-format to enable automatic code formatting.")
+        else:
+            logging.info("\n🎨 Formatting files with clang-format...")
+            for file_path in generated_files:
+                if self._format_file_with_clang_format(file_path):
+                    logging.info(f"   ✅ Formatted {file_path.name}")
+                else:
+                    logging.warning(f"   ❌ Failed to format {file_path.name}")
+
+        # Generate summary
+        functions = self.get_enabled_functions()
+        total_functions = len(functions)
+
+        logging.info("\n📊 Generation Summary:")
+        logging.info("=" * 50)
+        logging.info(f"   Total functions: {total_functions}")
+        logging.info(f"   Function groups: {len(self.functions)}")
+        logging.info("   Header files: 3")
+        logging.info(f"   Working directory: {current_dir}")
+
+
+def main():
+    try:
+        generator = RemotingCodebaseGenerator()
+        generator.regenerate_codebase()
+    except Exception as e:
+        logging.exception(f"❌ Error: {e}")
+        exit(1)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/ggml/src/ggml-virtgpu/virtgpu-apir.h b/ggml/src/ggml-virtgpu/virtgpu-apir.h
new file mode 100644
index 00000000..238f960a
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-apir.h
@@ -0,0 +1,15 @@
+#include "backend/shared/apir_backend.h"
+#include "ggml-alloc.h"
+#include "ggml-impl.h"
+#include "ggml.h"
+#include "virtgpu-shm.h"
+#include "virtgpu-utils.h"
+
+struct apir_buffer_context_t {
+    apir_buffer_host_handle_t host_handle;
+
+    struct virtgpu_shmem           shmem;
+    apir_buffer_type_host_handle_t buft_host_handle;
+};
+
+#include "virtgpu-forward.gen.h"
diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp
new file mode 100644
index 00000000..4593690c
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp
@@ -0,0 +1,58 @@
+#include "virtgpu-forward-impl.h"
+
+static long long current_time_ms() {
+    timespec ts;
+    clock_gettime(CLOCK_REALTIME, &ts);  // Use CLOCK_MONOTONIC for elapsed time
+    return (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec;
+}
+
+ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE);
+
+    std::vector cgraph_data;
+    size_t               cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data);
+
+    virtgpu_shmem   temp_shmem;  // Local storage for large buffers
+    virtgpu_shmem * shmem              = &temp_shmem;
+    bool            using_shared_shmem = false;
+
+    if (cgraph_size <= gpu->data_shmem.mmap_size) {
+        // Lock mutex before using shared data_shmem buffer
+        if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {
+            GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
+        }
+        using_shared_shmem = true;
+        shmem              = &gpu->data_shmem;
+    } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
+    }
+
+    apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id);
+
+    apir_encode_size_t(encoder, &cgraph_size);
+
+    char *       shmem_data    = (char *) shmem->mmap_ptr;
+    apir_encoder secondary_enc = apir_new_encoder(shmem_data, cgraph_size);
+
+    apir_encode_cgraph_data(&secondary_enc, cgraph_data);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    ggml_status status = GGML_STATUS_ABORTED;
+    apir_decode_ggml_status(decoder, &status);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    // Unlock mutex before cleanup
+    if (using_shared_shmem) {
+        mtx_unlock(&gpu->data_shmem_mutex);
+    } else {
+        virtgpu_shmem_destroy(gpu, shmem);
+    }
+
+    return status;
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp
new file mode 100644
index 00000000..38f8ec94
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp
@@ -0,0 +1,110 @@
+#include "virtgpu-forward-impl.h"
+
+char * apir_buffer_type_get_name(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME);
+
+    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    const size_t string_size = apir_decode_array_size_unchecked(decoder);
+    char *       string      = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
+    if (!string) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__);
+        apir_decoder_set_fatal(decoder);
+    }
+    apir_decode_char_array(decoder, string, string_size);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return string;
+}
+
+size_t apir_buffer_type_get_alignment(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT);
+
+    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    size_t alignment;
+    apir_decode_size_t(decoder, &alignment);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return alignment;
+}
+
+size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE);
+
+    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    size_t max_size;
+    apir_decode_size_t(decoder, &max_size);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return max_size;
+}
+
+apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu *                      gpu,
+                                                    apir_buffer_type_host_handle_t host_handle,
+                                                    size_t                         size) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    apir_buffer_context_t buffer_context;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER);
+
+    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);
+
+    apir_encode_size_t(encoder, &size);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return buffer_context;
+}
+
+size_t apir_buffer_type_get_alloc_size(virtgpu *                      gpu,
+                                       apir_buffer_type_host_handle_t host_handle,
+                                       const ggml_tensor *            op) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE);
+
+    apir_encode_apir_buffer_type_host_handle(encoder, host_handle);
+
+    apir_encode_ggml_tensor_inline(encoder, op);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    size_t alloc_size;
+    apir_decode_size_t(decoder, &alloc_size);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return alloc_size;
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp
new file mode 100644
index 00000000..228284f4
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp
@@ -0,0 +1,173 @@
+#include "virtgpu-forward-impl.h"
+
+void * apir_buffer_get_base(virtgpu * gpu, apir_buffer_context_t * buffer_context) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_BASE);
+
+    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    uintptr_t base;
+    apir_decode_uintptr_t(decoder, &base);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return (void *) base;
+}
+
+void apir_buffer_set_tensor(virtgpu *               gpu,
+                            apir_buffer_context_t * buffer_context,
+                            ggml_tensor *           tensor,
+                            const void *            data,
+                            size_t                  offset,
+                            size_t                  size) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_SET_TENSOR);
+
+    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);
+    apir_encode_ggml_tensor(encoder, tensor);
+
+    virtgpu_shmem   temp_shmem;  // Local storage for large buffers
+    virtgpu_shmem * shmem              = &temp_shmem;
+    bool            using_shared_shmem = false;
+
+    if (size <= gpu->data_shmem.mmap_size) {
+        // Lock mutex before using shared data_shmem buffer
+        if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {
+            GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
+        }
+        using_shared_shmem = true;
+        shmem              = &gpu->data_shmem;
+
+    } else if (virtgpu_shmem_create(gpu, size, shmem)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
+    }
+
+    memcpy(shmem->mmap_ptr, data, size);
+    apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id);
+
+    apir_encode_size_t(encoder, &offset);
+    apir_encode_size_t(encoder, &size);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    // Unlock mutex before cleanup
+    if (using_shared_shmem) {
+        mtx_unlock(&gpu->data_shmem_mutex);
+    } else {
+        virtgpu_shmem_destroy(gpu, shmem);
+    }
+
+    return;
+}
+
+void apir_buffer_get_tensor(virtgpu *               gpu,
+                            apir_buffer_context_t * buffer_context,
+                            const ggml_tensor *     tensor,
+                            void *                  data,
+                            size_t                  offset,
+                            size_t                  size) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_TENSOR);
+
+    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);
+    apir_encode_ggml_tensor(encoder, tensor);
+
+    virtgpu_shmem   temp_shmem;  // Local storage for large buffers
+    virtgpu_shmem * shmem              = &temp_shmem;
+    bool            using_shared_shmem = false;
+
+    if (size <= gpu->data_shmem.mmap_size) {
+        // Lock mutex before using shared data_shmem buffer
+        if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) {
+            GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__);
+        }
+        using_shared_shmem = true;
+        shmem              = &gpu->data_shmem;
+
+    } else if (virtgpu_shmem_create(gpu, size, shmem)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__);
+    }
+
+    apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id);
+    apir_encode_size_t(encoder, &offset);
+    apir_encode_size_t(encoder, &size);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    memcpy(data, shmem->mmap_ptr, size);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    // Unlock mutex before cleanup
+    if (using_shared_shmem) {
+        mtx_unlock(&gpu->data_shmem_mutex);
+    } else {
+        virtgpu_shmem_destroy(gpu, shmem);
+    }
+}
+
+bool apir_buffer_cpy_tensor(virtgpu *               gpu,
+                            apir_buffer_context_t * buffer_context,
+                            const ggml_tensor *     src,
+                            const ggml_tensor *     dst) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR);
+
+    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);
+    apir_encode_ggml_tensor(encoder, src);
+    apir_encode_ggml_tensor(encoder, dst);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    bool ret_val;
+    apir_decode_bool_t(decoder, &ret_val);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return ret_val;
+}
+
+void apir_buffer_clear(virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CLEAR);
+
+    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);
+    apir_encode_uint8_t(encoder, &value);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    remote_call_finish(gpu, encoder, decoder);
+}
+
+void apir_buffer_free_buffer(virtgpu * gpu, apir_buffer_context_t * buffer_context) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER);
+
+    apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    remote_call_finish(gpu, encoder, decoder);
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp
new file mode 100644
index 00000000..9f513c13
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp
@@ -0,0 +1,192 @@
+#include "virtgpu-forward-impl.h"
+#include "virtgpu-shm.h"
+
+int apir_device_get_count(virtgpu * gpu) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_COUNT);
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    int32_t dev_count = -1;
+    apir_decode_int32_t(decoder, &dev_count);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return dev_count;
+}
+
+char * apir_device_get_name(virtgpu * gpu) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_NAME);
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    const size_t string_size = apir_decode_array_size_unchecked(decoder);
+    char *       string      = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
+    if (!string) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__);
+        return NULL;
+    }
+    apir_decode_char_array(decoder, string, string_size);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return string;
+}
+
+char * apir_device_get_description(virtgpu * gpu) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    const size_t string_size = apir_decode_array_size_unchecked(decoder);
+    char *       string      = (char *) apir_decoder_alloc_array(sizeof(char), string_size);
+    if (!string) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device description buffer\n", __func__);
+
+        return NULL;
+    }
+    apir_decode_char_array(decoder, string, string_size);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return string;
+}
+
+uint32_t apir_device_get_type(virtgpu * gpu) {
+    static uint32_t dev_type = 255;
+    if (dev_type != 255) {
+        return dev_type;
+    }
+
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_TYPE);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    apir_decode_uint32_t(decoder, &dev_type);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return dev_type;
+}
+
+void apir_device_get_memory(virtgpu * gpu, size_t * free, size_t * total) {
+    static size_t         dev_free  = 0;
+    static size_t         dev_total = 0;
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_MEMORY);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    apir_decode_size_t(decoder, &dev_free);
+    apir_decode_size_t(decoder, &dev_total);
+
+    *free  = dev_free;
+    *total = dev_total;
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return;
+}
+
+bool apir_device_supports_op(virtgpu * gpu, const ggml_tensor * op) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP);
+
+    apir_encode_ggml_tensor_inline(encoder, op);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    bool supports_op;
+    apir_decode_bool_t(decoder, &supports_op);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return supports_op;
+}
+
+apir_buffer_type_host_handle_t apir_device_get_buffer_type(virtgpu * gpu) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    apir_buffer_type_host_handle_t buft_handle;
+    apir_decode_apir_buffer_type_host_handle_t(decoder, &buft_handle);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return buft_handle;
+}
+
+void apir_device_get_props(virtgpu * gpu,
+                           bool *    async,
+                           bool *    host_buffer,
+                           bool *    buffer_from_host_ptr,
+                           bool *    events) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_PROPS);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    apir_decode_bool_t(decoder, async);
+    apir_decode_bool_t(decoder, host_buffer);
+    apir_decode_bool_t(decoder, buffer_from_host_ptr);
+    apir_decode_bool_t(decoder, events);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return;
+}
+
+apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, size_t max_tensor_size) {
+    apir_encoder *        encoder;
+    apir_decoder *        decoder;
+    ApirForwardReturnCode ret;
+
+    apir_buffer_context_t buffer_context;
+
+    REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR);
+
+    if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size);
+    }
+
+    apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id);
+
+    apir_encode_size_t(encoder, &size);
+    apir_encode_size_t(encoder, &max_tensor_size);
+
+    REMOTE_CALL(gpu, encoder, decoder, ret);
+
+    apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle);
+    buffer_context.buft_host_handle = apir_decode_apir_buffer_type_host_handle(decoder);
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    return buffer_context;
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h
new file mode 100644
index 00000000..4d0b6e05
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h
@@ -0,0 +1,36 @@
+#pragma once
+
+// clang-format off
+#include "virtgpu.h"
+#include "ggml-remoting.h"
+#include "backend/shared/apir_backend.h"
+#include "backend/shared/apir_cs_ggml.h"
+#include "ggml-backend-impl.h"
+// clang-format on
+
+#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__)                                           \
+    int32_t      REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__;                                     \
+    const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__);                   \
+    do {                                                                                                               \
+        encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \
+        if (!encoder_name) {                                                                                           \
+            GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__);                        \
+        }                                                                                                              \
+    } while (0)
+
+#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name)                                     \
+    do {                                                                                                    \
+        ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \
+        if (!decoder_name) {                                                                                \
+            GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__);                        \
+        }                                                                                                   \
+        if (ret_name < APIR_FORWARD_BASE_INDEX) {                                                           \
+            GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__,            \
+                       apir_forward_error(ret_name), ret_name);                                             \
+        }                                                                                                   \
+        ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX);                            \
+        if (ret_name != 0) {                                                                                \
+            GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)",                       \
+                       REMOTE_CALL_PREPARE_command_name, ret_name);                                         \
+        }                                                                                                   \
+    } while (0)
diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h
new file mode 100644
index 00000000..44b0ad1f
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h
@@ -0,0 +1,53 @@
+#pragma once
+
+/* device */
+void                           apir_device_get_device_count(struct virtgpu * gpu);
+int                            apir_device_get_count(struct virtgpu * gpu);
+char *                         apir_device_get_name(struct virtgpu * gpu);
+char *                         apir_device_get_description(struct virtgpu * gpu);
+uint32_t                       apir_device_get_type(struct virtgpu * gpu);
+void                           apir_device_get_memory(struct virtgpu * gpu, size_t * free, size_t * total);
+bool                           apir_device_supports_op(struct virtgpu * gpu, const ggml_tensor * op);
+apir_buffer_type_host_handle_t apir_device_get_buffer_type(struct virtgpu * gpu);
+void                           apir_device_get_props(struct virtgpu * gpu,
+                                                     bool *           async,
+                                                     bool *           host_buffer,
+                                                     bool *           buffer_from_host_ptr,
+                                                     bool *           events);
+apir_buffer_context_t          apir_device_buffer_from_ptr(struct virtgpu * gpu, size_t size, size_t max_tensor_size);
+
+/* buffer-type */
+char *                apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
+size_t                apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
+size_t                apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle);
+/* apir_buffer_type_is_host is deprecated. */
+apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu *               gpu,
+                                                    apir_buffer_type_host_handle_t host_handle,
+                                                    size_t                         size);
+size_t                apir_buffer_type_get_alloc_size(struct virtgpu *               gpu,
+                                                      apir_buffer_type_host_handle_t host_handle,
+                                                      const ggml_tensor *            op);
+
+/* buffer */
+void * apir_buffer_get_base(struct virtgpu * gpu, apir_buffer_context_t * buffer_context);
+void   apir_buffer_set_tensor(struct virtgpu *        gpu,
+                              apir_buffer_context_t * buffer_context,
+                              ggml_tensor *           tensor,
+                              const void *            data,
+                              size_t                  offset,
+                              size_t                  size);
+void   apir_buffer_get_tensor(struct virtgpu *        gpu,
+                              apir_buffer_context_t * buffer_context,
+                              const ggml_tensor *     tensor,
+                              void *                  data,
+                              size_t                  offset,
+                              size_t                  size);
+bool   apir_buffer_cpy_tensor(struct virtgpu *        gpu,
+                              apir_buffer_context_t * buffer_context,
+                              const ggml_tensor *     src,
+                              const ggml_tensor *     dst);
+void   apir_buffer_clear(struct virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value);
+void   apir_buffer_free_buffer(struct virtgpu * gpu, apir_buffer_context_t * buffer_context);
+
+/* backend */
+ggml_status apir_backend_graph_compute(struct virtgpu * gpu, ggml_cgraph * cgraph);
diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp
new file mode 100644
index 00000000..ce6b3b3e
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp
@@ -0,0 +1,98 @@
+#include "virtgpu-shm.h"
+
+#include "virtgpu.h"
+
+#include 
+
+static uint32_t virtgpu_ioctl_resource_create_blob(virtgpu *  gpu,
+                                                   uint32_t   blob_mem,
+                                                   uint32_t   blob_flags,
+                                                   size_t     blob_size,
+                                                   uint64_t   blob_id,
+                                                   uint32_t * res_id) {
+#ifdef SIMULATE_BO_SIZE_FIX
+    blob_size = align64(blob_size, 4096);
+#endif
+
+    drm_virtgpu_resource_create_blob args = {
+        .blob_mem   = blob_mem,
+        .blob_flags = blob_flags,
+        .bo_handle  = 0,
+        .res_handle = 0,
+        .size       = blob_size,
+        .pad        = 0,
+        .cmd_size   = 0,
+        .cmd        = 0,
+        .blob_id    = blob_id,
+    };
+
+    if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_RESOURCE_CREATE_BLOB, &args)) {
+        return 0;
+    }
+
+    *res_id = args.res_handle;
+    return args.bo_handle;
+}
+
+static void virtgpu_ioctl_gem_close(virtgpu * gpu, uint32_t gem_handle) {
+    drm_gem_close args = {
+        .handle = gem_handle,
+        .pad    = 0,
+    };
+
+    const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_GEM_CLOSE, &args);
+    assert(!ret);
+#ifdef NDEBUG
+    UNUSED(ret);
+#endif
+}
+
+static void * virtgpu_ioctl_map(virtgpu * gpu, uint32_t gem_handle, size_t size) {
+    drm_virtgpu_map args = {
+        .offset = 0,
+        .handle = gem_handle,
+        .pad    = 0,
+    };
+
+    if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_MAP, &args)) {
+        return NULL;
+    }
+
+    void * ptr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, gpu->fd, args.offset);
+    if (ptr == MAP_FAILED) {
+        return NULL;
+    }
+
+    return ptr;
+}
+
+void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem) {
+    munmap(shmem->mmap_ptr, shmem->mmap_size);
+    virtgpu_ioctl_gem_close(gpu, shmem->gem_handle);
+}
+
+int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem) {
+    size = align64(size, 16384);
+
+    uint32_t res_id;
+    uint32_t gem_handle = virtgpu_ioctl_resource_create_blob(gpu, VIRTGPU_BLOB_MEM_HOST3D,
+                                                             VIRTGPU_BLOB_FLAG_USE_MAPPABLE, size, 0, &res_id);
+
+    if (!gem_handle) {
+        return 1;
+    }
+
+    void * ptr = virtgpu_ioctl_map(gpu, gem_handle, size);
+    if (!ptr) {
+        virtgpu_ioctl_gem_close(gpu, gem_handle);
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_ioctl_map failed\n", __func__);
+        return 1;
+    }
+
+    shmem->res_id     = res_id;
+    shmem->mmap_size  = size;
+    shmem->mmap_ptr   = ptr;
+    shmem->gem_handle = gem_handle;
+
+    return 0;
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.h b/ggml/src/ggml-virtgpu/virtgpu-shm.h
new file mode 100644
index 00000000..606860a0
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-shm.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "virtgpu-utils.h"
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+struct virtgpu;
+
+struct virtgpu_shmem {
+    uint32_t res_id;
+    size_t   mmap_size;
+    void *   mmap_ptr;
+
+    uint32_t gem_handle;
+};
+
+int  virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem);
+void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem);
diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.cpp b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp
new file mode 100644
index 00000000..8a2805e9
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp
@@ -0,0 +1,179 @@
+#include "virtgpu-utils.h"
+
+#include 
+#include 
+
+#include 
+
+#define NODE_ALLOC_ALIGN 64
+#define NODE_PTR_MASK    (~((uintptr_t) NODE_ALLOC_ALIGN - 1))
+#define NODE_LEVEL_MASK  ((uintptr_t) NODE_ALLOC_ALIGN - 1)
+#define NULL_NODE        0
+
+#define os_malloc_aligned(_size, _align) _aligned_malloc(_size, _align)
+#define os_free_aligned(_ptr)            free(_ptr)
+#define p_atomic_cmpxchg(v, old, _new)   __sync_val_compare_and_swap((v), (old), (_new))
+
+static inline uint64_t util_logbase2_64(uint64_t n) {
+#if defined(HAVE___BUILTIN_CLZLL)
+    return ((sizeof(uint64_t) * 8 - 1) - __builtin_clzll(n | 1));
+#else
+    uint64_t pos = 0ull;
+    if (n >= 1ull << 32) {
+        n >>= 32;
+        pos += 32;
+    }
+    if (n >= 1ull << 16) {
+        n >>= 16;
+        pos += 16;
+    }
+    if (n >= 1ull << 8) {
+        n >>= 8;
+        pos += 8;
+    }
+    if (n >= 1ull << 4) {
+        n >>= 4;
+        pos += 4;
+    }
+    if (n >= 1ull << 2) {
+        n >>= 2;
+        pos += 2;
+    }
+    if (n >= 1ull << 1) {
+        pos += 1;
+    }
+    return pos;
+#endif
+}
+
+void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size) {
+    memset(arr, 0, sizeof(*arr));
+    arr->elem_size      = elem_size;
+    arr->node_size_log2 = util_logbase2_64(node_size);
+    assert(node_size >= 2 && node_size == (1ull << arr->node_size_log2));
+}
+
+static inline void * os_malloc_aligned(size_t size, size_t alignment) {
+    void * ptr;
+    alignment = (alignment + sizeof(void *) - 1) & ~(sizeof(void *) - 1);
+    if (posix_memalign(&ptr, alignment, size) != 0) {
+        return NULL;
+    }
+    return ptr;
+}
+
+static inline void * _util_sparse_array_node_data(uintptr_t handle) {
+    return (void *) (handle & NODE_PTR_MASK);
+}
+
+static inline unsigned _util_sparse_array_node_level(uintptr_t handle) {
+    return handle & NODE_LEVEL_MASK;
+}
+
+static inline void _util_sparse_array_node_finish(util_sparse_array * arr, uintptr_t node) {
+    if (_util_sparse_array_node_level(node) > 0) {
+        uintptr_t * children  = (uintptr_t *) _util_sparse_array_node_data(node);
+        size_t      node_size = 1ull << arr->node_size_log2;
+        for (size_t i = 0; i < node_size; i++) {
+            if (children[i]) {
+                _util_sparse_array_node_finish(arr, children[i]);
+            }
+        }
+    }
+
+    os_free_aligned(_util_sparse_array_node_data(node));
+}
+
+static inline uintptr_t _util_sparse_array_node(void * data, unsigned level) {
+    assert(data != NULL);
+    assert(((uintptr_t) data & NODE_LEVEL_MASK) == 0);
+    assert((level & NODE_PTR_MASK) == 0);
+    return (uintptr_t) data | level;
+}
+
+inline uintptr_t _util_sparse_array_node_alloc(util_sparse_array * arr, unsigned level) {
+    size_t size;
+    if (level == 0) {
+        size = arr->elem_size << arr->node_size_log2;
+    } else {
+        size = sizeof(uintptr_t) << arr->node_size_log2;
+    }
+
+    void * data = os_malloc_aligned(size, NODE_ALLOC_ALIGN);
+    memset(data, 0, size);
+
+    return _util_sparse_array_node(data, level);
+}
+
+static inline uintptr_t _util_sparse_array_set_or_free_node(uintptr_t * node_ptr, uintptr_t cmp_node, uintptr_t node) {
+    uintptr_t prev_node = p_atomic_cmpxchg(node_ptr, cmp_node, node);
+
+    if (prev_node != cmp_node) {
+        /* We lost the race.  Free this one and return the one that was already
+       * allocated.
+       */
+        os_free_aligned(_util_sparse_array_node_data(node));
+        return prev_node;
+    } else {
+        return node;
+    }
+}
+
+void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx) {
+    const unsigned node_size_log2 = arr->node_size_log2;
+    uintptr_t      root           = p_atomic_read(&arr->root);
+    if (unlikely(!root)) {
+        unsigned root_level = 0;
+        uint64_t idx_iter   = idx >> node_size_log2;
+        while (idx_iter) {
+            idx_iter >>= node_size_log2;
+            root_level++;
+        }
+        uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level);
+        root               = _util_sparse_array_set_or_free_node(&arr->root, NULL_NODE, new_root);
+    }
+
+    while (1) {
+        unsigned root_level = _util_sparse_array_node_level(root);
+        uint64_t root_idx   = idx >> (root_level * node_size_log2);
+        if (likely(root_idx < (1ull << node_size_log2))) {
+            break;
+        }
+
+        /* In this case, we have a root but its level is low enough that the
+       * requested index is out-of-bounds.
+       */
+        uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level + 1);
+
+        uintptr_t * new_root_children = (uintptr_t *) _util_sparse_array_node_data(new_root);
+        new_root_children[0]          = root;
+
+        /* We only add one at a time instead of the whole tree because it's
+       * easier to ensure correctness of both the tree building and the
+       * clean-up path.  Because we're only adding one node we never have to
+       * worry about trying to free multiple things without freeing the old
+       * things.
+       */
+        root = _util_sparse_array_set_or_free_node(&arr->root, root, new_root);
+    }
+
+    void *   node_data  = _util_sparse_array_node_data(root);
+    unsigned node_level = _util_sparse_array_node_level(root);
+    while (node_level > 0) {
+        uint64_t child_idx = (idx >> (node_level * node_size_log2)) & ((1ull << node_size_log2) - 1);
+
+        uintptr_t * children = (uintptr_t *) node_data;
+        uintptr_t   child    = p_atomic_read(&children[child_idx]);
+
+        if (unlikely(!child)) {
+            child = _util_sparse_array_node_alloc(arr, node_level - 1);
+            child = _util_sparse_array_set_or_free_node(&children[child_idx], NULL_NODE, child);
+        }
+
+        node_data  = _util_sparse_array_node_data(child);
+        node_level = _util_sparse_array_node_level(child);
+    }
+
+    uint64_t elem_idx = idx & ((1ull << node_size_log2) - 1);
+    return (void *) ((char *) node_data + (elem_idx * arr->elem_size));
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.h b/ggml/src/ggml-virtgpu/virtgpu-utils.h
new file mode 100644
index 00000000..a0036b4e
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu-utils.h
@@ -0,0 +1,86 @@
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define unlikely(x) __builtin_expect(!!(x), 0)
+#define likely(x)   __builtin_expect(!!(x), 1)
+
+#ifndef UNUSED
+#    define UNUSED(x) (void) (x)
+#endif
+
+/** Checks is a value is a power of two. Does not handle zero. */
+#define IS_POT(v) (((v) & ((v) - 1)) == 0)
+
+/** Checks is a value is a power of two. Zero handled. */
+#define IS_POT_NONZERO(v) ((v) != 0 && IS_POT(v))
+
+/** Align a value to a power of two */
+#define ALIGN_POT(x, pot_align) (((x) + (pot_align) - 1) & ~((pot_align) - 1))
+
+#define p_atomic_read(_v) __atomic_load_n((_v), __ATOMIC_ACQUIRE)
+
+static inline bool util_is_power_of_two_nonzero64(uint64_t v) {
+    return IS_POT_NONZERO(v);
+}
+
+static inline uint64_t align64(uint64_t value, uint64_t alignment) {
+    assert(util_is_power_of_two_nonzero64(alignment));
+    return ALIGN_POT(value, alignment);
+}
+
+struct list_head {
+    list_head * prev;
+    list_head * next;
+};
+
+struct util_sparse_array {
+    size_t   elem_size;
+    unsigned node_size_log2;
+
+    uintptr_t root;
+};
+
+void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx);
+void   util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size);
+
+inline void os_time_sleep(int64_t usecs) {
+    timespec time;
+    time.tv_sec  = usecs / 1000000;
+    time.tv_nsec = (usecs % 1000000) * 1000;
+    while (clock_nanosleep(CLOCK_MONOTONIC, 0, &time, &time) == EINTR)
+        ;
+}
+
+struct timer_data {
+    long long start;
+    long long total;
+    long long count;
+};
+
+static inline void start_timer(timer_data * timer) {
+    timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    timer->start = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec;
+}
+
+// returns the duration in ns
+static inline long long stop_timer(timer_data * timer) {
+    timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    long long timer_end = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec;
+
+    long long duration = (timer_end - timer->start);
+    timer->total += duration;
+    timer->count += 1;
+
+    return duration;
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp
new file mode 100644
index 00000000..a84a7739
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu.cpp
@@ -0,0 +1,544 @@
+#include "virtgpu.h"
+
+#include 
+#include 
+
+#include 
+#include 
+#include 
+
+static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev);
+static virt_gpu_result_t virtgpu_open(virtgpu * gpu);
+
+static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu);
+static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu);
+
+static int      virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id);
+static int      virtgpu_ioctl_get_caps(virtgpu *             gpu,
+                                       virgl_renderer_capset id,
+                                       uint32_t              version,
+                                       void *                capset,
+                                       size_t                capset_size);
+static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param);
+static void     virtgpu_init_renderer_info(virtgpu * gpu);
+
+static void log_call_duration(long long call_duration_ns, const char * name);
+
+const uint64_t APIR_HANDSHAKE_MAX_WAIT_MS   = 2 * 1000;   // 2s
+const uint64_t APIR_LOADLIBRARY_MAX_WAIT_MS = 60 * 1000;  // 60s
+
+static int virtgpu_handshake(virtgpu * gpu) {
+    apir_encoder * encoder;
+    apir_decoder * decoder;
+
+    encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_HANDSHAKE, 0);
+    if (!encoder) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__);
+        return 1;
+    }
+
+    /* write handshake props */
+
+    uint32_t guest_major = APIR_PROTOCOL_MAJOR;
+    uint32_t guest_minor = APIR_PROTOCOL_MINOR;
+    apir_encode_uint32_t(encoder, &guest_major);
+    apir_encode_uint32_t(encoder, &guest_minor);
+
+    /* *** */
+
+    uint32_t  ret_magic;
+    long long call_duration_ns;
+    ret_magic = remote_call(gpu, encoder, &decoder, APIR_HANDSHAKE_MAX_WAIT_MS, &call_duration_ns);
+    log_call_duration(call_duration_ns, "API Remoting handshake");
+
+    if (!decoder) {
+        GGML_ABORT(GGML_VIRTGPU
+                   "%s: failed to initiate the communication with the virglrenderer library. "
+                   "Most likely, the wrong virglrenderer library was loaded in the hypervisor.",
+                   __func__);
+        return 1;
+    }
+
+    /* read handshake return values */
+
+    uint32_t host_major;
+    uint32_t host_minor;
+
+    if (ret_magic != APIR_HANDSHAKE_MAGIC) {
+        GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic,
+                   apir_backend_initialize_error(ret_magic));
+    } else {
+        apir_decode_uint32_t(decoder, &host_major);
+        apir_decode_uint32_t(decoder, &host_minor);
+    }
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    if (ret_magic != APIR_HANDSHAKE_MAGIC) {
+        return 1;
+    }
+
+    GGML_LOG_INFO(GGML_VIRTGPU "%s: Guest is running with %u.%u\n", __func__, guest_major, guest_minor);
+    GGML_LOG_INFO(GGML_VIRTGPU "%s: Host is running with %u.%u\n", __func__, host_major, host_minor);
+
+    if (guest_major != host_major) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "Host major (%d) and guest major (%d) version differ\n", host_major, guest_major);
+    } else if (guest_minor != host_minor) {
+        GGML_LOG_WARN(GGML_VIRTGPU "Host minor (%d) and guest minor (%d) version differ\n", host_minor, guest_minor);
+    }
+
+    return 0;
+}
+
+static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) {
+    apir_encoder *            encoder;
+    apir_decoder *            decoder;
+    ApirLoadLibraryReturnCode ret;
+
+    encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_LOADLIBRARY, 0);
+    if (!encoder) {
+        GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to prepare the API Remoting command encoder", __func__);
+        return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR;
+    }
+
+    long long call_duration_ns;
+
+    ret = (ApirLoadLibraryReturnCode) remote_call(gpu, encoder, &decoder, APIR_LOADLIBRARY_MAX_WAIT_MS,
+                                                  &call_duration_ns);
+    log_call_duration(call_duration_ns, "API Remoting LoadLibrary");
+
+    if (!decoder) {
+        GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to trigger the API Remoting hypercall.\n", __func__);
+        return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR;
+    }
+
+    remote_call_finish(gpu, encoder, decoder);
+
+    if (ret == APIR_LOAD_LIBRARY_SUCCESS) {
+        GGML_LOG_INFO(GGML_VIRTGPU "The API Remoting backend was successfully loaded and initialized\n");
+
+        return ret;
+    }
+
+    // something wrong happened, find out what.
+    if (ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {
+        if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) {
+            GGML_ABORT(GGML_VIRTGPU
+                       "%s: virglrenderer could not open the API Remoting backend library, "
+                       "some environment variables are missing. "
+                       "Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
+                       __func__, apir_load_library_error(ret));
+        } else if (ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) {
+            GGML_ABORT(GGML_VIRTGPU
+                       "%s: virglrenderer could not open the API Remoting backend library. "
+                       "Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
+                       __func__, apir_load_library_error(ret));
+        } else if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) {
+            GGML_ABORT(GGML_VIRTGPU
+                       "%s: could not load the backend library, some symbols are missing. "
+                       "Make sure virglrenderer is correctly configured by the hypervisor. (%s) ",
+                       __func__, apir_load_library_error(ret));
+        } else {
+            GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)",
+                       __func__, apir_load_library_error(ret), ret);
+        }
+        return ret;
+    }
+
+    GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__);
+
+    ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX);
+
+    if (apir_ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) {
+        GGML_ABORT(GGML_VIRTGPU
+                   "%s: the API Remoting backend library couldn't load the GGML backend library. "
+                   "Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
+                   __func__, apir_load_library_error(apir_ret));
+    } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) {
+        GGML_ABORT(
+            GGML_VIRTGPU
+            "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. "
+            "Make sure virglrenderer is correctly configured by the hypervisor. (%s)",
+            __func__, apir_load_library_error(apir_ret));
+    } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) {
+        GGML_ABORT(GGML_VIRTGPU
+                   "%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)",
+                   __func__, apir_ret, apir_load_library_error(apir_ret));
+    } else {
+        uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX;
+        GGML_ABORT(GGML_VIRTGPU
+                   "%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)",
+                   __func__, lib_ret);
+    }
+    return ret;
+}
+
+virtgpu * create_virtgpu() {
+    virtgpu * gpu = new virtgpu();
+
+    gpu->use_apir_capset = getenv("GGML_REMOTING_USE_APIR_CAPSET") != nullptr;
+    util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024);
+
+    // Initialize mutex to protect shared data_shmem buffer
+    if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) {
+        delete gpu;
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__);
+        return NULL;
+    }
+
+    if (virtgpu_open(gpu) != APIR_SUCCESS) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__);
+        return NULL;
+    }
+
+    if (virtgpu_init_capset(gpu) != APIR_SUCCESS) {
+        if (gpu->use_apir_capset) {
+            GGML_ABORT(GGML_VIRTGPU
+                       "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library "
+                       "supports it.",
+                       __func__);
+        } else {
+            GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__);
+        }
+        return NULL;
+    }
+
+    if (virtgpu_init_context(gpu) != APIR_SUCCESS) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__);
+        return NULL;
+    }
+
+    if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__);
+        return NULL;
+    }
+
+    if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__);
+        return NULL;
+    }
+
+    if (virtgpu_handshake(gpu)) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__);
+        return NULL;
+    }
+
+    if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__);
+        return NULL;
+    }
+
+    return gpu;
+}
+
+static virt_gpu_result_t virtgpu_open(virtgpu * gpu) {
+    drmDevicePtr devs[8];
+    int          count = drmGetDevices2(0, devs, ARRAY_SIZE(devs));
+    if (count < 0) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__);
+        return APIR_ERROR_INITIALIZATION_FAILED;
+    }
+
+    virt_gpu_result_t result = APIR_ERROR_INITIALIZATION_FAILED;
+    for (int i = 0; i < count; i++) {
+        result = virtgpu_open_device(gpu, devs[i]);
+        if (result == APIR_SUCCESS) {
+            break;
+        }
+    }
+
+    drmFreeDevices(devs, count);
+
+    return result;
+}
+
+static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev) {
+    const char * node_path = dev->nodes[DRM_NODE_RENDER];
+
+    int fd = open(node_path, O_RDWR | O_CLOEXEC);
+    if (fd < 0) {
+        GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path);
+        return APIR_ERROR_INITIALIZATION_FAILED;
+    }
+
+    drmVersionPtr version = drmGetVersion(fd);
+    if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) {
+        if (version) {
+            GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name,
+                           version->version_major);
+        } else {
+            GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__);
+        }
+
+        if (version) {
+            drmFreeVersion(version);
+        }
+        close(fd);
+        return APIR_ERROR_INITIALIZATION_FAILED;
+    }
+
+    gpu->fd = fd;
+
+    drmFreeVersion(version);
+
+    GGML_LOG_INFO(GGML_VIRTGPU "using DRM device %s\n", node_path);
+
+    return APIR_SUCCESS;
+}
+
+static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) {
+    assert(!gpu->capset.version);
+    const int ret = virtgpu_ioctl_context_init(gpu, gpu->capset.id);
+    if (ret) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to initialize context: %s\n", __func__, strerror(errno));
+        return APIR_ERROR_INITIALIZATION_FAILED;
+    }
+
+    return APIR_SUCCESS;
+}
+
+static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) {
+    if (gpu->use_apir_capset) {
+        GGML_LOG_INFO(GGML_VIRTGPU "Using the APIR capset\n");
+        gpu->capset.id = VIRTGPU_DRM_CAPSET_APIR;
+    } else {
+        GGML_LOG_INFO(GGML_VIRTGPU "Using the Venus capset\n");
+        gpu->capset.id = VIRTGPU_DRM_CAPSET_VENUS;
+    }
+    gpu->capset.version = 0;
+
+    int ret =
+        virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data));
+
+    if (ret) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version,
+                       strerror(errno));
+        return APIR_ERROR_INITIALIZATION_FAILED;
+    }
+
+    assert(gpu->capset.data.supports_blob_resources);
+
+    return APIR_SUCCESS;
+}
+
+static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id) {
+    drm_virtgpu_context_set_param ctx_set_params[3] = {
+        {
+         .param = VIRTGPU_CONTEXT_PARAM_CAPSET_ID,
+         .value = capset_id,
+         },
+        {
+         .param = VIRTGPU_CONTEXT_PARAM_NUM_RINGS,
+         .value = 1,
+         },
+        {
+         .param = VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK,
+         .value = 0, /* don't generate drm_events on fence signaling */
+        },
+    };
+
+    drm_virtgpu_context_init args = {
+        .num_params     = ARRAY_SIZE(ctx_set_params),
+        .pad            = 0,
+        .ctx_set_params = (uintptr_t) &ctx_set_params,
+    };
+
+    return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_CONTEXT_INIT, &args);
+}
+
+static int virtgpu_ioctl_get_caps(virtgpu *             gpu,
+                                  virgl_renderer_capset id,
+                                  uint32_t              version,
+                                  void *                capset,
+                                  size_t                capset_size) {
+    drm_virtgpu_get_caps args = {
+        .cap_set_id  = id,
+        .cap_set_ver = version,
+        .addr        = (uintptr_t) capset,
+        .size        = (__u32) capset_size,
+        .pad         = 0,
+    };
+
+    return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GET_CAPS, &args);
+}
+
+static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param) {
+    /* val must be zeroed because kernel only writes the lower 32 bits */
+    uint64_t             val  = 0;
+    drm_virtgpu_getparam args = {
+        .param = param,
+        .value = (uintptr_t) &val,
+    };
+
+    const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GETPARAM, &args);
+    return ret ? 0 : val;
+}
+
+apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags) {
+    /*
+     * Prepare the command encoder and its buffer
+     */
+
+    thread_local char encoder_buffer[4096];
+
+    thread_local apir_encoder enc;
+    enc = {
+        .cur   = encoder_buffer,
+        .start = encoder_buffer,
+        .end   = encoder_buffer + sizeof(encoder_buffer),
+        .fatal = false,
+    };
+
+    /*
+     * Fill the command encoder with the common args:
+     * - cmd_type (int32_t)
+     * - cmd_flags (int32_t)
+     * - reply res id (uint32_t)
+   */
+
+    int32_t cmd_type = apir_cmd_type;
+
+    // for testing during the hypervisor transition
+    if (!gpu->use_apir_capset) {
+        cmd_type += VENUS_COMMAND_TYPE_LENGTH;
+    }
+    apir_encode_int32_t(&enc, &cmd_type);
+    apir_encode_int32_t(&enc, &cmd_flags);
+
+    uint32_t reply_res_id = gpu->reply_shmem.res_id;
+    apir_encode_uint32_t(&enc, &reply_res_id);
+
+    return &enc;
+}
+
+void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec) {
+    UNUSED(gpu);
+
+    if (!enc) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) encoder", __func__);
+    }
+
+    if (!dec) {
+        GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) decoder", __func__);
+    }
+
+    if (apir_encoder_get_fatal(enc)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to encode the output parameters.", __func__);
+    }
+
+    if (apir_decoder_get_fatal(dec)) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to decode the input parameters.", __func__);
+    }
+}
+
+uint32_t remote_call(virtgpu *       gpu,
+                     apir_encoder *  encoder,
+                     apir_decoder ** decoder,
+                     float           max_wait_ms,
+                     long long *     call_duration_ns) {
+    /*
+     * Prepare the reply notification pointer
+     */
+
+    volatile std::atomic_uint * atomic_reply_notif = (volatile std::atomic_uint *) gpu->reply_shmem.mmap_ptr;
+    *atomic_reply_notif                            = 0;
+
+    /*
+     * Trigger the execbuf ioctl
+     */
+
+    drm_virtgpu_execbuffer args = {
+        .flags   = VIRTGPU_EXECBUF_RING_IDX,
+        .size    = (uint32_t) (encoder->cur - encoder->start),
+        .command = (uintptr_t) encoder->start,
+
+        .bo_handles     = 0,
+        .num_bo_handles = 0,
+
+        .fence_fd         = 0,
+        .ring_idx         = 0,
+        .syncobj_stride   = 0,
+        .num_in_syncobjs  = 0,
+        .num_out_syncobjs = 0,
+        .in_syncobjs      = 0,
+        .out_syncobjs     = 0,
+    };
+
+    *decoder = NULL;
+
+    int ret = drmIoctl(gpu->fd, DRM_IOCTL_VIRTGPU_EXECBUFFER, &args);
+
+    if (ret != 0) {
+        GGML_ABORT(GGML_VIRTGPU "%s: the virtgpu EXECBUFFER ioctl failed (%d)", __func__, ret);
+    }
+
+    /*
+     * Wait for the response notification
+     */
+    timer_data wait_host_reply_timer = { 0, 0, 0 };
+
+    start_timer(&wait_host_reply_timer);
+
+    timespec ts_start, ts_end;
+    clock_gettime(CLOCK_MONOTONIC, &ts_start);
+    long long start_time = (long long) ts_start.tv_sec * 1000000000LL + ts_start.tv_nsec;
+
+    bool     timedout    = false;
+    uint32_t notif_value = 0;
+    while (true) {
+        notif_value = std::atomic_load_explicit(atomic_reply_notif, std::memory_order_acquire);
+
+        if (notif_value != 0) {
+            break;
+        }
+
+        int64_t base_sleep_us = 15;
+
+        os_time_sleep(base_sleep_us);
+
+        if (max_wait_ms) {
+            clock_gettime(CLOCK_MONOTONIC, &ts_end);
+            long long end_time    = (long long) ts_end.tv_sec * 1000000000LL + ts_end.tv_nsec;
+            float     duration_ms = (end_time - start_time) / 1000000;
+
+            if (duration_ms > max_wait_ms) {
+                timedout = true;
+                break;
+            }
+        }
+    }
+
+    if (call_duration_ns) {
+        *call_duration_ns = stop_timer(&wait_host_reply_timer);
+    }
+
+    if (max_wait_ms && timedout) {
+        GGML_LOG_ERROR(GGML_VIRTGPU "%s: timed out waiting for the host answer...\n", __func__);
+        return APIR_FORWARD_TIMEOUT;
+    }
+
+    /*
+     * Prepare the decoder
+     */
+    static apir_decoder response_dec;
+    response_dec.cur = (char *) gpu->reply_shmem.mmap_ptr + sizeof(*atomic_reply_notif);
+    response_dec.end = (char *) gpu->reply_shmem.mmap_ptr + gpu->reply_shmem.mmap_size;
+    *decoder         = &response_dec;
+
+    // extract the actual return value from the notif flag
+    uint32_t returned_value = notif_value - 1;
+    return returned_value;
+}
+
+static void log_call_duration(long long call_duration_ns, const char * name) {
+    double call_duration_ms = (double) call_duration_ns / 1e6;  // 1 millisecond = 1e6 nanoseconds
+    double call_duration_s  = (double) call_duration_ns / 1e9;  // 1 second = 1e9 nanoseconds
+
+    if (call_duration_s > 1) {
+        GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name);
+    } else if (call_duration_ms > 1) {
+        GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name);
+    } else {
+        GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name);
+    }
+}
diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h
new file mode 100644
index 00000000..f82d8fb5
--- /dev/null
+++ b/ggml/src/ggml-virtgpu/virtgpu.h
@@ -0,0 +1,117 @@
+#pragma once
+
+// clang-format off
+#include "virtgpu-utils.h"
+#include "virtgpu-shm.h"
+#include "virtgpu-apir.h"
+
+#include "backend/shared/api_remoting.h"
+#include "backend/shared/apir_cs.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+#include "ggml-remoting.h"
+
+#define VIRGL_RENDERER_UNSTABLE_APIS 1
+#include "apir_hw.h"
+#include 
+#include "venus_hw.h"
+// clang-format on
+
+#ifndef VIRTGPU_DRM_CAPSET_APIR
+// Will be defined include/drm/virtgpu_drm.h when
+// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs
+// is merged
+#    define VIRTGPU_DRM_CAPSET_APIR 10
+#endif
+
+// Mesa/Virlgrenderer Venus internal. Only necessary during the
+// Venus->APIR transition in Virglrenderer
+#define VENUS_COMMAND_TYPE_LENGTH 331
+
+#ifndef VIRTGPU_DRM_CAPSET_VENUS  // only available with Linux >= v6.16
+#    define VIRTGPU_DRM_CAPSET_VENUS 4
+#endif
+
+typedef uint32_t virgl_renderer_capset;
+
+/* from src/virtio/vulkan/vn_renderer_virtgpu.c */
+#define VIRTGPU_PCI_VENDOR_ID       0x1af4
+#define VIRTGPU_PCI_DEVICE_ID       0x1050
+#define VIRTGPU_BLOB_MEM_GUEST_VRAM 0x0004
+#define VIRTGPU_PARAM_GUEST_VRAM    9
+
+#define SHMEM_DATA_SIZE  0x1830000  // 24MiB
+#define SHMEM_REPLY_SIZE 0x4000
+
+#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
+
+enum virt_gpu_result_t {
+    APIR_SUCCESS                     = 0,
+    APIR_ERROR_INITIALIZATION_FAILED = -1,
+};
+
+#define PRINTFLIKE(f, a) __attribute__((format(__printf__, f, a)))
+
+struct virtgpu {
+    bool use_apir_capset;
+
+    int fd;
+
+    struct {
+        virgl_renderer_capset      id;
+        uint32_t                   version;
+        virgl_renderer_capset_apir data;
+    } capset;
+
+    util_sparse_array shmem_array;
+
+    /* APIR communication pages */
+    virtgpu_shmem reply_shmem;
+    virtgpu_shmem data_shmem;
+
+    /* Mutex to protect shared data_shmem buffer from concurrent access */
+    mtx_t data_shmem_mutex;
+
+    /* Cached device information to prevent memory leaks and race conditions */
+    struct {
+        char *   description;
+        char *   name;
+        int32_t  device_count;
+        uint32_t type;
+        size_t   memory_free;
+        size_t   memory_total;
+    } cached_device_info;
+
+    /* Cached buffer type information to prevent memory leaks and race conditions */
+    struct {
+        apir_buffer_type_host_handle_t host_handle;
+        char *                         name;
+        size_t                         alignment;
+        size_t                         max_size;
+    } cached_buffer_type;
+};
+
+static inline int virtgpu_ioctl(virtgpu * gpu, unsigned long request, void * args) {
+    return drmIoctl(gpu->fd, request, args);
+}
+
+virtgpu * create_virtgpu();
+
+apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags);
+
+uint32_t remote_call(virtgpu *       gpu,
+                     apir_encoder *  enc,
+                     apir_decoder ** dec,
+                     float           max_wait_ms,
+                     long long *     call_duration_ns);
+
+void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec);
diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt
index de01336c..715a263a 100644
--- a/ggml/src/ggml-vulkan/CMakeLists.txt
+++ b/ggml/src/ggml-vulkan/CMakeLists.txt
@@ -90,7 +90,7 @@ if (Vulkan_FOUND)
     target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
 
     # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
-    # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
+    # Possibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
     if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
         add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
     endif()
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index deed5055..7092361d 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -27,6 +27,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -92,6 +93,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
 #define VK_VENDOR_ID_APPLE 0x106b
 #define VK_VENDOR_ID_INTEL 0x8086
 #define VK_VENDOR_ID_NVIDIA 0x10de
+#define VK_VENDOR_ID_QUALCOMM 0x5143
 
 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
 
@@ -187,6 +189,11 @@ struct ggml_backend_vk_buffer_type_context {
 
 struct vk_queue;
 
+struct vk_command_buffer {
+    vk::CommandBuffer buf;
+    bool in_use = false;
+};
+
 // Stores command pool/buffers. There's an instance of this
 // for each (context,queue) pair and for each (device,queue) pair.
 struct vk_command_pool {
@@ -194,10 +201,16 @@ struct vk_command_pool {
     void destroy(vk::Device& device);
 
     vk::CommandPool pool;
-    uint32_t cmd_buffer_idx;
-    std::vector cmd_buffers;
+    // Using deque so the pointers to command buffers
+    // remain valid even if we add more
+    std::deque cmd_buffers;
 
     vk_queue *q;
+
+    size_t buffers_in_use() const {
+        return std::count_if(cmd_buffers.begin(), cmd_buffers.end(),
+            [](const auto& cb) { return cb.in_use; });
+    }
 };
 
 // Prevent simultaneous submissions to the same queue.
@@ -254,6 +267,7 @@ enum vk_device_architecture {
     AMD_RDNA3,
     INTEL_XE2,
     NVIDIA_PRE_TURING,
+    NVIDIA_TURING,
 };
 
 static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
@@ -336,18 +350,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
         const std::vector ext_props = device.enumerateDeviceExtensionProperties();
 
         bool cooperative_matrix = false;
+        bool sm_builtins = false;
 
         // Detect "pre-turing" based on lack of coopmat support.
         for (const auto& properties : ext_props) {
             if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
                 cooperative_matrix = true;
-                break;
+            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
+                sm_builtins = true;
             }
         }
 
         if (!cooperative_matrix) {
             return vk_device_architecture::NVIDIA_PRE_TURING;
         }
+
+        if (sm_builtins) {
+            vk::PhysicalDeviceProperties2 props2;
+            vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
+
+            props2.pNext = &sm_props;
+
+            device.getProperties2(&props2);
+
+            // Turing has 32, following architectures have 48
+            if (sm_props.shaderWarpsPerSM == 32) {
+                return vk_device_architecture::NVIDIA_TURING;
+            }
+        }
     }
     return vk_device_architecture::OTHER;
 }
@@ -385,18 +415,20 @@ enum FaCodePath {
 };
 
 struct vk_fa_pipeline_state {
-    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
-        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
-
     uint32_t HSK, HSV;
-    bool small_rows, small_cache;
+    uint32_t Br, Bc;
+    uint32_t D_split, row_split;
+    bool shmem_staging;
     FaCodePath path;
+    uint32_t workgroup_size, subgroup_size;
     bool aligned;
     bool f32acc;
+    uint32_t flags;
+    uint32_t limit_occupancy_shmem;
 
     bool operator<(const vk_fa_pipeline_state &b) const {
-        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
-               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
+        return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
+               std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
     }
 };
 
@@ -570,6 +602,7 @@ struct vk_device_struct {
     vk_queue transfer_queue;
     bool single_queue;
     bool support_async;
+    bool async_use_transfer_queue;
     uint32_t subgroup_size;
     uint32_t subgroup_size_log2;
     uint32_t shader_core_count;
@@ -669,6 +702,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_acc_f32;
+    vk_pipeline pipeline_set_f32;
 
     // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
     vk_pipeline pipeline_add[2][2][2];
@@ -722,6 +756,7 @@ struct vk_device_struct {
 
     // [src/dst 0=fp32,1=fp16]
     vk_pipeline pipeline_exp[2];
+    vk_pipeline pipeline_elu[2];
     vk_pipeline pipeline_gelu[2];
     vk_pipeline pipeline_gelu_erf[2];
     vk_pipeline pipeline_gelu_quick[2];
@@ -740,6 +775,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_ceil[2];
     vk_pipeline pipeline_floor[2];
     vk_pipeline pipeline_trunc[2];
+    vk_pipeline pipeline_sgn[2];
 
     vk_pipeline pipeline_add1_f16_f16;
     vk_pipeline pipeline_add1_f16_f32;
@@ -789,6 +825,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_pool2d_f32;
     vk_pipeline pipeline_rwkv_wkv6_f32;
     vk_pipeline pipeline_rwkv_wkv7_f32;
+    // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
+    vk_pipeline pipeline_gated_delta_net[3][2];
     vk_pipeline pipeline_ssm_scan_f32_d128;
     vk_pipeline pipeline_ssm_scan_f32_d256;
     vk_pipeline pipeline_ssm_conv_f32;
@@ -803,6 +841,8 @@ struct vk_device_struct {
 
     std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
 
+    std::map, vk_pipeline> pipeline_fa_mask_opt;
+
     vk_pipeline pipeline_flash_attn_split_k_reduce;
     vk_pipeline pipeline_count_experts;
 
@@ -852,10 +892,12 @@ struct vk_device_struct {
 };
 
 void vk_command_pool::init(vk_device& device, vk_queue *q_) {
-    cmd_buffer_idx = 0;
+    cmd_buffers.clear();
     q = q_;
 
-    vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
+    vk::CommandPoolCreateInfo command_pool_create_info(
+        vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT),
+        q->queue_family_index);
     pool = device->device.createCommandPool(command_pool_create_info);
 }
 
@@ -903,6 +945,7 @@ struct vk_subbuffer {
 struct vk_event {
     vk::Event event;
     vk::Fence fence;
+    vk_command_buffer* cmd_buffer = nullptr;
 };
 
 struct vk_semaphore {
@@ -911,7 +954,7 @@ struct vk_semaphore {
 };
 
 struct vk_submission {
-    vk::CommandBuffer buffer;
+    vk_command_buffer* buffer = nullptr;
     std::vector wait_semaphores;
     std::vector signal_semaphores;
 };
@@ -922,6 +965,7 @@ struct vk_mat_mat_push_constants {
     uint32_t M; uint32_t N; uint32_t K;
     uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
     uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
+    uint32_t base_work_group_z; uint32_t num_batches;
     uint32_t k_split;
     uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
     uint32_t padded_N;
@@ -941,6 +985,7 @@ struct vk_mat_vec_push_constants {
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
     uint32_t fusion_flags;
+    uint32_t base_work_group_y;
     uint32_t ne02;
     uint32_t ne12;
     uint32_t broadcast2;
@@ -991,6 +1036,8 @@ struct vk_mat_vec_id_push_constants {
     uint32_t fusion_flags;
     uint32_t nei0;
     uint32_t ne11;
+    uint32_t expert_i1;
+    uint32_t nbi1;
 };
 
 struct vk_flash_attn_push_constants {
@@ -1244,25 +1291,30 @@ struct vk_op_diag_mask_push_constants {
 
 struct vk_op_rope_push_constants {
     uint32_t rope_mode;
-    uint32_t ncols;
     uint32_t nrows;
     uint32_t n_dims;
     float freq_scale;
-    uint32_t p_delta_rows;
     float freq_base;
     float ext_factor;
     float attn_factor;
     float corr_dims[2];
     float theta_scale;
     uint32_t has_ff;
-    uint32_t ne02;
-    uint32_t s1;
-    uint32_t s2;
     int32_t sections[4];
     uint32_t is_imrope;
     uint32_t is_back;
     uint32_t set_rows_stride;
+    uint32_t ne00;
+    uint32_t ne01;
+    uint32_t ne02;
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
 };
+static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
 
 // For fused rms_norm+mul+rope(+view+set_rows)
 struct vk_op_rms_norm_mul_rope_push_constants {
@@ -1404,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants {
     uint32_t C;
     uint32_t H;
 };
+struct vk_op_gated_delta_net_push_constants {
+    uint32_t H;
+    uint32_t n_tokens;
+    uint32_t n_seqs;
+    uint32_t s_off;
+    uint32_t sq1, sq2, sq3;
+    uint32_t sv1, sv2, sv3;
+    uint32_t sb1, sb2, sb3;
+    uint32_t neq1, rq3;
+    float scale;
+};
+
 struct vk_op_ssm_scan_push_constants {
     uint32_t nb02, nb03, nb12, nb13;
     uint32_t nb21, nb22, nb31;
@@ -1516,6 +1580,27 @@ struct vk_quantize_q8_1_push_constants {
     uint32_t num_blocks;
 };
 
+struct vk_op_flash_attn_split_k_reduce_push_constants {
+    uint32_t D;
+    uint32_t ne1;
+    uint32_t ne2;
+    uint32_t ne3;
+    uint32_t k_num;
+    uint32_t sinks;
+};
+
+struct vk_op_flash_attn_mask_opt_push_constants {
+    uint32_t nem0;
+    uint32_t nem1;
+    uint32_t nem2;
+    uint32_t nbm1;
+    uint32_t nbm2;
+    uint32_t nbm3;
+    uint32_t nbd1;
+    uint32_t nbd2;
+    uint32_t nbd3;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1604,6 +1689,7 @@ static bool vk_perf_logger_concurrent = false;
 static bool vk_enable_sync_logger = false;
 // number of calls between perf logger prints
 static uint32_t vk_perf_logger_frequency = 1;
+static std::string vk_pipeline_stats_filter;
 
 class vk_perf_logger {
   public:
@@ -1724,6 +1810,7 @@ class vk_perf_logger {
                 " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
                 " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
                 " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
+            *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
             return name.str();
         }
         if (node->op == GGML_OP_TOP_K) {
@@ -1802,7 +1889,10 @@ struct ggml_backend_vk_context {
     bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
 
     vk_context_ref compute_ctx;
+
     vk_context_ref transfer_ctx;
+    vk_semaphore transfer_semaphore;
+    uint64_t transfer_semaphore_last_submitted {};
 
     std::vector tensor_ctxs;
 
@@ -2121,7 +2211,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
         executableInfo.pipeline = pipeline->pipeline;
 
         auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
+
+        bool print_stats = !vk_pipeline_stats_filter.empty() &&
+                           pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
+        if (print_stats) {
+            std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
+        }
+
         for (auto & s : statistics) {
+            if (print_stats) {
+                std::cerr << "ggml_vulkan:   " << s.name.data() << ": ";
+                switch (s.format) {
+                    case vk::PipelineExecutableStatisticFormatKHR::eBool32:
+                        std::cerr << (s.value.b32 ? "true" : "false");
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eInt64:
+                        std::cerr << s.value.i64;
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eUint64:
+                        std::cerr << s.value.u64;
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
+                        std::cerr << s.value.f64;
+                        break;
+                }
+                std::cerr << std::endl;
+            }
             // "Register Count" is reported by NVIDIA drivers.
             if (strcmp(s.name, "Register Count") == 0) {
                 VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
@@ -2197,25 +2312,15 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx
     }
 }
 
-static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
+static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
     VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
-
-    if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
-        // Reuse command buffer
-        return p.cmd_buffers[p.cmd_buffer_idx++];
-    }
-
     vk::CommandBufferAllocateInfo command_buffer_alloc_info(
         p.pool,
         vk::CommandBufferLevel::ePrimary,
         1);
     const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
-    auto buf = cmd_buffers.front();
-
-    p.cmd_buffers.push_back(buf);
-    p.cmd_buffer_idx++;
-
-    return buf;
+    p.cmd_buffers.push_back({ cmd_buffers.front(), true });
+    return &p.cmd_buffers[p.cmd_buffers.size()-1];
 }
 
 static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
@@ -2282,7 +2387,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
                 tl_wait_semaphores[idx].data(),
                 stage_flags[idx].data(),
                 1,
-                &submission.buffer,
+                &submission.buffer->buf,
                 (uint32_t) submission.signal_semaphores.size(),
                 tl_signal_semaphores[idx].data(),
             };
@@ -2406,7 +2511,11 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p)
 
     // Requires command buffers to be done
     device->device.resetCommandPool(p.pool);
-    p.cmd_buffer_idx = 0;
+    // Don't clear the command buffers and mark them as not in use.
+    // This allows us to reuse them
+    for (auto& cmd_buffer : p.cmd_buffers) {
+        cmd_buffer.in_use = false;
+    }
 }
 
 static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
@@ -2415,10 +2524,10 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
     // Arbitrary frequency to cleanup/reuse command buffers
     static constexpr uint32_t cleanup_frequency = 10;
 
-    if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
+    if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
         ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
     }
-    if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
+    if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
         ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
     }
 }
@@ -2666,7 +2775,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
         ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
     }
 
-    subctx->s->buffer.pipelineBarrier(
+    subctx->s->buffer->buf.pipelineBarrier(
         subctx->p->q->stage_flags,
         subctx->p->q->stage_flags,
         {},
@@ -2682,7 +2791,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
 static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
     VK_LOG_DEBUG("ggml_vk_set_event()");
 
-    ctx->s->buffer.setEvent(
+    ctx->s->buffer->buf.setEvent(
         event,
         ctx->p->q->stage_flags
     );
@@ -2694,7 +2803,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events
         return;
     }
 
-    ctx->s->buffer.waitEvents(
+    ctx->s->buffer->buf.waitEvents(
         events,
         ctx->p->q->stage_flags,
         ctx->p->q->stage_flags,
@@ -2704,78 +2813,218 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events
     );
 }
 
-// number of rows/cols for flash attention shader
-static constexpr uint32_t flash_attention_num_small_rows = 32;
-static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
+struct vk_fa_tuning_params {
+    FaCodePath path;
+    uint32_t workgroup_size;
+    uint32_t subgroup_size;
+    uint32_t block_rows;
+    uint32_t block_cols;
+    uint32_t d_split;
+    uint32_t row_split;
+    bool shmem_staging;
+    bool disable_subgroups;
+    uint32_t limit_occupancy_shmem;
 
-static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
-    if (hsv >= 192) {
-        return 2;
-    } else if ((hsv | hsk) & 8 || small_cache) {
-        return 4;
-    } else {
-        return 8;
+    void print() const {
+        std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
+                     " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
+                     " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
+                     " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
     }
-}
+};
 
-// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
-// 128 threads split into four subgroups, each subgroup does 1/4
-// of the Bc dimension.
-static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
-static constexpr uint32_t scalar_flash_attention_Bc = 64;
-static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
 
-static uint32_t get_fa_num_small_rows(FaCodePath path) {
-    if (path == FA_COOPMAT2) {
-        return flash_attention_num_small_rows;
+static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(kv_type);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_SCALAR;
+
+    if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+        // Disable subgroup use due to performance issues when enforcing subgroup sizes
+        result.subgroup_size = 32;
+        result.disable_subgroups = true;
+    } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
+        result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
     } else {
-        return scalar_flash_attention_num_small_rows;
+        result.subgroup_size = device->subgroup_size;
     }
-}
 
-static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
-    GGML_UNUSED(clamp);
+    // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
+    uint32_t row_split_max_hsk = 64;
+    if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
+        row_split_max_hsk = n_rows <= 8 ? 64 : 128;
+    }
+    result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
 
-    if (path == FA_SCALAR) {
-        if (small_rows) {
-            return {scalar_flash_attention_num_small_rows, 64};
+    if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
+        result.workgroup_size = result.subgroup_size * 2;
+    } else {
+        result.workgroup_size = result.subgroup_size * 4;
+    }
+
+    const uint32_t D = hsk | hsv;
+
+    const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
+
+    if (n_rows == 1) {
+        result.block_rows = 1;
+        result.block_cols = 64;
+    } else {
+        // row_split 1 means higher register use per row, so block size has to be adjusted
+        if (result.row_split == 1) {
+            result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
         } else {
-            if ((hsv | hsk) & 8) {
-                // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
-                // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
-                return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
-            } else {
-                return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
-            }
+            result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
         }
+
+        result.block_cols = (D & 8) ? 64 : 32;
+    }
+
+    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
+
+    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
+
+    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
+
+    if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
+        result.block_rows /= 2;
+    }
+
+    // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
+    // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
+    // This targets an occupancy of 4 subgroups per SIMD.
+    if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
+        if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
+            // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
+            // Values are guessed, tested on RDNA2
+            result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
+        } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
+            // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
+            // Here low-batch FA with large head size is affected.
+            // n_rows < 4 switch because workgroup size switches from 128 to 256 there.
+            result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
+        }
+    }
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(n_rows);
+    GGML_UNUSED(n_kv);
+    GGML_UNUSED(kv_type);
+    GGML_UNUSED(f32acc);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_COOPMAT1;
+
+    const uint32_t D = hsk | hsv;
+
+    const uint32_t coopmat_block_rows = 16;
+    const uint32_t coopmat_block_cols = 16;
+
+    const uint32_t num_subgroups = 4;
+
+    result.block_rows = coopmat_block_rows;
+    result.block_cols = coopmat_block_cols * num_subgroups;
+    result.row_split = num_subgroups;
+    result.subgroup_size = device->subgroup_size;
+    result.workgroup_size = num_subgroups * result.subgroup_size;
+
+    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
+    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
+
+    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(n_kv);
+    GGML_UNUSED(f32acc);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_COOPMAT2;
+
+    const uint32_t D = hsk | hsv;
+
+    const bool small_rows = n_rows < 32;
+
+    if (small_rows) {
+        result.block_rows = 32;
+        result.block_cols = 32;
+    } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
+        result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
+        result.block_cols = 32;
+    } else {
+        result.block_rows = 64;
+        result.block_cols = 64;
+    }
+
+    result.subgroup_size = device->subgroup_size;
+    result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
+                      device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
+
+    if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
+        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
+        path = FA_SCALAR;
     }
 
     if (path == FA_COOPMAT1) {
-        if (small_rows) {
-            return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
-        } else {
-            return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
+        bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
+                        (!f32acc && device->coopmat_support_16x16x16_f16acc);
+        const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+        bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
+
+        if (!shape_ok || !shmem_ok) {
+            path = FA_SCALAR;
         }
     }
 
-    // small rows, large cols
-    if (small_rows) {
-        return {get_fa_num_small_rows(FA_COOPMAT2), 32};
+    // scalar is faster than coopmat when N==1
+    if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
+        path = FA_SCALAR;
     }
 
-    // small cols to reduce register count
-    if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
-        if (hsk >= 512 || hsv >= 512) {
-            return {32, 32};
-        } else {
-            return {64, 32};
-        }
+    switch (path) {
+    case FA_SCALAR:
+        return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    case FA_COOPMAT1:
+        return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    case FA_COOPMAT2:
+        return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    default:
+        throw std::runtime_error("unsupported FaCodePath");
     }
-    return {64, 64};
 }
 
-static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
-    return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
+static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
+                                                  bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
+    const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
+                                 (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
+
+    uint32_t flags = (use_mask_opt      ? 1 : 0) |
+                     (use_mask          ? 2 : 0) |
+                     (use_logit_softcap ? 4 : 0) |
+                     (old_amd_windows   ? 8 : 0);
+
+    const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
+
+    return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
+}
+
+static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) {
+    return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
+            state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
 }
 
 static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -3142,60 +3391,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
                                        align, disable_robustness, require_full_subgroups, required_subgroup_size);
     };
 
-    auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array {
-        return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
-    };
-
-    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector {
-        // For large number of rows, 128 invocations seems to work best.
-        // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
-        // can't use 256 for D==80.
-        // For scalar, use 128 (arbitrary)
-        // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
-        const uint32_t D = (hsk|hsv);
-        uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
-                            ? scalar_flash_attention_workgroup_size
-                            : ((small_rows && (D % 32) == 0) ? 256 : 128);
-        auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
-
-        // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
-        // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
-        const uint32_t D_lsb = D ^ (D & (D-1));
-        uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
-
-        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
-    };
-
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
         for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
-            uint32_t HSK = fa.first.HSK; \
-            uint32_t HSV = fa.first.HSV; \
-            bool small_rows = fa.first.small_rows; \
-            bool small_cache = fa.first.small_cache; \
             FaCodePath path = fa.first.path; \
+            uint32_t Br = fa.first.Br; \
+            uint32_t Bc = fa.first.Bc; \
             bool aligned = fa.first.aligned; \
             bool f32acc = fa.first.f32acc; \
+            uint32_t fa_sgs = fa.first.subgroup_size; \
+            bool fa_ds = fa.first.subgroup_size == 0; \
             if (path == FAPATH) { \
                 if (aligned) { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } \
                 } else { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } \
                 } \
             } \
         }
 
-    CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
+    if (device->fp16) {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
+    } else {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
+    }
 #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     if (device->coopmat1_fa_support) {
         CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
@@ -3713,10 +3945,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
         && !device->coopmat_bf16_support
 #endif
         ) {
+        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
+
         // use scalar tile sizes
         l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
         m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
-        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
+        s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };
 
         l_wg_denoms = {128, 128, 1 };
         m_wg_denoms = { 64,  64, 1 };
@@ -3980,7 +4214,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+
+    for (auto &it : device->pipeline_fa_mask_opt) {
+        auto BrBc = it.first;
+        ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
+    }
 
     if (device->subgroup_clustered && device->subgroup_require_full_support) {
         ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
@@ -4012,7 +4251,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4113,7 +4352,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -4158,6 +4398,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);  \
     ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
+    CREATE_UNARY(elu)
     CREATE_UNARY(gelu)
     CREATE_UNARY(gelu_erf)
     CREATE_UNARY(gelu_quick)
@@ -4176,6 +4417,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_UNARY(ceil)
     CREATE_UNARY(floor)
     CREATE_UNARY(trunc)
+    CREATE_UNARY(sgn)
 #undef CREATE_UNARY
 
 #define CREATE_UNARY_RTE(name)  \
@@ -4340,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
+    {
+        const uint32_t gdn_sizes[] = {32, 64, 128};
+        const char * gdn_names[][2] = {
+            {"gated_delta_net_f32_d32",     "gated_delta_net_f32_d32_kda"},
+            {"gated_delta_net_f32_d64",     "gated_delta_net_f32_d64_kda"},
+            {"gated_delta_net_f32_d128",    "gated_delta_net_f32_d128_kda"},
+        };
+        for (uint32_t si = 0; si < 3; si++) {
+            for (uint32_t kda = 0; kda < 2; kda++) {
+                ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
+                    gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
+                    "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
+                    {1, 1, 1}, {gdn_sizes[si], kda}, 1);
+            }
+        }
+    }
+
     if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
@@ -4348,7 +4607,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
     }
 
-    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
@@ -4460,6 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 }
 
 static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
+static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
 
 static vk_device ggml_vk_get_device(size_t idx) {
     VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -4676,6 +4936,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device->shader_core_count = sm_props.shaderSMCount;
         } else if (amd_shader_core_properties2) {
             device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
+        } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+            device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
         } else {
             device->shader_core_count = 0;
         }
@@ -4719,8 +4981,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
         std::vector queue_family_props = device->physical_device.getQueueFamilyProperties();
 
         // Try to find a non-graphics compute queue and transfer-focused queues
-        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
-        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
+        // On AMD, the graphics queue seems to be faster, so don't avoid it
+        const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
+        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
+        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
 
         const float priorities[] = { 1.0f, 1.0f };
         device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
@@ -4895,11 +5159,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
 #if defined(VK_KHR_cooperative_matrix)
         device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
-
-        // coopmat1 fa shader currently assumes 32 invocations per subgroup
-        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
-                                      device->subgroup_size_control && device->subgroup_min_size <= 32 &&
-                                      device->subgroup_max_size >= 32;
+        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
 #endif
 
         if (coopmat2_support) {
@@ -5186,10 +5446,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
         if (!device->single_queue) {
             const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
             ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
+
+            device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
         } else {
             // TODO: Use pointer or reference to avoid copy
             device->transfer_queue.copyFrom(device->compute_queue);
             device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
+
+            device->async_use_transfer_queue = false;
         }
 
         device->buffer_type = {
@@ -5467,6 +5731,10 @@ static void ggml_vk_instance_init() {
     vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
     vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
     vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
+    const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
+    if (GGML_VK_PIPELINE_STATS != nullptr) {
+        vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
+    }
     const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
 
     if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
@@ -5513,22 +5781,30 @@ static void ggml_vk_instance_init() {
 
             if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
                 // Check if there are two physical devices corresponding to the same GPU
+                // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
+                // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
+                // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
+                // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
+                // driver is MoltenVK
                 auto old_device = std::find_if(
                     vk_instance.device_indices.begin(),
                     vk_instance.device_indices.end(),
-                    [&devices, &new_id](const size_t k){
+                    [&devices, &new_id, &new_driver](const size_t k){
                         vk::PhysicalDeviceProperties2 old_props;
+                        vk::PhysicalDeviceDriverProperties old_driver;
                         vk::PhysicalDeviceIDProperties old_id;
-                        old_props.pNext = &old_id;
+                        old_props.pNext = &old_driver;
+                        old_driver.pNext = &old_id;
                         devices[k].getProperties2(&old_props);
 
-                        bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
-                        equals = equals || (
+                        bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
+                        same_uuid = same_uuid || (
                             old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
                             std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
                         );
+                        bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
 
-                        return equals;
+                        return same_uuid && !both_molten_vk;
                     }
                 );
                 if (old_device == vk_instance.device_indices.end()) {
@@ -5565,6 +5841,10 @@ static void ggml_vk_instance_init() {
                             driver_priorities[vk::DriverId::eMesaNvk] = 2;
 #endif
                             break;
+                        case VK_VENDOR_ID_QUALCOMM:
+                            driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
+                            driver_priorities[vk::DriverId::eMesaTurnip] = 2;
+                            break;
                     }
                     driver_priorities[vk::DriverId::eMesaDozen] = 100;
 
@@ -5647,7 +5927,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
     ctx->almost_ready_fence = ctx->device->device.createFence({});
 
     ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
-    ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
+    if (ctx->device->async_use_transfer_queue) {
+        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
+        vk::SemaphoreCreateInfo ci{};
+        ci.setPNext(&tci);
+        ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci);
+        ctx->transfer_semaphore.value = 0;
+
+        ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
+    }
 
     if (vk_perf_logger_enabled) {
         ctx->perf_logger = std::unique_ptr(new vk_perf_logger());
@@ -6100,13 +6388,24 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
     return vk_subbuffer{buffer, offset, size};
 }
 
+// Get a command buffer from pool. Create a new one if no reusable buffer is available
+static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
+    for (auto& cmd_buffer : pool.cmd_buffers) {
+        if (!cmd_buffer.in_use) {
+            cmd_buffer.in_use = true;
+            return &cmd_buffer;
+        }
+    }
+    return ggml_vk_create_cmd_buffer(device, pool);
+}
+
 static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
     vk_submission s;
-    s.buffer = ggml_vk_create_cmd_buffer(device, p);
+    s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p);
     if (one_time) {
-        s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
+        s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
     } else {
-        s.buffer.begin({ vk::CommandBufferUsageFlags{} });
+        s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} });
     }
 
     return s;
@@ -6159,18 +6458,18 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
     vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
     ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
 
-    subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
-    subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
-    subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
+    subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
+    subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
+    subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
                                 pipeline->layout,
                                 0,
                                 { descriptor_set },
                                 {});
-    subctx->s->buffer.dispatch(wg0, wg1, wg2);
+    subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
 }
 
 static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) {
-    s.buffer.end();
+    s.buffer->buf.end();
 
     s.wait_semaphores = std::move(wait_semaphores);
     s.signal_semaphores = std::move(signal_semaphores);
@@ -6182,7 +6481,7 @@ static void ggml_vk_ctx_end(vk_context& ctx) {
         return;
     }
 
-    ctx->s->buffer.end();
+    ctx->s->buffer->buf.end();
     ctx->s = nullptr;
 }
 
@@ -6196,6 +6495,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
     subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
 }
 
+static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
+    if (!ctx->compute_ctx.expired()) {
+        return ctx->compute_ctx.lock();
+    }
+
+    vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
+
+    ctx->compute_ctx = result;
+    ggml_vk_ctx_begin(ctx->device, result);
+
+    if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
+        result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
+        ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
+    }
+
+    return result;
+}
+
+// Submit any pending transfer queue work and signal the transfer semaphore.
+// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore.
+// Returns true if work was submitted.
+static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {
+    if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {
+        return false;
+    }
+
+    vk_context cpy_ctx = ctx->transfer_ctx.lock();
+    ggml_vk_ctx_end(cpy_ctx);
+
+    for (auto& cpy : cpy_ctx->in_memcpys) {
+        memcpy(cpy.dst, cpy.src, cpy.n);
+    }
+
+    ctx->transfer_semaphore.value++;
+    cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);
+
+    ggml_vk_submit(cpy_ctx, {});
+    ctx->transfer_ctx.reset();
+    return true;
+}
+
 static size_t ggml_vk_align_size(size_t width, size_t align) {
     VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
     return CEIL_DIV(width, align) * align;
@@ -6295,7 +6635,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
         }
 
         ggml_vk_sync_buffers(ctx, subctx);
-        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
+        subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
         return;
     }
 
@@ -6310,7 +6650,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
     VkBufferCopy buf_copy{ 0, offset, copy_size };
 
     ggml_vk_sync_buffers(ctx, subctx);
-    vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
+    vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
 
     for (uint64_t i3 = 0; i3 < ne3; i3++) {
         for (uint64_t i2 = 0; i2 < ne2; i2++) {
@@ -6359,7 +6699,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
         }
 
         ggml_vk_sync_buffers(nullptr, subctx);
-        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
+        subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
         return true;
     }
     VK_LOG_DEBUG("STAGING");
@@ -6381,7 +6721,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
         copy_size};
 
     ggml_vk_sync_buffers(nullptr, subctx);
-    vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
+    vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
 
     if (width == spitch) {
         deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
@@ -6467,7 +6807,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
     if (buf != nullptr) {
         // Memory is pinned, use as staging buffer
         ggml_vk_sync_buffers(nullptr, subctx);
-        subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
+        subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices);
 
         return true;
     }
@@ -6485,7 +6825,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
     vk_buffer& staging_buffer = src->device->sync_staging;
 
     ggml_vk_sync_buffers(nullptr, subctx);
-    subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
+    subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);
 
     deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
     return true;
@@ -6532,7 +6872,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
 
     VkBufferCopy bc{ src_offset, dst_offset, size };
 
-    vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
+    vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
 }
 
 static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
@@ -6570,7 +6910,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
     }
 
     // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
-    ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
+    ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
 }
 
 static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
@@ -6585,7 +6925,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
     std::lock_guard guard(dst->device->mutex);
     vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
     ggml_vk_ctx_begin(dst->device, subctx);
-    subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
+    subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
     ggml_vk_ctx_end(subctx);
 
     ggml_vk_submit(subctx, dst->device->fence);
@@ -6691,8 +7031,16 @@ static void ggml_vk_matmul(
         uint32_t padded_n) {
         VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
     if (split_k == 1) {
-        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
-        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+        uint32_t base_work_group_z = 0;
+        while (base_work_group_z < batch) {
+            uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+            const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
+            ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
+            base_work_group_z += groups_z;
+        }
         return;
     }
 
@@ -6706,9 +7054,17 @@ static void ggml_vk_matmul(
     uint32_t k_split = CEIL_DIV(k, split_k);
     k_split = ROUNDUP_POW2(k_split, 256);
 
-    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
-    // Make sure enough workgroups get assigned for split k to work
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
+
+    uint32_t base_work_group_z = 0;
+    while (base_work_group_z < batch) {
+        uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
+
+        const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
+        // Make sure enough workgroups get assigned for split k to work
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
+        base_work_group_z += groups_z;
+    }
     ggml_vk_sync_buffers(ctx, subctx);
     const std::array pc2 = { (uint32_t)(m * n * batch), split_k };
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7104,7 +7460,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         }
 
         // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
         if (qx_needs_dequant) {
             ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
         }
@@ -7274,6 +7629,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
             return false;
         }
 
+        if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
+            // Intel Windows proprietary driver tuning
+            switch (src0_type) {
+            case GGML_TYPE_MXFP4:
+            case GGML_TYPE_Q4_K:
+            case GGML_TYPE_Q5_K:
+                return false;
+            default:
+                return true;
+            }
+        }
+
         switch (src0_type) {
         // From tests on A770 Linux, may need more tuning
         case GGML_TYPE_Q4_0:
@@ -7402,7 +7769,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7497,22 +7863,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
     }
 
-    // compute
-    const vk_mat_vec_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, stride_batch_d,
-        fusion_flags,
-        (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
-    };
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-                              {
-                                d_X,
-                                d_Y,
-                                d_D,
-                                d_F0,
-                                d_F1,
-                              },
-                              pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
+    ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
+
+    uint32_t base_work_group_y = 0;
+    while (base_work_group_y < ne12 * ne13) {
+
+        uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+        const vk_mat_vec_push_constants pc = {
+            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+            stride_batch_x, stride_batch_y, stride_batch_d,
+            fusion_flags, base_work_group_y,
+            (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
+        };
+        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+                                  {
+                                    d_X,
+                                    d_Y,
+                                    d_D,
+                                    d_F0,
+                                    d_F1,
+                                  },
+                                  pc, { groups_x, groups_y, groups_z });
+        base_work_group_y += groups_y;
+    }
 
     if (x_non_contig) {
         ctx->prealloc_x_need_sync = true;
@@ -7750,10 +8123,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
         src1->nb[2] <= src1->nb[1] &&
         src1->nb[1] <= src1->nb[3] &&
         src0->ne[3] == 1 &&
-        src1->ne[3] == 1) {
+        src1->ne[3] == 1 &&
+        src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+        src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
         ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
     } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
-               !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
+               !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
+               src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
+               src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
+               src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
         ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
     // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
     // when ne12 and ne13 are one.
@@ -8083,8 +8461,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
 
     const uint64_t nei0 = ids->ne[0];
     const uint64_t nei1 = ids->ne[1];
-
-    GGML_ASSERT(nei1 == 1);
+    const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
 
     const uint64_t ne20 = dst->ne[0];
     const uint64_t ne21 = dst->ne[1];
@@ -8168,7 +8545,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -8226,7 +8603,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     uint32_t stride_batch_y = ne10*ne11;
 
     if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+        stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
     }
 
     const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
@@ -8262,23 +8639,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
     }
 
-    // compute
-    const vk_mat_vec_id_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
-        fusion_flags,
-        (uint32_t)nei0, (uint32_t)ne11,
-    };
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-        {
-            d_X,
-            d_Y,
-            d_D,
-            d_F0,
-            d_F1,
-            d_ids,
-        },
-        pc, { groups_x, (uint32_t)nei0, groups_z });
+    // Loop over the batch dimension
+    for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
+        const vk_mat_vec_id_push_constants pc = {
+            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+            (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
+            fusion_flags,
+            (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
+        };
+        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+            {
+                d_X,
+                d_Y,
+                d_D,
+                d_F0,
+                d_F1,
+                d_ids,
+            },
+            pc, { groups_x, (uint32_t)nei0, groups_z });
+    }
 
     if (x_non_contig) {
         ctx->prealloc_x_need_sync = true;
@@ -8292,7 +8671,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
     ggml_tensor * dst = cgraph->nodes[node_idx];
     ggml_tensor * src0 = dst->src[0];
     ggml_tensor * src2 = dst->src[2];
-    return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
+    return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
 }
 
 static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -8308,55 +8687,70 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 }
 
-static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
+    GGML_UNUSED(f32acc);
     // Needs to be kept up to date on shader changes
-    GGML_UNUSED(hsv);
-    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
-    const uint32_t Bc = scalar_flash_attention_Bc;
+    const uint32_t wg_size = params.workgroup_size;
+    const uint32_t Br = params.block_rows;
+    const uint32_t Bc = params.block_cols;
 
+    const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
+
+    // tmpsh is overestimated slightly
     const uint32_t tmpsh = wg_size * sizeof(float);
-    const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
+    const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
 
-    const uint32_t masksh = Bc * Br * sizeof(float);
+    const uint32_t masksh = Bc * (Br + 1) * float_type_size;
 
-    const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
+    const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
 
-    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
+    const uint32_t D = std::max(hsk, hsv);
+    const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
+
+    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
-    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
+    VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
 
     return supported;
 }
 
-static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
     // Needs to be kept up to date on shader changes
-    GGML_UNUSED(hsv);
-    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = coopmat1_flash_attention_num_large_rows;
-    const uint32_t Bc = scalar_flash_attention_Bc;
+    const uint32_t Br = params.block_rows;
+    const uint32_t Bc = params.block_cols;
+
+    const uint32_t MatBr = 16, MatBc = 16;
+
+    const uint32_t row_split = Bc / MatBc;
 
     const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
+    const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
 
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
 
-    const uint32_t tmpsh = wg_size * sizeof(float);
-    const uint32_t tmpshv4 = wg_size * 4 * acctype;
+    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
 
     const uint32_t qstride = hsk_pad / 4 + 2;
     const uint32_t Qf = Br * qstride * f16vec4;
 
+    const uint32_t psh_stride = Br / 4 + 2;
+    const uint32_t Psh = Bc * psh_stride * f16vec4;
+
     const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
     const uint32_t sfsh = Bc * sfshstride * acctype;
 
-    const uint32_t kshstride = hsk_pad / 4 + 2;
-    const uint32_t ksh = Bc * kshstride * f16vec4;
+    const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
+    const uint32_t vsh_stride = MatBc / 4 * row_split;
+    const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
 
-    const uint32_t slope = Br * sizeof(float);
+    const uint32_t osh_stride = params.row_split * MatBr / 4;
+    const uint32_t pvsh = MatBc * osh_stride * f16vec4;
 
-    const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
+    const uint32_t slope = Br * acctype;
+
+    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
     VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
@@ -8383,6 +8777,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
     GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
+    const uint32_t nem0 = mask ? mask->ne[0] : 0;
     const uint32_t nem1 = mask ? mask->ne[1] : 0;
     const uint32_t nem2 = mask ? mask->ne[2] : 0;
     const uint32_t nem3 = mask ? mask->ne[3] : 0;
@@ -8416,72 +8811,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     assert(q->type == GGML_TYPE_F32);
     assert(k->type == v->type);
 
-    FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
-                      ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
-
-    if (path == FA_COOPMAT1) {
-        const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
-                                             (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
-
-        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
-
-        if (!coopmat_shape_supported || !coopmat_shmem_supported) {
-            path = FA_SCALAR;
-        }
-    }
-
     uint32_t gqa_ratio = 1;
     uint32_t qk_ratio = neq2 / nek2;
     uint32_t workgroups_x = (uint32_t)neq1;
     uint32_t workgroups_y = (uint32_t)neq2;
     uint32_t workgroups_z = (uint32_t)neq3;
 
-    const bool small_cache = nek1 < 1024;
+    const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
 
     // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
     // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
-    uint32_t max_gqa;
-    switch (path) {
-    case FA_SCALAR:
-    case FA_COOPMAT1:
-        // We may switch from coopmat1 to scalar, so use the scalar limit for both
-        max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
-        break;
-    case FA_COOPMAT2:
-        max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
-        break;
-    default:
-        GGML_ASSERT(0);
-    }
+    vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
+    const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
 
-    if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
+    if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
         qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
         // grouped query attention - make the N dimension equal to gqa_ratio, reduce
         // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
         // and change addressing calculations to index Q's dimension 2.
         gqa_ratio = qk_ratio;
         N = gqa_ratio;
-        workgroups_y /= N;
+        workgroups_y /= gqa_ratio;
     }
 
-    bool small_rows = N <= get_fa_num_small_rows(path);
-
-    // coopmat1 does not actually support "small rows" (it needs 16 rows).
-    // So use scalar instead.
-    if (small_rows && path == FA_COOPMAT1) {
-        path = FA_SCALAR;
-    }
-
-    // scalar is faster than coopmat2 when N==1
-    if (N == 1 && path == FA_COOPMAT2) {
-        path = FA_SCALAR;
-    }
-
-    // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
-    if (path == FA_SCALAR &&
-        !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
-        small_rows = true;
-    }
+    tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
 
     const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
     uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@@ -8495,19 +8848,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         v_stride /= 4;
     }
 
-    uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
+    const uint32_t alignment = tuning_params.block_cols;
     bool aligned = (KV % alignment) == 0 &&
                    // the "aligned" shader variant will forcibly align strides, for performance
                    (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
 
     // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
-    if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
+    if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
         aligned = false;
     }
 
-    bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
 
-    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
+    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
+    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
+    vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
+                                                                   mask != nullptr, use_mask_opt, logit_softcap != 0);
 
     vk_pipeline pipeline = nullptr;
 
@@ -8523,29 +8889,46 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 
     assert(pipeline);
+    // Compile early to initialize wg_denoms.
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
     uint32_t split_kv = KV;
     uint32_t split_k = 1;
 
-    // Use a placeholder core count if one isn't available. split_k is a big help for perf.
-    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
+    // Intel Alchemist prefers more workgroups
+    const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
 
-    // Try to use split_k when KV is large enough to be worth the overhead
-    if (workgroups_x == 1 && shader_core_count > 0) {
-        // Try to run two workgroups per SM.
-        split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
-        if (split_k > 1) {
-            // Try to evenly split KV into split_k chunks, but it needs to be a multiple
-            // of "align", so recompute split_k based on that.
-            split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
-            split_k = CEIL_DIV(KV, split_kv);
-            workgroups_x = split_k;
+    // Use a placeholder core count if one isn't available. split_k is a big help for perf.
+    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
+
+    const uint32_t Br = fa_pipeline_state.Br;
+    const uint32_t Bc = fa_pipeline_state.Bc;
+
+    GGML_ASSERT(Br == pipeline->wg_denoms[0]);
+    const uint32_t Tr = CEIL_DIV(N, Br);
+
+    // Try to use split_k when KV is large enough to be worth the overhead.
+    if (gqa_ratio > 1 && workgroups_x <= Br) {
+        split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
+    } else if (gqa_ratio <= 1) {
+        uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
+        if (total_wgs_no_split < shader_core_count * 2) {
+            split_k = shader_core_count * 2 / total_wgs_no_split;
         }
     }
 
+    if (split_k > 1) {
+        // Try to evenly split KV into split_k chunks, but it needs to be a multiple
+        // of "align", so recompute split_k based on that.
+        split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
+        split_k = CEIL_DIV(KV, split_kv);
+    }
+
     // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
     // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
-    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
+    // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
+    // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
+    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
     if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
         GGML_ABORT("Requested preallocation size is too large");
     }
@@ -8554,24 +8937,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         ggml_vk_preallocate_buffers(ctx, subctx);
     }
 
-    {
-        // Request descriptor sets
-        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
-        if (split_k > 1) {
-            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+    const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
+    const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
+
+    vk_pipeline pipeline_fa_mask_opt = nullptr;
+    if (use_mask_opt) {
+        std::lock_guard guard(ctx->device->mutex);
+        auto &pipelines = ctx->device->pipeline_fa_mask_opt;
+        auto it = pipelines.find({Br, Bc});
+        if (it != pipelines.end()) {
+            pipeline_fa_mask_opt = it->second;
+        } else {
+            pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared();
         }
-    }
+        assert(pipeline_fa_mask_opt);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
 
-    float scale         = 1.0f;
-    float max_bias      = 0.0f;
-    float logit_softcap = 0.0f;
-
-    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
-
-    if (logit_softcap != 0) {
-        scale /= logit_softcap;
+        if (ctx->prealloc_size_y < mask_opt_size) {
+            ctx->prealloc_size_y = mask_opt_size;
+            ggml_vk_preallocate_buffers(ctx, subctx);
+        }
+        if (ctx->prealloc_y_need_sync) {
+            ggml_vk_sync_buffers(ctx, subctx);
+        }
     }
 
     const uint32_t n_head_kv   = neq2;
@@ -8585,8 +8973,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
     vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
     vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
+    vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
 
-    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
+    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
+
+    if (use_mask_opt)
+    {
+        const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
+            nem0,
+            nem1,
+            nem2,
+            (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
+            (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
+            (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
+            mask_opt_num_dwords,
+            mask_opt_num_dwords * CEIL_DIV(nem1, Br),
+            mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
+        };
+
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
+                                  { mask_buf, mask_opt_buf }, opt_pc,
+                                  { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
 
     const vk_flash_attn_push_constants pc = { N, KV,
                                               (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -8602,28 +9011,40 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
                                               gqa_ratio, split_kv, split_k };
 
     if (split_k > 1) {
+        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
+
         if (ctx->prealloc_split_k_need_sync) {
             ggml_vk_sync_buffers(ctx, subctx);
         }
 
+        // We reuse workgroups_x to mean the number of splits, so we need to
+        // cancel out the divide by wg_denoms[0].
+        uint32_t dispatch_x;
+        if (gqa_ratio > 1) {
+            workgroups_x *= pipeline->wg_denoms[0];
+            dispatch_x = split_k * workgroups_x;
+        } else {
+            dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
+        }
+
         vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
-                                    // We only use split_k when group query attention is enabled, which means
-                                    // there's no more than one tile of rows (i.e. workgroups_x would have been
-                                    // one). We reuse workgroups_x to mean the number of splits, so we need to
-                                    // cancel out the divide by wg_denoms[0].
-                                    pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
+                                    pc, { dispatch_x, workgroups_y, workgroups_z });
 
         ggml_vk_sync_buffers(ctx, subctx);
-        const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
+        const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
         ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
                                     {split_k_buf, sinks_buf, dst_buf},
-                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
+                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
         ctx->prealloc_split_k_need_sync = true;
     } else {
+        if (gqa_ratio > 1) {
+            // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
+            workgroups_x *= pipeline->wg_denoms[0];
+        }
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
-                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
+                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
                                     pc, { workgroups_x, workgroups_y, workgroups_z });
     }
 }
@@ -8668,6 +9089,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_acc_f32;
         }
         return nullptr;
+    case GGML_OP_SET:
+        if (src0->type == src1->type && src0->type == dst->type &&
+            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
+            return ctx->device->pipeline_set_f32;
+        }
+        return nullptr;
     case GGML_OP_ADD:
     case GGML_OP_SUB:
     case GGML_OP_MUL:
@@ -8869,6 +9296,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         switch (ggml_get_unary_op(dst)) {
             case GGML_UNARY_OP_EXP:
                 return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_ELU:
+                return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_SILU:
                 return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_GELU:
@@ -8905,6 +9334,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_TRUNC:
                 return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_SGN:
+                return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
             default:
                 break;
         }
@@ -9098,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_rwkv_wkv7_f32;
         }
         return nullptr;
+    case GGML_OP_GATED_DELTA_NET:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            const uint32_t S_v = dst->src[2]->ne[0];
+            const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
+            uint32_t si;
+            switch (S_v) {
+                case 32:  si = 0; break;
+                case 64:  si = 1; break;
+                case 128: si = 2; break;
+                default: return nullptr;
+            }
+            return ctx->device->pipeline_gated_delta_net[si][kda];
+        }
+        return nullptr;
     case GGML_OP_SSM_SCAN:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             const uint32_t d_state = src0->ne[0];
@@ -9654,16 +10099,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
+    int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
+    int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
+    int offset = dst->op_params[3] / src0_type_size; // offset in bytes
 
-    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
+    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
         (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] /  dst_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         0,
         0.0f, 0.0f, offset,
     });
@@ -9928,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
     );
 }
 
+static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
+    const ggml_tensor * src_q     = dst->src[0];
+    const ggml_tensor * src_v     = dst->src[2];
+    const ggml_tensor * src_beta  = dst->src[4];
+
+    GGML_ASSERT(dst->buffer != nullptr);
+
+    const uint32_t S_v      = (uint32_t)src_v->ne[0];
+    const uint32_t H        = (uint32_t)src_v->ne[1];
+    const uint32_t n_tokens = (uint32_t)src_v->ne[2];
+    const uint32_t n_seqs   = (uint32_t)src_v->ne[3];
+
+    const uint32_t s_off = S_v * H * n_tokens * n_seqs;
+
+    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
+    GGML_ASSERT(pipeline != nullptr);
+
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer src_buf[6] = {};
+    for (int i = 0; i < 6; i++) {
+        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
+    }
+
+    const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
+    const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
+    const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
+    const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
+    const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
+    const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
+    const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
+    const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
+    const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
+
+    const uint32_t neq1 = (uint32_t)src_q->ne[1];
+    const uint32_t rq3  = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
+
+    const float scale = 1.0f / sqrtf((float)S_v);
+    const vk_op_gated_delta_net_push_constants pc = {
+        H, n_tokens, n_seqs, s_off,
+        sq1, sq2, sq3,
+        sv1, sv2, sv3,
+        sb1, sb2, sb3,
+        neq1, rq3,
+        scale
+    };
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
+        pc, { H, n_seqs, 1u });
+}
+
 static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
@@ -10335,12 +10833,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
 
     uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
     uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
+    uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
+
+    uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
+    uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
+    uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
 
     vk_op_rope_push_constants rope {
-        (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
-        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
-        has_ff, (uint32_t)src0->ne[2], nb01, nb02,
+        (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
+        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
         { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
+
+        (uint32_t)src0->ne[0],
+        (uint32_t)src0->ne[1],
+        (uint32_t)src0->ne[2],
+        nb01, nb02, nb03,
+        nb11, nb12, nb13,
     };
 
     return rope;
@@ -10467,8 +10975,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
 }
 
 static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
+    const float * op_params = (const float *)dst->op_params;
+    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
+    p.param1 = op_params[0];
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
 }
 
 static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -11386,7 +11896,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         }
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
         ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
@@ -11560,7 +12069,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     free(d_chk);
 
     ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
-    ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
 
     ggml_vk_destroy_buffer(d_X);
     ggml_vk_destroy_buffer(d_Y);
@@ -11896,7 +12404,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         // y[i] = i % k;
     }
 
-    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
     if (split_k > 1) {
         ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
 
@@ -11909,7 +12416,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         }
     }
     if (mmq) {
-        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it);
+        vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
     }
 
     ggml_pipeline_allocate_descriptor_sets(ctx);
@@ -12145,7 +12653,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
         ggml_vk_submit(subctx, {});
         ctx->submit_pending = true;
         ggml_vk_synchronize(ctx);
+        GGML_ASSERT(ctx->compute_ctx.expired());
         ggml_vk_ctx_begin(ctx->device, subctx);
+        ctx->compute_ctx = subctx;
     }
 
     if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
@@ -12163,6 +12673,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
             ggml_vk_destroy_buffer(ctx->prealloc_y);
         }
         ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
+        ctx->prealloc_y_last_tensor_used = nullptr;
     }
     if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
         VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -12191,6 +12702,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
         return false;
     }
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+        return false;
+    }
 
     VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
     ctx->semaphore_idx = 0;
@@ -12215,15 +12729,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         }
     }
 
-    vk_context compute_ctx;
-
-    if (ctx->compute_ctx.expired()) {
-        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->compute_ctx = compute_ctx;
-        ggml_vk_ctx_begin(ctx->device, compute_ctx);
-    } else {
-        compute_ctx = ctx->compute_ctx.lock();
-    }
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
 
     {
         // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
@@ -12294,7 +12800,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
             if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
                 ctx->query_node_idx[ctx->query_idx] = node_idx;
-                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
+                compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
             }
         }
         // Add all fused nodes to the unsynchronized lists.
@@ -12337,6 +12843,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_ACC:
+    case GGML_OP_SET:
         ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
 
         break;
@@ -12471,6 +12978,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         }
 
         switch (ggml_get_unary_op(node)) {
+        case GGML_UNARY_OP_ELU:
         case GGML_UNARY_OP_EXP:
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
@@ -12489,6 +12997,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_UNARY_OP_CEIL:
         case GGML_UNARY_OP_FLOOR:
         case GGML_UNARY_OP_TRUNC:
+        case GGML_UNARY_OP_SGN:
             ggml_vk_unary(ctx, compute_ctx, src0, node);
             break;
         case GGML_UNARY_OP_XIELU:
@@ -12633,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
 
+    case GGML_OP_GATED_DELTA_NET:
+        ggml_vk_gated_delta_net(ctx, compute_ctx, node);
+
+        break;
+
     case GGML_OP_SSM_SCAN:
         ggml_vk_ssm_scan(ctx, compute_ctx, node);
 
@@ -12740,7 +13254,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
     ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
 
     ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
-    ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
+    if (ctx->device->async_use_transfer_queue) {
+        ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
+    }
 
     for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
         ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
@@ -12769,7 +13285,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
 static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
     // discard any unsubmitted command buffers
-    ctx->transfer_ctx.reset();
+    ctx->compute_ctx.reset();
     // wait for any pending command buffers to finish
     ggml_vk_synchronize(ctx);
 
@@ -12802,7 +13318,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
     ctx->descriptor_sets.clear();
 
     ctx->compute_cmd_pool.destroy(ctx->device->device);
-    ctx->transfer_cmd_pool.destroy(ctx->device->device);
+    if (ctx->device->async_use_transfer_queue) {
+        ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s);
+
+        ctx->transfer_cmd_pool.destroy(ctx->device->device);
+    }
     if (vk_perf_logger_enabled) {
         ctx->perf_logger->print_timings(true);
     }
@@ -12861,6 +13381,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
     vk_buffer buf = buf_ctx->dev_buffer;
 
+    if (size == 0) {
+        return;
+    }
+
     uint32_t val32 = (uint32_t)value * 0x01010101;
     ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
 }
@@ -12870,6 +13394,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
     vk_buffer buf = buf_ctx->dev_buffer;
 
+    if (size == 0) {
+        return;
+    }
+
     ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
 }
 
@@ -12877,12 +13405,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
     VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
 
+    if (size == 0) {
+        return;
+    }
+
     vk_buffer buf = buf_ctx->dev_buffer;
 
     ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
 }
 
 static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+    if (ggml_nbytes(src) == 0) {
+        return true;
+    }
+
     if (ggml_backend_buffer_is_vk(src->buffer)) {
         ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
         ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@@ -13072,36 +13608,44 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
     GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
 
+    if (size == 0) {
+        return;
+    }
+
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
 
-    vk_context transfer_ctx;
+    vk_context cpy_ctx;
 
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
+    if (ctx->device->async_use_transfer_queue) {
+        if (ctx->transfer_ctx.expired()) {
+            // Initialize new transfer context
+            cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
+            ctx->transfer_ctx = cpy_ctx;
+            ggml_vk_ctx_begin(ctx->device, cpy_ctx);
+        } else {
+            cpy_ctx = ctx->transfer_ctx.lock();
+        }
     } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
+        cpy_ctx = ggml_vk_get_compute_ctx(ctx);
     }
 
     vk_buffer buf = buf_ctx->dev_buffer;
 
     auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
 
-    bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
+    bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
 
     if (!ret) {
         ggml_vk_ensure_sync_staging_buffer(ctx, size);
-        ggml_vk_sync_buffers(nullptr, transfer_ctx);
+        ggml_vk_sync_buffers(nullptr, cpy_ctx);
 
         vk::BufferCopy buffer_cpy;
         buffer_cpy.srcOffset = 0;
         buffer_cpy.dstOffset = dst_offset;
         buffer_cpy.size = size;
 
-        transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
-        deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
+        cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
+        deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
         ggml_vk_synchronize(ctx);
     }
 }
@@ -13111,101 +13655,156 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
     GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
 
+    if (size == 0) {
+        return;
+    }
+
     ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
 
-    vk_context transfer_ctx;
-
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-    } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
-    }
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
 
     vk_buffer buf = buf_ctx->dev_buffer;
 
     auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
-    bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size);
+    bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
 
     // If that failed, copy synchronously through a staging buffer
     if (!ret) {
         ggml_vk_ensure_sync_staging_buffer(ctx, size);
-        ggml_vk_sync_buffers(nullptr, transfer_ctx);
+        ggml_vk_sync_buffers(nullptr, compute_ctx);
 
         vk::BufferCopy buffer_cpy;
         buffer_cpy.srcOffset = src_offset;
         buffer_cpy.dstOffset = 0;
         buffer_cpy.size = size;
 
-        transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
-        deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys);
+        compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
+        deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
         ggml_vk_synchronize(ctx);
     }
 }
 
-static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
-    VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-    if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
-        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
-        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
+    VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
+    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
 
-        vk_context transfer_ctx;
-
-        if (ctx->transfer_ctx.expired()) {
-            // Initialize new transfer context
-            transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-            ctx->transfer_ctx = transfer_ctx;
-            ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-        } else {
-            transfer_ctx = ctx->transfer_ctx.lock();
-        }
-
-        vk_buffer src_buf = src_buf_ctx->dev_buffer;
-        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
-
-        ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
+    // Skip zero-size tensors
+    if (ggml_nbytes(src) == 0) {
         return true;
     }
 
+    if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
+        return false;
+    }
+
+    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
+    vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
+
+    if (ggml_backend_buffer_is_vk(src->buffer)) {
+        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
+
+        // Async copy only works within the same device
+        if (src_buf_ctx->dev_buffer->device != dst_buf->device) {
+            return false;
+        }
+
+        vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
+
+        ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs,
+                                   src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs,
+                                   ggml_nbytes(src));
+        return true;
+    }
+
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        vk_buffer pinned_buf = nullptr;
+        size_t pinned_offset = 0;
+        ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset);
+        if (pinned_buf == nullptr) {
+            return false;
+        }
+
+        vk_context cpy_ctx;
+        if (ctx->device->async_use_transfer_queue) {
+            if (ctx->transfer_ctx.expired()) {
+                cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
+                ctx->transfer_ctx = cpy_ctx;
+                ggml_vk_ctx_begin(ctx->device, cpy_ctx);
+            } else {
+                cpy_ctx = ctx->transfer_ctx.lock();
+            }
+        } else {
+            cpy_ctx = ggml_vk_get_compute_ctx(ctx);
+        }
+
+        return ggml_vk_buffer_write_async(cpy_ctx, dst_buf,
+                                          vk_tensor_offset(dst) + dst->view_offs,
+                                          src->data, ggml_nbytes(src));
+    }
+
+    GGML_UNUSED(backend_src);
     return false;
 }
 
 static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
     VK_LOG_DEBUG("ggml_vk_synchronize()");
 
-    bool do_transfer = !ctx->transfer_ctx.expired();
+    bool do_transfer = !ctx->compute_ctx.expired();
 
-    vk_context transfer_ctx;
+    if (ggml_vk_submit_transfer_ctx(ctx)) {
+        ctx->submit_pending = true;
+    }
+
+    vk_context compute_ctx;
+    vk_command_buffer* cmd_buf = nullptr;
     if (do_transfer) {
-        transfer_ctx = ctx->transfer_ctx.lock();
+        compute_ctx = ctx->compute_ctx.lock();
+        if (compute_ctx->s) {
+            cmd_buf = compute_ctx->s->buffer;
+        }
 
-        ggml_vk_ctx_end(transfer_ctx);
+        ggml_vk_ctx_end(compute_ctx);
 
-        for (auto& cpy : transfer_ctx->in_memcpys) {
+        for (auto& cpy : compute_ctx->in_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
 
-        ggml_vk_submit(transfer_ctx, {});
+        ggml_vk_submit(compute_ctx, {});
         ctx->submit_pending = true;
     }
 
     if (ctx->submit_pending) {
-        {
+        if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
+            vk::TimelineSemaphoreSubmitInfo tl_info{
+                1, &ctx->transfer_semaphore.value,
+                0, nullptr,
+            };
+            vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags;
+            vk::SubmitInfo si{
+                1, &ctx->transfer_semaphore.s, &stage,
+                0, nullptr,
+                0, nullptr,
+            };
+            si.setPNext(&tl_info);
+            std::lock_guard guard(queue_mutex);
+            ctx->device->compute_queue.queue.submit({ si }, ctx->fence);
+            ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
+        } else {
             std::lock_guard guard(queue_mutex);
             ctx->device->compute_queue.queue.submit({}, ctx->fence);
         }
         ggml_vk_wait_for_fence(ctx);
         ctx->submit_pending = false;
+        if (cmd_buf) {
+            cmd_buf->in_use = false;
+        }
     }
 
     if (do_transfer) {
-        for (auto& cpy : transfer_ctx->out_memcpys) {
+        for (auto& cpy : compute_ctx->out_memcpys) {
             memcpy(cpy.dst, cpy.src, cpy.n);
         }
-        ctx->transfer_ctx.reset();
+        ctx->compute_ctx.reset();
     }
 }
 
@@ -13505,12 +14104,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
     return true;
 }
 
-// Check whether the tensors overlap in memory but are not equal.
-// Fusions can potenitally overwrite src tensors in ways that are not prevented
-// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
-// to overlap if they are exactly equal.
-// XXX TODO this check is probably missing from several fusion optimizations.
-static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
+// Check whether the tensors overlap in memory.
+// Fusions can potentially overwrite src tensors in ways that are not prevented
+// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
+// with the destination, then it's OK for them to overlap if they are exactly equal.
+static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
     ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
     vk_buffer a_buf = a_buf_ctx->dev_buffer;
     ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
@@ -13521,7 +14119,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g
         auto b_base = vk_tensor_offset(b) + b->view_offs;
         auto b_size = ggml_nbytes(b);
 
-        if (a_base == b_base && a_size == b_size) {
+        if (elementwise && a_base == b_base && a_size == b_size) {
             return false;
         }
 
@@ -13559,13 +14157,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
         return false;
     }
 
-    // must not overwrite srcs in a way that's not elementwise
-    ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
-    if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
-        ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
-        return false;
-    }
-
     // conditions for pipeline creation
     if (!(ctx->device->float_controls_rte_fp16 &&
         sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
@@ -13627,6 +14218,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
     return num_adds;
 }
 
+static int32_t find_first_set(uint32_t x) {
+    int32_t ret = 0;
+    if (!x) {
+        return -1;
+    }
+    while (!(x & 1)) {
+        x >>= 1;
+        ret++;
+    }
+    return ret;
+}
+
 static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -13645,7 +14248,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
     int last_node = cgraph->n_nodes - 1;
 
     // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
-    while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
+    while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
         last_node -= 1;
     }
 
@@ -13655,6 +14258,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
     bool first_node_in_batch = true; // true if next node will be first node in a batch
     int submit_node_idx = 0; // index to first node in a batch
 
+    ggml_vk_submit_transfer_ctx(ctx);
+
     vk_context compute_ctx;
     if (vk_perf_logger_enabled) {
         // allocate/resize the query pool
@@ -13680,11 +14285,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
 
         GGML_ASSERT(ctx->compute_ctx.expired());
-        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->compute_ctx = compute_ctx;
-        ggml_vk_ctx_begin(ctx->device, compute_ctx);
+        compute_ctx = ggml_vk_get_compute_ctx(ctx);
         ctx->query_idx = 0;
-        compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
+        compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
     }
 
     ctx->prealloc_y_last_pipeline_used = nullptr;
@@ -13692,13 +14295,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
 
     if (ctx->prealloc_size_add_rms_partials) {
         ggml_vk_preallocate_buffers(ctx, nullptr);
-        if (ctx->compute_ctx.expired()) {
-            compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-            ctx->compute_ctx = compute_ctx;
-            ggml_vk_ctx_begin(ctx->device, compute_ctx);
-        } else {
-            compute_ctx = ctx->compute_ctx.lock();
-        }
+        compute_ctx = ggml_vk_get_compute_ctx(ctx);
         // initialize partial sums to zero.
         ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
         ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -13725,6 +14322,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             total_mul_mat_bytes += bytes;
         }
 
+        // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
+        // the fused result in an elementwise-way. This affects whether the memory for
+        // the src is allowed to overlap the memory for the destination.
+        // The array is sized to handle the largest fusion (asserted later).
+        bool op_srcs_fused_elementwise[12];
+
         ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
         ctx->fused_topk_moe_scale = false;
         const char *fusion_string {};
@@ -13733,39 +14336,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             if (num_adds) {
                 ctx->num_additional_fused_ops = num_adds - 1;
                 fusion_string = "MULTI_ADD";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "MUL_MAT_ADD_ADD";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ADD";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ID_ADD_ID";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "MUL_MAT_ID_MUL";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
                        ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
                 ctx->num_additional_fused_ops = 4;
                 fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = false;
+                op_srcs_fused_elementwise[2] = false;
+                op_srcs_fused_elementwise[3] = false;
+                op_srcs_fused_elementwise[4] = false;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "RMS_NORM_MUL_ROPE";
+                // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = true;
+                op_srcs_fused_elementwise[2] = true;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
                 fusion_string = "RMS_NORM_MUL";
+                // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
+                // they are overwritten, and one workgroup per row. So close enough.
+                op_srcs_fused_elementwise[0] = true;
+                op_srcs_fused_elementwise[1] = true;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
                        ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
                 ctx->num_additional_fused_ops = 2;
                 fusion_string = "ROPE_VIEW_SET_ROWS";
+                op_srcs_fused_elementwise[0] = false;
+                op_srcs_fused_elementwise[1] = false;
+                op_srcs_fused_elementwise[2] = false;
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -13774,6 +14406,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 3;
                 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
@@ -13782,6 +14415,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 4;
                 ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
                 fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
@@ -13790,6 +14424,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 3;
                 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
@@ -13798,6 +14433,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->fused_ops_write_mask |= 1 << 1;
                 ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
                 fusion_string = "TOPK_MOE_LATE_SOFTMAX";
+                std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
             }
             if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
                 // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
@@ -13805,11 +14441,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                     ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
                     ctx->fused_topk_moe_scale = true;
                     ctx->num_additional_fused_ops++;
+                    op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
                 }
             }
         }
+        GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
         ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
 
+        // Check whether fusion would overwrite src operands while they're still in use.
+        // If so, disable fusion.
+        if (ctx->num_additional_fused_ops) {
+            // There are up to two output nodes - topk_moe has two.
+            uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
+            ggml_tensor *output_nodes[2] {};
+            output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
+            if (bits) {
+                int output_idx = find_first_set(bits);
+                GGML_ASSERT(bits == (1u << output_idx));
+                output_nodes[1] = cgraph->nodes[i + output_idx];
+            }
+
+            bool need_disable = false;
+
+            // topk_moe often overwrites the source, but for a given row all the src values are
+            // loaded before anything is stored. If there's only one row, this is safe, so treat
+            // this as a special case.
+            bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
+                                          ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
+
+            if (!is_topk_moe_single_row) {
+                for (int j = 0; j < 2; ++j) {
+                    ggml_tensor *dst = output_nodes[j];
+                    if (!dst) {
+                        continue;
+                    }
+                    // Loop over all srcs of all nodes in the fusion. If the src overlaps
+                    // the destination and the src is not an intermediate node that's being
+                    // elided, then disable fusion.
+                    for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
+                        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
+                            ggml_tensor *src = cgraph->nodes[i + k]->src[s];
+                            if (!src || src->op == GGML_OP_NONE) {
+                                continue;
+                            }
+                            if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
+                                bool found = false;
+                                for (int n = 0; n < k; ++n) {
+                                    if (cgraph->nodes[i + n] == src) {
+                                        found = true;
+                                        break;
+                                    }
+                                }
+                                if (!found) {
+                                    need_disable = true;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            if (need_disable) {
+                ctx->num_additional_fused_ops = 0;
+                ctx->fused_ops_write_mask = 1;
+                ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
+                ctx->fused_topk_moe_scale = false;
+            }
+        }
+
         // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
         bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
         bool submit = (submitted_nodes >= nodes_per_submit) ||
@@ -13820,18 +14518,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
 
         if (vk_perf_logger_enabled && enqueued) {
-            if (ctx->compute_ctx.expired()) {
-                compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-                ctx->compute_ctx = compute_ctx;
-                ggml_vk_ctx_begin(ctx->device, compute_ctx);
-            } else {
-                compute_ctx = ctx->compute_ctx.lock();
-            }
+            compute_ctx = ggml_vk_get_compute_ctx(ctx);
             if (!vk_perf_logger_concurrent) {
                 // track a single node/fusion for the current query
                 ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
                 ctx->query_fusion_names[ctx->query_idx] = fusion_string;
-                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
+                compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
             } else {
                 // track a fusion string and number of fused ops for the current node_idx
                 ctx->query_fusion_names[i] = fusion_string;
@@ -13874,6 +14566,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         ggml_vk_submit(compute_ctx, ctx->device->fence);
         VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
         ctx->device->device.resetFences({ ctx->device->fence });
+        ctx->compute_ctx.reset();
 
         // Get the results and pass them to the logger
         std::vector timestamps(cgraph->n_nodes + 1);
@@ -14160,29 +14853,24 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
     vk_event *vkev = (vk_event *)event->context;
 
-    vk_context transfer_ctx;
+    ggml_vk_submit_transfer_ctx(ctx);
 
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-    } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
-    }
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
+    auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
 
     // the backend interface doesn't have an explicit reset, so reset it here
     // before we record the command to set it
     ctx->device->device.resetEvent(vkev->event);
     ctx->device->device.resetFences({ vkev->fence });
 
-    ggml_vk_set_event(transfer_ctx, vkev->event);
+    ggml_vk_set_event(compute_ctx, vkev->event);
 
-    ggml_vk_ctx_end(transfer_ctx);
+    ggml_vk_ctx_end(compute_ctx);
 
-    ggml_vk_submit(transfer_ctx, {vkev->fence});
+    ggml_vk_submit(compute_ctx, {vkev->fence});
     ctx->submit_pending = true;
-    ctx->transfer_ctx.reset();
+    vkev->cmd_buffer = cmd_buf;
+    ctx->compute_ctx.reset();
 }
 
 static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
@@ -14190,20 +14878,11 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
     vk_event *vkev = (vk_event *)event->context;
 
-    vk_context transfer_ctx;
+    vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
 
-    if (ctx->transfer_ctx.expired()) {
-        // Initialize new transfer context
-        transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
-        ctx->transfer_ctx = transfer_ctx;
-        ggml_vk_ctx_begin(ctx->device, transfer_ctx);
-    } else {
-        transfer_ctx = ctx->transfer_ctx.lock();
-    }
-
-    ggml_vk_wait_events(transfer_ctx, {vkev->event});
-    ggml_vk_ctx_end(transfer_ctx);
-    ctx->transfer_ctx.reset();
+    ggml_vk_wait_events(compute_ctx, {vkev->event});
+    ggml_vk_ctx_end(compute_ctx);
+    ctx->compute_ctx.reset();
 }
 
 // TODO: enable async and synchronize
@@ -14212,7 +14891,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
     /* .free                    = */ ggml_backend_vk_free,
     /* .set_tensor_async        = */ ggml_backend_vk_set_tensor_async,
     /* .get_tensor_async        = */ ggml_backend_vk_get_tensor_async,
-    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
+    /* .cpy_tensor_async        = */ ggml_backend_vk_cpy_tensor_async,
     /* .synchronize             = */ ggml_backend_vk_synchronize,
     /* .graph_plan_create       = */ NULL,
     /* .graph_plan_free         = */ NULL,
@@ -14413,13 +15092,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
     ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
     const vk_device& device = ggml_vk_get_device(ctx->device);
 
+    const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
+                          device->shader_int64 && device->buffer_device_address;
+
+    auto const & tensor_size_supported = [&](size_t tensor_size) {
+        if (tensor_size > device->max_buffer_size) {
+            return false;
+        }
+        // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
+        // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
+        if (!uses_bda && !device->shader_64b_indexing) {
+            if (tensor_size > device->properties.limits.maxStorageBufferRange) {
+                return false;
+            }
+        }
+        return true;
+    };
     // reject any tensors larger than the max buffer size
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
+        if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
             return false;
         }
     }
-    if (ggml_nbytes(op) > device->max_buffer_size) {
+    if (!tensor_size_supported(ggml_nbytes(op))) {
         return false;
     }
 
@@ -14427,6 +15122,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_GELU_QUICK:
@@ -14445,6 +15141,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_CEIL:
                 case GGML_UNARY_OP_FLOOR:
                 case GGML_UNARY_OP_TRUNC:
+                case GGML_UNARY_OP_SGN:
                     return ggml_is_contiguous(op->src[0]) &&
                            (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                            (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14707,6 +15404,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_REPEAT_BACK:
             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ROPE:
+            return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
         case GGML_OP_ROPE_BACK:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
@@ -14717,8 +15415,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return true;
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_L2_NORM:
             return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_L2_NORM:
+            return ggml_is_contiguous_rows(op->src[0]) &&
+                   op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         case GGML_OP_ADD:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
@@ -14781,7 +15481,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
-            return op->src[0]->type == GGML_TYPE_F32;
+            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+        case GGML_OP_SET:
+            return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
+                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
         case GGML_OP_CONCAT:
             return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
         case GGML_OP_ADD1:
@@ -14855,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true; // all inputs are contiguous, see ggml.c
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                const uint32_t S_v = op->src[2]->ne[0];
+                if (S_v != 32 && S_v != 64 && S_v != 128) {
+                    return false;
+                }
+                for (int i = 0; i < 6; i++) {
+                    if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
+                        return false;
+                    }
+                }
+                return op->type == GGML_TYPE_F32;
+            }
         case GGML_OP_SSM_SCAN:
             {
                 for (int i = 0; i < 6; i++) {
@@ -14926,11 +15642,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba
     return buft_ctx->device->idx == ctx->device;
 }
 
+static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_GET_ROWS:
+            return 0;
+        case GGML_OP_MUL_MAT:
+            return op->ne[1];
+        case GGML_OP_MUL_MAT_ID:
+        case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+            return op->ne[2];
+        default:
+            return ggml_nrows(op);
+    }
+}
+
 static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
 
-    return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
-           (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
+    return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
 }
 
 static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
@@ -14972,6 +15702,10 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
     vk_event *vkev = (vk_event *)event->context;
 
     VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
+    // Finished using current command buffer so we flag for reuse
+    if (vkev->cmd_buffer) {
+        vkev->cmd_buffer->in_use = false;
+    }
 }
 
 static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
@@ -15190,6 +15924,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
     }
 }
 
+static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
+    VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
+
+    if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
+        return 0;
+    }
+
+    const uint32_t device_id = props.properties.deviceID;
+
+    switch (device_id) {
+    case 0x56A6:  // A310
+        return 6;
+    case 0x5693:  // A370M
+    case 0x56A5:  // A380
+    case 0x56B1:  // Pro A40/A50
+        return 8;
+    case 0x5697:  // A530M
+        return 12;
+    case 0x5692:  // A550M
+    case 0x56B3:  // Pro A60
+        return 16;
+    case 0x56A2:  // A580
+        return 24;
+    case 0x5691:  // A730M
+    case 0x56A1:  // A750
+        return 28;
+    case 0x56A0:  // A770
+    case 0x5690:  // A770M
+        return 32;
+    case 0xE212:  // Pro B50
+        return 16;
+    case 0xE20C:  // B570
+        return 18;
+    case 0xE20B:  // B580
+        return 20;
+    default:
+        return 0;
+    }
+}
+
 // checks
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
@@ -15403,7 +16177,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
         } else if (tensor->op == GGML_OP_FILL) {
             const float value = ggml_get_op_params_f32(tensor, 0);
-            tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
+            tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value);
         } else if (tensor->op == GGML_OP_SQR) {
             tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_SQRT) {
@@ -15432,6 +16206,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
         } else if (tensor->op == GGML_OP_ACC) {
             tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
+        } else if (tensor->op == GGML_OP_SET) {
+            tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
         } else if (tensor->op == GGML_OP_NORM) {
             tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
         } else if (tensor->op == GGML_OP_GROUP_NORM) {
@@ -15488,6 +16264,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             case GGML_UNARY_OP_EXP:
                 tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
                 break;
+            case GGML_UNARY_OP_ELU:
+                tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);
+                break;
             case GGML_UNARY_OP_SILU:
                 tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
                 break;
@@ -15546,6 +16325,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             case GGML_UNARY_OP_TRUNC:
                 tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
                 break;
+            case GGML_UNARY_OP_SGN:
+                tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
+                break;
             default:
                 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
                 GGML_ABORT("fatal error");
@@ -15666,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         } else if (tensor->op == GGML_OP_RWKV_WKV7) {
             tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
             src_clone[4], src_clone[5], src_clone[6]);
+        } else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
+            tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
+            src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
         } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
             src_clone[0]->flags = tensor->src[0]->flags;
             tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
@@ -15864,7 +16649,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
         ggml_vk_print_graph_origin(tensor, done);
     }
 
-    if (avg_err > 0.5 || std::isnan(avg_err)) {
+    if (avg_err > 0.01 || std::isnan(avg_err)) {
         std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
         std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
         if (src0 != nullptr) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
index 5084a70e..6ba3d1d8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
@@ -3,6 +3,9 @@
 #include "types.glsl"
 #include "generic_binary_head.glsl"
 
+// false for SET, true for ACC
+layout(constant_id = 1) const bool ACC = true;
+
 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
@@ -13,17 +16,22 @@ void main() {
 
     const uint offset = p.param3;
     const uint src1_i = idx - offset;
-    const uint oz = src1_i / p.nb02;
-    const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
-    const uint ox = src1_i % p.nb01;
+    const uint i3 = src1_i / p.nb03;
+    const uint rem2 = src1_i - i3 * p.nb03;
+    const uint i2 = rem2 / p.nb02;
+    const uint rem1 = rem2 - i2 * p.nb02;
+    const uint i1 = rem1 / p.nb01;
+    const uint i0 = rem1 % p.nb01;
 
     uint i00, i01, i02, i03;
-    get_indices(idx, i00, i01, i02, i03);
 
-    if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
-        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+    if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
+        if (ACC) {
+            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        } else {
+            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        }
     } else {
-        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
+        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
     }
 }
-
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp
new file mode 100644
index 00000000..84dcbd8c
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp
@@ -0,0 +1,27 @@
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    float x = float(data_a[i]);
+
+    if (x < 0.0f) {
+        x = exp(x) - 1;
+    }
+
+    data_d[i] = D_TYPE(x);
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
index 0379e5d5..ec48f5b1 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
@@ -3,9 +3,13 @@
 #extension GL_EXT_control_flow_attributes : enable
 #extension GL_EXT_shader_16bit_storage : require
 
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_subgroup_extended_types_float16 : require
+#endif
+
 #extension GL_KHR_shader_subgroup_shuffle : enable
 #extension GL_KHR_shader_subgroup_vote : enable
 
@@ -15,8 +19,10 @@
 const uint32_t HSK_per_thread = HSK / D_split;
 const uint32_t HSV_per_thread = HSV / D_split;
 
-const uint32_t cols_per_iter = WorkGroupSize / D_split;
+const uint32_t rows_per_thread = Br / row_split;
+const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
+const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
 
 
 layout (binding = 0) readonly buffer Q {float data_q[];};
@@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
-// Store the output when doing grouped query attention.
-// Rows index by Q's dimension 2, and the first N rows are valid.
-D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    uint32_t offset = (iq2 + r) * HSV + c;
-    data_o[o_offset + offset] = D_TYPE(elem);
-    return elem;
-}
+// If SubGroupSize is set to 0 then only use shmem reductions
+const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
+shared float tmpsh[tmpsh_size];
+shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
 
-shared FLOAT_TYPE tmpsh[WorkGroupSize];
-shared vec4 tmpshv4[WorkGroupSize];
+const uint32_t masksh_stride = Br + 1;
+shared FLOAT_TYPE masksh[Bc * masksh_stride];
 
-shared float masksh[Bc][Br];
-shared vec4 Qf[Br][HSK / 4];
+const uint32_t qf_stride = HSK / 4 + 1;
+shared FLOAT_TYPEV4 Qf[Br * qf_stride];
+
+const uint32_t D = HSK > HSV ? HSK : HSV;
+const uint32_t kvsh_stride = D / 4 + 1;
+shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
+
+shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -50,50 +58,70 @@ void main() {
     init_indices();
 
     const uint32_t tid = gl_LocalInvocationIndex;
+    const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+    const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
+    const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
     const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
-    const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
+    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
 
-    uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+    if (LIMIT_OCCUPANCY_SHMEM > 0) {
+        // This just exists to avoid the occupancy_limiter array getting optimized out
+        occupancy_limiter[tid] = vec4(tid);
+
+        barrier();
+
+        if (occupancy_limiter[tid] == vec4(99999.0)) {
+            data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);
+        }
+    }
+
+#define tile_row(r) (row_tid * rows_per_thread + (r))
+
+    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
 
     [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
         uint32_t d = (idx + tid) % (HSK / 4);
         uint32_t r = (idx + tid) / (HSK / 4);
         if (r < Br && d < HSK / 4 &&
             i * Br + r < N) {
-            Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
+            Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
         }
     }
     barrier();
 
-    vec4 Of[Br][HSV_per_thread / 4];
+    FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            Of[r][d] = vec4(0.0);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            Of[r][d] = FLOAT_TYPEV4(0.0);
         }
     }
 
-    float Lf[Br], Mf[Br];
+    float Lf[rows_per_thread], Mf[rows_per_thread];
 
     // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
     const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
 
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lf[r] = 0;
         Mf[r] = NEG_FLT_MAX_OVER_2;
     }
 
-    float slope[Br];
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        slope[r] = 1.0;
+    ACC_TYPE slope[rows_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        slope[r] = ACC_TYPE(1.0);
     }
 
     // ALiBi
     if (p.max_bias > 0.0f) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
         }
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
 #if BLOCK_SIZE > 1
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -101,69 +129,149 @@ void main() {
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
-    uint32_t m_offset = 0;
+    uint32_t m_offset = gqa_iq1*KV;
     if (p.nem2 != 1 || p.nem3 != 1) {
-        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+    uint32_t mask_opt_bits = 0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
+        if (MASK_ENABLE) {
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+                float max_mask = NEG_FLT_MAX_OVER_2;
+                barrier();
+                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+                    uint32_t c = (idx + tid) % Bc;
+                    uint32_t r = (idx + tid) / Bc;
+                    if (idx + tid < Bc * Br) {
+                        if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+                            FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                            masksh[c * masksh_stride + r] = m;
+                            max_mask = max(max_mask, float(m));
+                        } else {
+                            masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
+                        }
+                    }
+                }
+                // skip the block if the mask is entirely -inf
+                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+                barrier();
+                if (gl_SubgroupInvocationID == 0) {
+                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+                }
+                barrier();
+                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                    max_mask = max(max_mask, tmpsh[s]);
+                }
+                if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
+            }
+        }
 
-            float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
-                        masksh[c][r] = m;
-                        max_mask = max(max_mask, m);
+        ACC_TYPE Sf[rows_per_thread][cols_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                Sf[r][c] = ACC_TYPE(0.0);
+            }
+        }
+
+        if (SHMEM_STAGING != 0) {
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK / 4);
+                uint32_t c = (idx + tid) / (HSK / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
+                    FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+#endif
+                    }
+
+                    kvsh[c * kvsh_stride + d] = K_Tf;
+                }
+            }
+            barrier();
+        }
+
+        // More d iterations means Q register caching becomes relevant
+        // Few iterations means the additional registers needed are worse than the speed-up from caching
+        if (HSK_per_thread / 4 > 4) {
+            [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                FLOAT_TYPEV4 Q_cache[rows_per_thread];
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
+                }
+
+                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                    if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                        continue;
+                    }
+
+                    FLOAT_TYPEV4 K_Tf;
+                    if (SHMEM_STAGING != 0) {
+                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
                     } else {
-                        masksh[c][r] = float(0);
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+#endif
+                    }
+                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                        Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
                     }
                 }
             }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
-            barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
-            }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
-                continue;
-            }
-        }
-
-        float Sf[Br][cols_per_thread];
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        } else {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                Sf[r][c] = 0.0;
-            }
-        }
+                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                    continue;
+                }
 
-
-        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                continue;
-            }
-            [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                    FLOAT_TYPEV4 K_Tf;
+                    if (SHMEM_STAGING != 0) {
+                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                    } else {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
 #else
-                vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
+                    }
+                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                        Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
+                    }
                 }
             }
         }
@@ -171,89 +279,109 @@ void main() {
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
             // Compute sum across the D_split
             [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                     Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
                 }
             }
         }
 
-        if (p.logit_softcap != 0.0f) {
-            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        if (LOGIT_SOFTCAP) {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                    Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
+                    Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
                 }
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    float mvf = masksh[c * cols_per_iter + col_tid][r];
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
 
                     Sf[r][c] += slope[r]*mvf;
                 }
             }
-            barrier();
         }
 
-        float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            rowmaxf[r] = NEG_FLT_MAX_OVER_2;
+        float eMf[rows_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                     continue;
                 }
-                rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
+                rowmaxf = max(rowmaxf, float(Sf[r][c]));
             }
-            Moldf[r] = Mf[r];
+            float Moldf = Mf[r];
 
             // M = max(rowmax, Mold)
             // P = e^(S - M)
             // eM = e^(Mold - M)
-            Mf[r] = max(rowmaxf[r], Moldf[r]);
-            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                Pf[r][c] = exp(Sf[r][c] - Mf[r]);
-            }
-            eMf[r] = exp(Moldf[r] - Mf[r]);
-
-            // Compute sum across row of P
-            rowsumf[r] = 0.0;
-            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                    continue;
-                }
-                rowsumf[r] += Pf[r][c];
-            }
-
-            Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
+            Mf[r] = max(rowmaxf, Moldf);
+            eMf[r] = exp(Moldf - Mf[r]);
+            Lf[r] = eMf[r]*Lf[r];
         }
 
         [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                Of[r][d] = eMf[r] * Of[r][d];
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];
             }
         }
 
+        if (SHMEM_STAGING != 0) {
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSV / 4);
+                uint32_t c = (idx + tid) / (HSV / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
+                    FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                        V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
+#endif
+                    }
+
+                    kvsh[c * kvsh_stride + d] = V_Tf;
+                }
+            }
+            barrier();
+        }
+
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
             if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                 continue;
             }
+
+            FLOAT_TYPE Pf[rows_per_thread];
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
+                Lf[r] += Pf[r];
+            }
+
             [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                FLOAT_TYPEV4 Vf;
+                if (SHMEM_STAGING != 0) {
+                    Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                } else {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+                    uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                    uint ib = coord / BLOCK_SIZE;
+                    uint iqs = (coord % BLOCK_SIZE);
+                    Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
 #else
-                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+                    Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    Of[r][d] += Pf[r][c] * Vf;
+                }
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
                 }
             }
         }
-
-        barrier();
     }
 
     // prevent race on tmpsh
@@ -261,58 +389,115 @@ void main() {
 
     // reduce across threads
 
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        float rowmaxf, eMf;
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        float rowmaxf = Mf[r];
 
-        tmpsh[tid] = Mf[r];
         // Compute max across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-            if (tid < s) {
-                tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
+        if (SubGroupSize > 0) {
+            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
             }
+            if (row_split == 1) {
+                // Reduce inside workgroup with shmem
+                barrier();
+                if (gl_SubgroupInvocationID == d_tid) {
+                    tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
+                }
+                barrier();
+                rowmaxf = tmpsh[d_tid];
+                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                    rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
+                }
+            }
+        } else {
             barrier();
+            tmpsh[tid] = rowmaxf;
+            barrier();
+            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                if (rowgroup_tid < s) {
+                    tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
+                }
+                barrier();
+            }
+            rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
         }
-        rowmaxf = tmpsh[d_tid];
-        barrier();
 
         float Moldf = Mf[r];
 
         // M = max(rowmax, Mold)
         // eM = e^(Mold - M)
         Mf[r] = max(rowmaxf, Moldf);
-        eMf = exp(Moldf - Mf[r]);
+        float eMf = exp(Moldf - Mf[r]);
 
         Lf[r] = eMf*Lf[r];
 
-        tmpsh[tid] = Lf[r];
-
         // Compute sum across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-            if (tid < s) {
-                tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
+        if (SubGroupSize > 0) {
+            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                Lf[r] += subgroupShuffleXor(Lf[r], s);
             }
+            if (row_split == 1) {
+                barrier();
+                if (gl_SubgroupInvocationID == d_tid) {
+                    tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
+                }
+                barrier();
+                Lf[r] = tmpsh[d_tid];
+                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                    Lf[r] += tmpsh[s * D_split + d_tid];
+                }
+            }
+        } else {
             barrier();
-        }
-        Lf[r] = tmpsh[d_tid];
-        barrier();
-
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-
-            Of[r][d] = eMf * Of[r][d];
-            tmpshv4[tid] = Of[r][d];
-
+            tmpsh[tid] = Lf[r];
             barrier();
-            [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-                if (tid < s) {
-                    Of[r][d] += tmpshv4[tid + s];
-                    tmpshv4[tid] = Of[r][d];
+            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                if (rowgroup_tid < s) {
+                    tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
                 }
                 barrier();
             }
-            Of[r][d] = tmpshv4[d_tid];
-            barrier();
+            Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
+        }
+
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+            Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];
+
+            if (SubGroupSize > 0) {
+                [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                    if (!OLD_AMD_WINDOWS) {
+                        Of[r][d] += subgroupShuffleXor(Of[r][d], s);
+                    } else {
+                        // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below.
+                        // Shuffle full vec4 as workaround.
+                        // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697
+                        Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s));
+                    }
+                }
+                if (row_split == 1) {
+                    barrier();
+                    if (gl_SubgroupInvocationID == d_tid) {
+                        tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
+                    }
+                    barrier();
+                    Of[r][d] = tmpshv4[d_tid];
+                    [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                        Of[r][d] += tmpshv4[s * D_split + d_tid];
+                    }
+                }
+            } else {
+                barrier();
+                tmpshv4[tid] = Of[r][d];
+                barrier();
+                [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                    if (rowgroup_tid < s) {
+                        Of[r][d] += tmpshv4[tid ^ s];
+                        tmpshv4[tid] = Of[r][d];
+                    }
+                    barrier();
+                }
+                Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
+            }
         }
     }
 
@@ -320,32 +505,53 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
 
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                if (row < N) {
+                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                        gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
                     }
                 }
             }
-        }
 
-        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
-                perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
-                perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                if (row < N) {
+                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+                }
+            }
+        } else {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                const uint global_row = i * Br + row;
+
+                if (global_row < N) {
+                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
+
+                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                        data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
+                    }
+                }
+
+                if (global_row < N && d_tid == 0 && col_tid == 0) {
+                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
+                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
+                }
             }
         }
-
         return;
     }
 
     if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
 
             float ms = 1.0f;
             float vs = 1.0f;
@@ -354,7 +560,7 @@ void main() {
                 ms = exp(Mf[r] - sink);
 
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    Of[r][d] *= ms;
+                    Of[r][d] *= FLOAT_TYPE(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -364,39 +570,37 @@ void main() {
         }
     }
 
-    float Lfrcp[Br];
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+    float Lfrcp[rows_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            Of[r][d] *= Lfrcp[r];
-#if defined(ACC_TYPE_MAX)
-            Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);
+#if defined(FLOAT_TYPE_MAX)
+            Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
 #endif
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
 
     if (p.gqa_ratio > 1) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint row = tile_row(r);
+            if (row < N) {
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
-                    }
+                    gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
                 }
             }
         }
     } else {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (i * Br + r < N) {
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint row = tile_row(r);
+            if (i * Br + row < N) {
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
-                    }
+                    data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
                 }
             }
         }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
index eb93903c..172d38f0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
@@ -1,13 +1,23 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
-layout (constant_id = 1) const uint32_t Br = 1;
-layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t HSK = 32;
-layout (constant_id = 4) const uint32_t HSV = 32;
-layout (constant_id = 5) const uint32_t Clamp = 0;
-layout (constant_id = 6) const uint32_t D_split = 16;
+layout (constant_id =  0) const uint32_t WorkGroupSize = 128;
+layout (constant_id =  1) const uint32_t Br = 1;
+layout (constant_id =  2) const uint32_t Bc = 32;
+layout (constant_id =  3) const uint32_t HSK = 32;
+layout (constant_id =  4) const uint32_t HSV = 32;
+layout (constant_id =  5) const uint32_t Clamp = 0;
+layout (constant_id =  6) const uint32_t D_split = 16;
+layout (constant_id =  7) const uint32_t row_split = 1;
+layout (constant_id =  8) const uint32_t SubGroupSize = 32;
+layout (constant_id =  9) const uint32_t SHMEM_STAGING = 0;
+layout (constant_id = 10) const uint32_t Flags = 0;
+layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
+
+const bool USE_MASK_OPT    = (Flags & 1) != 0;
+const bool MASK_ENABLE     = (Flags & 2) != 0;
+const bool LOGIT_SOFTCAP   = (Flags & 4) != 0;
+const bool OLD_AMD_WINDOWS = (Flags & 8) != 0;
 
 // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
 const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -57,12 +67,17 @@ layout (push_constant) uniform parameter {
 } p;
 
 #define SINK_ENABLE_BIT (1<<24)
-#define MASK_ENABLE_BIT (1<<16)
 #define N_LOG2_MASK 0xFFFF
 
 layout (binding = 4) readonly buffer S {float data_s[];};
 
 layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
+layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
+
+layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
 
 #define BINDING_IDX_K 0
 #define BINDING_IDX_V 1
@@ -74,17 +89,21 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 #endif
 
+#ifndef BLOCK_SIZE
+#define BLOCK_SIZE 1
+#endif
+
 #if defined(DATA_A_F32)
 #undef BLOCK_SIZE
 #define BLOCK_SIZE 4
 #define BLOCK_BYTE_SIZE 16
 
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     // iqs is currently always zero in the flash attention shaders
     if (binding_idx == BINDING_IDX_K) {
-        return k_packed.k_data_packed[a_offset + ib];
+        return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
     } else {
-        return v_packed.v_data_packed[a_offset + ib];
+        return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
     }
 }
 #endif
@@ -92,7 +111,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
 #if defined(DATA_A_Q4_0)
 #define BLOCK_BYTE_SIZE 18
 
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     if (binding_idx == BINDING_IDX_K) {
         uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
         uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -100,7 +119,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
         vui_lo >>= shift;
         vui_hi >>= shift;
 
-        return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
     } else {
         uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
         uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -108,24 +127,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
         vui_lo >>= shift;
         vui_hi >>= shift;
 
-        return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
     }
 }
 #endif
 
 #if defined(DATA_A_Q8_0)
 #define BLOCK_BYTE_SIZE 34
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     if (binding_idx == BINDING_IDX_K) {
         const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
         const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-        return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
     } else {
         const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
         const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-        return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
     }
 }
 #endif
@@ -165,7 +184,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC
 }
 
 uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
-         iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
+         gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
          q_stride, k_stride, v_stride, m_stride;
 
 void init_indices()
@@ -173,12 +192,25 @@ void init_indices()
     N = p.N;
     KV = p.KV;
 
-    i = gl_WorkGroupID.x;
-    split_k_index = 0;
-
     if (p.k_num > 1) {
+        if (p.gqa_ratio > 1) {
+            i = 0;
+            // batch and split_k share gl_WorkGroupID.x
+            gqa_iq1 = gl_WorkGroupID.x / p.k_num;
+            split_k_index = gl_WorkGroupID.x % p.k_num;
+        } else {
+            gqa_iq1 = 0;
+            split_k_index = gl_WorkGroupID.x % p.k_num;
+            i = gl_WorkGroupID.x / p.k_num;
+        }
+    } else if (p.gqa_ratio > 1) {
         i = 0;
-        split_k_index = gl_WorkGroupID.x;
+        gqa_iq1 = gl_WorkGroupID.x;
+        split_k_index = 0;
+    } else {
+        i = gl_WorkGroupID.x;
+        gqa_iq1 = 0;
+        split_k_index = 0;
     }
 
     Tr = CEIL_DIV(N, Br);
@@ -218,3 +250,15 @@ void init_indices()
     // and breaking the alignment detection.
     m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
 }
+
+// Bias applied to softmax to stay in fp16 range.
+// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
+const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    uint32_t offset = (iq2 + r) * HSV / 4 + c;
+    data_ov4[o_offset + offset] = D_TYPEV4(elems);
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index c995ab14..526e8da3 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -7,6 +7,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
 #extension GL_KHR_shader_subgroup_vote : enable
 #extension GL_KHR_memory_scope_semantics : enable
 #extension GL_KHR_cooperative_matrix : enable
@@ -14,12 +15,12 @@
 #include "types.glsl"
 #include "flash_attn_base.glsl"
 
-const uint32_t HSK_per_thread = HSK / D_split;
-const uint32_t HSV_per_thread = HSV / D_split;
+// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
+const uint32_t MatBr = 16;
+const uint32_t MatBc = 16;
 
-const uint32_t row_split = 4;
 const uint32_t rows_per_thread = Br / row_split;
-const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
+const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
 
 
@@ -31,33 +32,28 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
-// Store the output when doing grouped query attention.
-// Rows index by Q's dimension 2, and the first N rows are valid.
-D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    uint32_t offset = (iq2 + r) * HSV + c;
-    data_o[o_offset + offset] = D_TYPE(elem);
-    return elem;
-}
-
-// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
-const uint32_t MatBr = 16;
-const uint32_t MatBc = 16;
-
-shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
-shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
+shared float tmpsh[row_split];
 
 const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
 shared f16vec4 Qf[Br * qstride];
 
+const uint psh_stride = Br / 4 + 2;
+shared f16vec4 Psh[Bc * psh_stride];
+
 // Avoid padding for hsk==256 to make it fit in 48KB shmem.
-const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
-shared ACC_TYPE sfsh[Bc * sfshstride];
+const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
+shared ACC_TYPEV4 sfsh[Bc * sfshstride];
 
-const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
-shared f16vec4 ksh[Bc * kshstride];
+const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
+const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
+const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
+const uint vsh_stride = v_cols;
+shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
 
-shared float slope[Br];
+const uint32_t osh_stride = row_split * MatBr / 4;
+shared f16vec4 pvsh[MatBc * osh_stride];
+
+shared ACC_TYPE slope[Br];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -69,9 +65,9 @@ void main() {
     const uint32_t tid = gl_LocalInvocationIndex;
 
     const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+    const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
     const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
-    const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
-    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
+    const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
 
 #define tile_row(r) (row_tid * rows_per_thread + (r))
 
@@ -82,15 +78,10 @@ void main() {
                 Qf[i + tid] = f16vec4(0);
             }
         }
-        [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
-            if (i + tid < Bc * kshstride) {
-                ksh[i + tid] = f16vec4(0);
-            }
-        }
         barrier();
     }
 
-    uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+    uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
 
     [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
         uint32_t d = (idx + tid) % (HSK / 4);
@@ -102,10 +93,10 @@ void main() {
     }
     barrier();
 
-    ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
-    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Of[r][d] = ACC_TYPEV4(0.0);
+    f16vec4 Of[rows_per_thread][d_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
+            Of[r][d] = f16vec4(0.0);
         }
     }
 
@@ -125,15 +116,17 @@ void main() {
             uint r = tid;
             slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
         }
-        barrier();
     } else {
         if (tid < Br) {
             uint r = tid;
-            slope[r] = 1.0;
+            slope[r] = ACC_TYPE(1.0);
         }
-        barrier();
     }
 
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
 #if BLOCK_SIZE > 1
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
@@ -141,65 +134,114 @@ void main() {
     uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
     uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
 #endif
-    uint32_t m_offset = 0;
+    uint32_t m_offset = gqa_iq1*KV;
     if (p.nem2 != 1 || p.nem3 != 1) {
-        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+    uint32_t mask_opt_bits = 0;
+    f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        float mask_cache[Bc * Br / WorkGroupSize];
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+        [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
+            mask_cache[idx] = f16vec4(0);
+        }
 
-            float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
-                        mask_cache[idx / WorkGroupSize] = m;
-                        max_mask = max(max_mask, m);
-                    }
-                }
+        if (MASK_ENABLE) {
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
             }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
-            barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
-            }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
+            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
                 continue;
             }
-        }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-        [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
-            uint32_t d = (idx + tid) % (HSK / 4);
-            uint32_t c = (idx + tid) / (HSK / 4);
-            if (c < Bc && d < HSK / 4) {
-                f16vec4 K_Tf = f16vec4(0);
-                if (!KV_bounds_check || j * Bc + c < KV) {
-#if BLOCK_SIZE > 1
-                    uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
-                    uint ib = coord / BLOCK_SIZE;
-                    uint iqs = (coord % BLOCK_SIZE);
-                    K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
-#else
-                    K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
-#endif
+                float max_mask = NEG_FLT_MAX_OVER_2;
+                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                    uint32_t c = (idx + tid) / (Br / 4);
+                    uint32_t r = (idx + tid) % (Br / 4);
+                    if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                        if ((!KV_bounds_check || j * Bc + c < KV)) {
+                            f16vec4 m;
+                            if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
+                                max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
+                            } else if (i * Br + r * 4 + 2 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                            0.0);
+                                max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
+                            } else if (i * Br + r * 4 + 1 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                            0.0,
+                                            0.0);
+                                max_mask = max(max(max_mask, float(m[0])), float(m[1]));
+                            } else if (i * Br + r * 4 < p.nem1) {
+                                m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                            0.0,
+                                            0.0,
+                                            0.0);
+                                max_mask = max(max_mask, float(m[0]));
+                            } else {
+                                m = f16vec4(0.0);
+                            }
+                            mask_cache[idx / WorkGroupSize] = m;
+                        }
+                    }
+                }
+                // skip the block if the mask is entirely -inf
+                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+                barrier();
+                if (gl_SubgroupInvocationID == 0) {
+                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+                }
+                barrier();
+                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                    max_mask = max(max_mask, tmpsh[s]);
+                }
+                if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                    continue;
                 }
-
-                ksh[c * kshstride + d] = K_Tf;
             }
         }
-        barrier();
+
+        if (SHMEM_STAGING != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK_pad / 4);
+                uint32_t c = (idx + tid) / (HSK_pad / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
+                    f16vec4 K_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+#endif
+                    }
+
+                    kvsh[c * kvsh_stride + d] = K_Tf;
+                }
+            }
+            barrier();
+        }
 
         // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
         // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
@@ -208,11 +250,59 @@ void main() {
         coopmat KMat;
         coopmat QMat;
 
-        for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
-            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
+        [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
+            // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
+            // If not, f16 K is loaded directly from global memory if aligned, otherwise
+            // staged through a Bc * MatBr size staging buffer.
+            // If K is not type f16, then it is always staged for dequantization.
+            if (SHMEM_STAGING == 0) {
+#if BLOCK_SIZE == 1
+            if (KV_bounds_check || d * 16 + 16 > HSK) {
+#endif
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t col_vec = (idx + tid) % (MatBr / 4);
+                uint32_t row = (idx + tid) / (MatBr / 4);
+                if (idx + tid < Bc * MatBr / 4) {
+                    f16vec4 K_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
+#endif
+                    }
 
-            uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
-            coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+                    kvsh[row * kvsh_stride + col_vec] = K_Tf;
+                }
+            }
+            barrier();
+#if BLOCK_SIZE == 1
+            }
+#endif
+
+#if BLOCK_SIZE == 1
+            if (KV_bounds_check || d * 16 + 16 > HSK)
+#endif
+            {
+                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
+                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+            }
+#if BLOCK_SIZE == 1
+            else {
+                const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
+                coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+            }
+#endif
+            } else {
+                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
+                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+            }
+
+            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
 
             SfMat = coopMatMulAdd(KMat, QMat, SfMat);
         }
@@ -221,27 +311,27 @@ void main() {
         coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
         barrier();
 
-        if (p.logit_softcap != 0.0f) {
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) / Br;
-                uint32_t r = (idx + tid) % Br;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
+        if (LOGIT_SOFTCAP) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
                 }
             }
             barrier();
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float f = mask_cache[idx / WorkGroupSize];
-                        sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+                        // Mask nem1 bounds check is handled when loading masks
+                        ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
+                        ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
+                        sfsh[c * sfshstride + r] += slopes * masks;
                     }
                 }
             }
@@ -250,143 +340,237 @@ void main() {
 
         float eMf[rows_per_thread];
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint r_vec  = tile_row(r) / 4;
+            const uint r_comp = tile_row(r) % 4;
+
             float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                     continue;
                 }
-                rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
+                rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
             }
             float Moldf = Mf[r];
 
+            // Compute max across the row
+            rowmaxf = subgroupMax(rowmaxf);
+
             // M = max(rowmax, Mold)
             // P = e^(S - M)
             // eM = e^(Mold - M)
             Mf[r] = max(rowmaxf, Moldf);
             eMf[r] = exp(Moldf - Mf[r]);
-        }
 
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
-            }
-        }
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             Lf[r] = eMf[r]*Lf[r];
         }
 
-        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                continue;
-            }
-            float Pf[rows_per_thread];
+        [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+            const uint d_local = d0 / threads_per_rowgroup;
             [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
-                Lf[r] += Pf[r];
+                Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
             }
-            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-#if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
-#else
-                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
-#endif
-                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                    Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
+        }
+
+        // Calculate and store Pf in Psh
+        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            const uint col = c * cols_per_iter + col_tid;
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
+                const uint row = tile_row(r);
+                if (KV_bounds_check && j * Bc + col >= KV) {
+                    Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
+                } else {
+                    const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
+                    const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
+                    [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
+                        Lf[r + vec_idx] += Pf[vec_idx];
+                    }
+                    Psh[col * psh_stride + row / 4] = Pf;
                 }
             }
         }
 
-        barrier();
-    }
+        if (SHMEM_STAGING != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSV_pad / 4);
+                uint32_t c = (idx + tid) / (HSV_pad / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
+                    f16vec4 V_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                        V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
+#endif
+                    }
 
-    // prevent race on tmpsh
-    barrier();
-
-    // reduce across threads
-
-    float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        FLOAT_TYPE M = Mf[r];
-        tmpsh[tid] = M;
-        // Compute max across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-            M = max(M, tmpsh[tid ^ s]);
-            barrier();
-            tmpsh[tid] = M;
-            barrier();
-        }
-        rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
-        barrier();
-    }
-
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        Moldf[r] = Mf[r];
-
-        // M = max(rowmax, Mold)
-        // eM = e^(Mold - M)
-        Mf[r] = max(rowmaxf[r], Moldf[r]);
-        eMf[r] = exp(Moldf[r] - Mf[r]);
-
-        Lf[r] = eMf[r]*Lf[r];
-    }
-
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        FLOAT_TYPE L = Lf[r];
-        tmpsh[tid] = L;
-        // Compute sum across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-            L += tmpsh[tid ^ s];
-            barrier();
-            tmpsh[tid] = L;
-            barrier();
-        }
-        Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
-        barrier();
-    }
-
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-
-            Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
-            tmpshv4[tid] = Of[r][d];
-
-            barrier();
-            [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-                Of[r][d] += tmpshv4[tid ^ s];
-                barrier();
-                tmpshv4[tid] = Of[r][d];
-                barrier();
+                    kvsh[c * kvsh_stride + d] = V_Tf;
+                }
             }
-            Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
-            barrier();
         }
-    }
+        barrier();
 
-    // If there is split_k, then the split_k resolve shader does the final
-    // division by L. Store the intermediate O value and per-row m and L values.
-    if (p.k_num > 1) {
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
+        const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
 
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
+        // Each subgroup handles HSV/4 columns
+        [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
+            const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
+
+            coopmat PVMat = coopmat(0);
+
+            // Preload V tiles for [Bc, 16 * num subgroups]
+            const uint v_rows = Bc;
+            const uint v_total = v_rows * v_cols;
+            const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
+
+            // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
+            // If not, f16 V is loaded directly from global memory if aligned, otherwise
+            // staged through a Bc * MatBr size staging buffer.
+            // If V is not type f16, then it is always staged for dequantization.
+            if (SHMEM_STAGING == 0) {
+#if BLOCK_SIZE == 1
+            // For f16, only preload if not aligned
+            if (KV_bounds_check) {
+#endif
+            [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
+                const uint idx = i * gl_WorkGroupSize.x + tid;
+                const uint row = idx / v_cols;
+                const uint col = idx % v_cols;
+
+                const uint v_row = j * Bc + row;
+                const uint v_col = hsv_tile * MatBc * row_split + col * 4;
+
+                const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
+                const uint ib = coord / BLOCK_SIZE;
+                const uint iqs = coord % BLOCK_SIZE;
+
+                if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
+#if BLOCK_SIZE > 1
+                    kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                    kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
+#endif
+                } else {
+                    kvsh[row * vsh_stride + col] = f16vec4(0.0f);
+                }
+            }
+
+#if BLOCK_SIZE == 1
+            }
+#endif
+            }
+            barrier();
+
+            const uint o_offset = gl_SubgroupID * MatBr / 4;
+
+            if (hsv_offset < HSV_pad) {
+                [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
+                    coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
+
+                    if (SHMEM_STAGING == 0) {
+#if BLOCK_SIZE == 1
+                    if (!KV_bounds_check) {
+                        // F16 values can be loaded directly from global memory
+                        const uint v_tile_row = j * Bc + bc_chunk * MatBc;
+                        const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
+                        coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+                    } else
+#endif
+                    {
+                        const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
+                        coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    }
+                    } else {
+                        const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
+                        coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    }
+
+                    PVMat = coopMatMulAdd(KMat, QMat, PVMat);
+                }
+
+                // Store PVMat to pvsh and load into Of
+                coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
+            }
+
+            barrier();
+
+            const uint hsv_per_tile = row_split * MatBc;
+            const uint hsv_base = hsv_tile * hsv_per_tile;
+            const uint d_values_per_tile = hsv_per_tile / 4;
+
+            const uint d_start = hsv_tile * d_values_per_tile;
+            const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+
+                [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
+                    const uint d = d_local * threads_per_rowgroup + col_tid;
+                    const uint hsv_col = 4 * d;
+
+                    if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
+                        const uint local_hsv = (hsv_col - hsv_base) / 4;
+                        Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
                     }
                 }
             }
         }
 
-        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            if (tile_row(r) < N) {
-                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
-                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+        barrier();
+    }
+
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        Lf[r] = subgroupAdd(Lf[r]);
+    }
+
+    // If there is split_k, then the split_k resolve shader does the final
+    // division by L. Store the intermediate O value and per-row m and L values.
+    if (p.k_num > 1) {
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                if (tile_row(r) < N) {
+                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                        const uint d = d0 + col_tid;
+                        if (d >= HSV/4) break;
+                        const uint d_local = d0 / threads_per_rowgroup;
+                        gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
+                    }
+                }
+            }
+
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                if (tile_row(r) < N) {
+                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+                }
+            }
+        } else {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                const uint global_row = i * Br + row;
+
+                if (global_row < N) {
+                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
+
+                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                        const uint d = d0 + col_tid;
+                        if (d >= HSV/4) break;
+                        data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
+                    }
+                }
+
+                if (global_row < N && col_tid == 0) {
+                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
+                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
+                }
             }
         }
 
@@ -403,8 +587,9 @@ void main() {
             if (sink > Mf[r]) {
                 ms = exp(Mf[r] - sink);
 
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    Of[r][d] *= ACC_TYPE(ms);
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    Of[r][d_local] *= float16_t(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -419,34 +604,37 @@ void main() {
         Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
-    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+        const uint d_local = d0 / threads_per_rowgroup;
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Of[r][d] *= ACC_TYPE(Lfrcp[r]);
-#if defined(ACC_TYPE_MAX)
-            Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+            Of[r][d_local] *= float16_t(Lfrcp[r]);
+#if defined(FLOAT_TYPE_MAX)
+            Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
 #endif
         }
     }
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
-                    }
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV / 4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
                 }
             }
         }
     } else {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (i * Br + tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
-                    }
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV / 4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
                 }
             }
         }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
index 9a719963..0ea18134 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
@@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
     return max(elem0, elem1);
 }
 
-#if defined(BLOCK_SIZE)
+#if BLOCK_SIZE > 1
 #define DECODEFUNC , DEQUANTFUNC
 #else
 #define DECODEFUNC
@@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
+// Store O values for non-GQA split_k. Rows are tokens, not heads.
+D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
+    uint32_t global_row = i * Br + r;
+    if (global_row < N && c < HSV) {
+        uint32_t o_off = HSV * p.ne1
+            * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+        data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
+    }
+    return elem;
+}
+
+// Store L/M values for non-GQA split_k.
+ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
+    uint32_t global_row = i * Br + r;
+    if (global_row < N && c == 0) {
+        uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
+            + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+        data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
+    }
+    return elem;
+}
+
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
@@ -85,7 +107,7 @@ void main() {
 
     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
 
-#if defined(BLOCK_SIZE)
+#if BLOCK_SIZE > 1
     tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
     tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
 #endif
@@ -98,7 +120,7 @@ void main() {
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
     {
         q_stride &= ~7;
-#if !defined(BLOCK_SIZE)
+#if BLOCK_SIZE == 1
         k_stride &= ~7;
         v_stride &= ~7;
 #endif
@@ -111,13 +133,13 @@ void main() {
     coopmat Q;
     coopmat Qf16;
 
-    uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
+    uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
     coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
 
     Qf16 = coopmat(Q);
     Qf16 *= float16_t(p.scale);
 
-    coopmat O = coopmat(0);
+    coopmat O = coopmat(0);
 
     coopmat L, M;
 
@@ -138,48 +160,67 @@ void main() {
         coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
     }
 
-    uint32_t m_offset = 0;
+    const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc);
+    // mo_offset will point to the tile starting at row i*Br and col 0
+    uint32_t mo_offset = mo_stride * i;
+
+    uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
     if (p.nem2 != 1 || p.nem3 != 1) {
-        m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+        m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+        mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride;
     }
 
+    uint32_t mask_opt = 0;
+    uint32_t mask_opt_idx = ~0;
+
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        coopmat mv;
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+        coopmat mv = coopmat(0);
+        if (MASK_ENABLE) {
 
-            if (nem1_bounds_check) {
-                tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
-                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
-                tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-                coopmat mvmax;
+                if (nem1_bounds_check) {
+                    tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+                    tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
 
-                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    coopmat mvmax;
 
-                // skip the block if the mask is entirely -inf
-                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
-                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
-                    continue;
-                }
-            } else {
-                tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
-                // Don't clamp against nem1 when GQA is enabled
-                uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
-                tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    // skip the block if the mask is entirely -inf
+                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                        continue;
+                    }
+                } else {
+                    tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+                    // Don't clamp against nem1 when GQA is enabled
+                    uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
+                    tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
+                    tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
-                coopmat mvmax;
+                    coopmat mvmax;
 
-                coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
-
-                // skip the block if the mask is entirely -inf
-                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
-                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
-                    continue;
+                    coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+                    // skip the block if the mask is entirely -inf
+                    coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                    if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                        continue;
+                    }
                 }
             }
         }
@@ -192,14 +233,14 @@ void main() {
         coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
         S = coopMatMulAdd(Qf16, K_T, S);
 
-        if (p.logit_softcap != 0.0f) {
+        if (LOGIT_SOFTCAP) {
             [[unroll]]
             for (int k = 0; k < S.length(); ++k) {
                 S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
             }
         }
 
-        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+        if (MASK_ENABLE) {
             S += slopeMat*coopmat(mv);
         }
 
@@ -218,6 +259,8 @@ void main() {
 
         coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
 
+        rowmax += coopmat(FATTN_KQ_MAX_OFFSET);
+
         coopmat Mold = M;
 
         // M = max(rowmax, Mold)
@@ -260,11 +303,8 @@ void main() {
         // resize eM by using smear/reduce
         coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
 
-        // multiply with fp16 accumulation, then add to O.
-        coopmat PV = coopmat(0);
-        PV = coopMatMulAdd(P_A, V, PV);
-
-        O = eMdiag * O + coopmat(PV);
+        O *= coopmat(eMdiag);
+        O = coopMatMulAdd(P_A, V, O);
     }
 
     // If there is split_k, then the split_k resolve shader does the final
@@ -272,12 +312,19 @@ void main() {
     if (p.k_num > 1) {
         coopmat O_D = coopmat(O);
 
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
-        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
 
-        o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
-        coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
-        coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
+            coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+        } else {
+            coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
+            coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
+            coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
+        }
         return;
     }
 
@@ -305,7 +352,7 @@ void main() {
             if (sink > Mr[i]) {
                 ms = exp(Mr[i] - sink);
 
-                O[i] *= ms;
+                O[i] *= float16_t(ms);
             } else {
                 vs = exp(sink - Mr[i]);
             }
@@ -319,15 +366,16 @@ void main() {
         Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
     }
 
-    O = Ldiag*O;
+    coopmat O_D = coopmat(O);
+
+    O_D = coopmat(Ldiag)*O_D;
 
 #if defined(ACC_TYPE_MAX)
-    [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+    [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); }
 #endif
 
-    uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
 
-    coopmat O_D = coopmat(O);
     if (p.gqa_ratio > 1) {
         coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
     } else {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
new file mode 100644
index 00000000..0e417708
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp
@@ -0,0 +1,162 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 128;
+layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
+layout (constant_id = 2) const uint Br = 32;
+layout (constant_id = 3) const uint Bc = 32;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {float16_t data_a[];};
+layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
+layout (binding = 1) writeonly buffer D {uint data_d[];};
+
+layout (push_constant) uniform parameter {
+    uint nem0;
+    uint nem1;
+    uint nem2;
+    uint nbm1;
+    uint nbm2;
+    uint nbm3;
+    uint nbd1;
+    uint nbd2;
+    uint nbd3;
+};
+
+#define MASK_OPT_ALL_NEG_INF 1
+#define MASK_OPT_ALL_ZERO 2
+
+shared float minsh[NUM_SUBGROUPS];
+shared float maxsh[NUM_SUBGROUPS];
+
+float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
+
+void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) {
+    const uint tid = gl_LocalInvocationIndex;
+
+    [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+        float min_v = FLT_MAX_OVER_2;
+        float max_v = -FLT_MAX_OVER_2;
+        [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
+            uint j0 = (i + tid) % (Bc / 4);
+            uint j1 = (i + tid) / (Bc / 4);
+
+            j0 *= 4;
+            j0 += (i0 * 16 + block_x) * Bc;
+            j1 += i1 * Br;
+
+            if (!need_bounds_check || j0 + 3 < nem0) {
+                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
+                [[unroll]] for (int c = 0; c < 4; ++c) {
+                    min_v = min(min_v, f[c]);
+                    max_v = max(max_v, f[c]);
+                }
+            } else {
+                [[unroll]] for (int c = 0; c < 4; ++c) {
+                    if (j0 + c < nem0) {
+                        float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+                        min_v = min(min_v, f);
+                        max_v = max(max_v, f);
+                    }
+                }
+            }
+        }
+        min_v = subgroupMin(min_v);
+        max_v = subgroupMax(max_v);
+        if (gl_SubgroupInvocationID == 0) {
+            minsh[gl_SubgroupID] = min_v;
+            maxsh[gl_SubgroupID] = max_v;
+        }
+        barrier();
+        if (tid == 0) {
+            [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                min_v = min(min_v, minsh[i]);
+                max_v = max(max_v, maxsh[i]);
+            }
+            if (max_v <= -FLT_MAX_OVER_2) {
+                result |= 1 << (2*block_x);
+            }
+            if (min_v == 0.0f && max_v == 0.0f) {
+                result |= 2 << (2*block_x);
+            }
+        }
+        barrier();
+    }
+}
+
+// For each Br x Bc block of the mask (input) buffer, read all values and check
+// if it's all -inf or all zero. Write out a two-bit code indicating which it is
+// (or zero for neither). Each workgroup processes 16 tiles and writes out a
+// 32-bit result mask.
+//
+// TODO: This is a lot of work per workgroup, might make sense to split this into
+// more workgroups in the future.
+void main() {
+    // Each workgroup handles a row
+    const uint tid = gl_LocalInvocationIndex;
+    const uint i0 = gl_WorkGroupID.x;
+    const uint i1 = gl_WorkGroupID.y;
+    const uint i2 = gl_WorkGroupID.z % nem2;
+    const uint i3 = gl_WorkGroupID.z / nem2;
+
+    uint result = 0;
+
+    // Fast path for fully in-bounds blocks where we can do f16vec4 loads
+    if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
+        ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
+        if ((i0 + 1) * 16 * Bc <= nem0) {
+            loadvec4(result, i0, i1, i2, i3, false);
+        } else {
+            loadvec4(result, i0, i1, i2, i3, true);
+        }
+    } else {
+        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
+            float min_v = FLT_MAX_OVER_2;
+            float max_v = -FLT_MAX_OVER_2;
+            [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
+                if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
+                    continue;
+                }
+                uint j0 = (i + tid) % Bc;
+                uint j1 = (i + tid) / Bc;
+
+                j0 += (i0 * 16 + block_x) * Bc;
+                j1 += i1 * Br;
+
+                if (j0 < nem0 && j1 < nem1) {
+                    float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
+                    min_v = min(min_v, f);
+                    max_v = max(max_v, f);
+                }
+            }
+            min_v = subgroupMin(min_v);
+            max_v = subgroupMax(max_v);
+            if (gl_SubgroupInvocationID == 0) {
+                minsh[gl_SubgroupID] = min_v;
+                maxsh[gl_SubgroupID] = max_v;
+            }
+            barrier();
+            if (tid == 0) {
+                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
+                    min_v = min(min_v, minsh[i]);
+                    max_v = max(max_v, maxsh[i]);
+                }
+                if (max_v <= -FLT_MAX_OVER_2) {
+                    result |= 1 << (2*block_x);
+                }
+                if (min_v == 0.0f && max_v == 0.0f) {
+                    result |= 2 << (2*block_x);
+                }
+            }
+            barrier();
+        }
+    }
+
+    if (tid == 0) {
+        data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
+    }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
index 4eaddd31..68917fc0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
@@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];};
 
 layout (push_constant) uniform parameter {
     uint D;
-    uint N;
+    uint ne1;
+    uint ne2;
     uint ne3;
     uint k_num;
     uint sinks;
@@ -24,15 +25,15 @@ void main() {
     // Each workgroup handles a row
     const uint n = gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
-    const uint iq3 = gl_WorkGroupID.z;
+    const uint i2 = gl_WorkGroupID.z % p.ne2;
+    const uint i3 = gl_WorkGroupID.z / p.ne2;
 
     uint D = p.D;
-    uint N = p.N;
     uint k_num = p.k_num;
 
-    uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
-    uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
-    uint lm_stride = N * 2;
+    uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
+    uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
+    uint lm_stride = p.ne1 * 2;
 
     // Compute the max m value for the row
     float m_max = -1.0/0.0;
@@ -99,7 +100,7 @@ void main() {
     if (d < D) {
         float O = 0.0;
         [[unroll]] for (uint k = 0; k < k_num; ++k) {
-            uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
+            uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
             float m = data_a[m_offset + k * lm_stride];
             O += exp(m - m_max) * data_a[o_offset];
         }
@@ -115,6 +116,6 @@ void main() {
         const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
         O = clamp(O, -FLT_MAX, FLT_MAX);
 
-        data_d[iq3 * D * N + D * n + d] = O;
+        data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
     }
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
new file mode 100644
index 00000000..f008859b
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
@@ -0,0 +1,128 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+layout(constant_id = 0) const uint S_V = 128;
+layout(constant_id = 1) const uint KDA = 0;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+    uint H;
+    uint n_tokens;
+    uint n_seqs;
+    uint s_off;
+    uint sq1, sq2, sq3;
+    uint sv1, sv2, sv3;
+    uint sb1, sb2, sb3;
+    uint neq1, rq3;
+    float scale;
+};
+
+layout(binding = 0) readonly  buffer QBuf     { FLOAT_TYPE data_q[];     };
+layout(binding = 1) readonly  buffer KBuf     { FLOAT_TYPE data_k[];     };
+layout(binding = 2) readonly  buffer VBuf     { FLOAT_TYPE data_v[];     };
+layout(binding = 3) readonly  buffer GBuf     { FLOAT_TYPE data_g[];     };
+layout(binding = 4) readonly  buffer BetaBuf  { FLOAT_TYPE data_beta[];  };
+layout(binding = 5) readonly  buffer StateBuf { FLOAT_TYPE data_state[]; };
+layout(binding = 6)           buffer DstBuf   { FLOAT_TYPE data_dst[];   };
+
+shared FLOAT_TYPE s_k[S_V];
+shared FLOAT_TYPE s_q[S_V];
+shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
+
+void main() {
+    const uint head_id = gl_WorkGroupID.x;
+    const uint seq_id  = gl_WorkGroupID.y;
+    const uint col     = gl_LocalInvocationID.x;
+
+    const uint iq1 = head_id % neq1;
+    const uint iq3 = seq_id / rq3;
+
+    const uint state_size = S_V * S_V;
+    const uint state_base = (seq_id * H + head_id) * state_size;
+
+    FLOAT_TYPE state[S_V];
+    [[unroll]] for (uint i = 0; i < S_V; i++) {
+        state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
+    }
+
+    uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
+
+    for (uint t = 0; t < n_tokens; t++) {
+        const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const uint k_off = q_off;
+        const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
+
+        s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
+        s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
+
+        const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
+
+        if (KDA != 0) {
+            const uint g_base = gb_off * S_V;
+            s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
+        }
+
+        barrier();
+
+        const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
+        const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
+
+        if (KDA == 0) {
+            const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
+
+            FLOAT_TYPE kv_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                kv_col += dot(
+                    vec4(state[i], state[i+1], state[i+2], state[i+3]),
+                    vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
+                );
+            }
+
+            FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
+
+            FLOAT_TYPE attn_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+                sv = g_val * sv + kv * delta_col;
+                state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
+
+                attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
+            }
+
+            data_dst[attn_off + col] = attn_col * scale;
+        } else {
+            FLOAT_TYPE kv_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
+                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+                kv_col += dot(gv * sv, kv);
+            }
+
+            FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
+
+            FLOAT_TYPE attn_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
+                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+                sv = gv * sv + kv * delta_col;
+                state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
+
+                attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
+            }
+
+            data_dst[attn_off + col] = attn_col * scale;
+        }
+
+        attn_off += S_V * H;
+        barrier();
+    }
+
+    [[unroll]] for (uint i = 0; i < S_V; i++) {
+        data_dst[s_off + state_base + col * S_V + i] = state[i];
+    }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
index 83ef2f87..f9af4674 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
@@ -1,6 +1,6 @@
 #version 450
 
-#include "generic_head.glsl"
+#include "generic_unary_head.glsl"
 #include "types.glsl"
 
 #extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,22 @@
 
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
 shared FLOAT_TYPE sum[BLOCK_SIZE];
 
 void main() {
     const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
 
+    const uint i3 = row / (p.ne11 * p.ne12);
+    const uint i3_offset = i3 * p.ne12 * p.ne11;
+    const uint i2 = (row - i3_offset) / p.ne11;
+    const uint i2_offset = i2 * p.ne11;
+    const uint i1 = row - i3_offset - i2_offset;
+
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
+        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
         sum[tid] += xi * xi;
     }
 
@@ -33,9 +36,9 @@ void main() {
         barrier();
     }
 
-    const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
+    const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+    [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
+        data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
     }
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
index dfb78659..4aeda68c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
@@ -29,7 +29,10 @@ layout (push_constant) uniform parameter
 #ifdef MUL_MAT_ID
     uint nei0;
     uint ne11;
+    uint expert_i1;
+    uint nbi1;
 #else
+    uint base_work_group_y;
     uint ne02;
     uint ne12;
     uint broadcast2;
@@ -43,9 +46,9 @@ uint expert_id;
 
 void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.y;
+    const uint expert_i0 = gl_WorkGroupID.y;
 #else
-    const uint batch_idx = gl_GlobalInvocationID.y;
+    const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y;
 #endif
 
 #ifndef MUL_MAT_ID
@@ -60,7 +63,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
         batch_idx_a = i03 * p.ne02 + i02;
     }
 #else
-    expert_id = data_ids[expert_idx];
+    expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
 #endif
 
     a_offset =
@@ -71,13 +74,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #endif
     b_offset =
 #ifdef MUL_MAT_ID
-            (expert_idx % p.ne11) * p.stride_b;
+            (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
 #else
             batch_idx * p.batch_stride_b;
 #endif
     d_offset =
 #ifdef MUL_MAT_ID
-            expert_idx * p.stride_d;
+            expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
 #else
             batch_idx * p.batch_stride_d;
 #endif
@@ -103,12 +106,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
                     temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -158,12 +161,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -203,12 +206,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index 775e9a70..23f3bd8d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -90,6 +90,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -139,7 +141,7 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
+    const uint expert_idx = gl_WorkGroupID.z;
     if (ic * BN >= data_expert_count[expert_idx]) {
         return;
     }
@@ -149,7 +151,7 @@ void main() {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z;
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -366,7 +368,7 @@ void main() {
     const uint dc = ic * BN + warp_c * WN;
 
 #ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
 #ifdef COOPMAT
@@ -375,6 +377,7 @@ void main() {
         [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
             coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
 
+            barrier();
             [[unroll]] for (uint col = 0; col < TN; col += storestride) {
                 const uint row_i = dc + cm_col * TN + col + store_c;
                 if (row_i >= _ne1) break;
@@ -385,6 +388,7 @@ void main() {
                     data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
                 }
             }
+            barrier();
         }
     }
 #else
@@ -402,18 +406,22 @@ void main() {
                 // Full coopMat is within bounds, but stride_d is not aligned
                 coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
 
+                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
                 [[unroll]] for (uint col = 0; col < TN; col += storestride) {
                     data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
                 }
+                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
             } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
                 // Partial coopMat is within bounds
                 coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
 
+                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
                 [[unroll]] for (uint col = 0; col < TN; col += storestride) {
                     if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
                         data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
                     }
                 }
+                controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease);
             }
         }
     }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
index b6614d2f..497a18ff 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -53,6 +53,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -165,7 +167,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
         uint id = ids[iter++];
         uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
 
-        ballots_sh[gl_SubgroupID] = ballot;
+        if (gl_SubgroupInvocationID == 0) {
+            ballots_sh[gl_SubgroupID] = ballot;
+        }
         barrier();
 
         uint subgroup_base = 0;
@@ -197,7 +201,7 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
+    const uint expert_idx = gl_WorkGroupID.z;
     if (ic * BN >= data_expert_count[expert_idx]) {
         return;
     }
@@ -215,7 +219,7 @@ void main() {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z;
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -255,7 +259,7 @@ void main() {
 #else
     uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
     uint pos_b = batch_idx * p.batch_stride_b;
-    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
     uint stride_a = p.stride_a / QUANT_K;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
index 743004ff..26c5c12a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
@@ -43,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
         uint id = ids[iter++];
         uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
 
-        ballots_sh[gl_SubgroupID] = ballot;
+        if (gl_SubgroupInvocationID == 0) {
+            ballots_sh[gl_SubgroupID] = ballot;
+        }
         barrier();
 
         uint subgroup_base = 0;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index 335d7f6a..aae1c2e8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -57,6 +57,8 @@ layout (push_constant) uniform parameter
     uint nbi1;
     uint ne11;
 #else
+    uint base_work_group_z;
+    uint num_batches;
     uint k_split;
     uint ne02;
     uint ne12;
@@ -108,7 +110,7 @@ void main() {
     const uint ic = gl_WorkGroupID.y;
 
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.z;
+    const uint expert_idx = gl_WorkGroupID.z;
     if (ic * BN >= data_expert_count[expert_idx]) {
         return;
     }
@@ -118,7 +120,7 @@ void main() {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint batch_idx = gl_GlobalInvocationID.z;
+    const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z;
 
     const uint i13 = batch_idx / p.ne12;
     const uint i12 = batch_idx % p.ne12;
@@ -276,7 +278,7 @@ void main() {
     const uint dc = ic * BN + warp_c * WN;
 
 #ifndef MUL_MAT_ID
-    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches;
 #endif
 
     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
index 7f32dadf..9c297d1c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
@@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
         const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8      ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
                                                      (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147
 
-        buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
+        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32));
     }
 }
 
@@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
                               (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));
         }
 
-        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
+        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm));
     }
 }
 
@@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
         const uint is = iqs_k / 4;
         const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy;
 
-        buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
+        buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales));
     }
 }
 
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
index 9d6d3665..55b89f19 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
@@ -112,12 +112,11 @@ void rms_norm(uint num_iters) {
 #if RMS_NORM_ROPE_FUSION
     barrier();
     rope_params rp = p.rope;
-    uint rope_row = (samp*nchannels + channel)*nrows + row;
     for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) {
         if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) {
-            rope_neox(t, rope_row, rp);
+            rope_neox(t, row, channel, samp, rp);
         } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) {
-            rope_norm(t, rope_row, rp);
+            rope_norm(t, row, channel, samp, rp);
         }
     }
 #endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
index aacec984..2e534599 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl
@@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) {
     return 1.0f - min(1.0f, max(0.0f, y));
 }
 
-uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) {
+uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) {
 #if RMS_NORM_ROPE_FUSION
     // Per-row offset in shared memory
     const uint ix = i0;
 #else
-    const uint ix = i02*p.nb02 + i01*p.nb01 + i0;
+    const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0;
 #endif
     return ix;
 }
@@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out
     sin_theta = sin(theta) * mscale;
 }
 
-void rope_norm(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
+void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0;
-    const uint ix = rope_a_coord(i0, i01, i02, p);
+    uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0, i1, i2, i3, p);
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
         return;
     }
 
-    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 
@@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) {
     rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 }
 
-void rope_neox(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-
-    if (i0 >= ne0) {
+void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0/2;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0/2;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
         return;
     }
 
-    const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f);
+    const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f);
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
 
@@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) {
 }
 
 
-void rope_multi(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
+void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     // Fusion optimization: ROPE + VIEW + SET_ROWS.
     // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i.
     if (p.set_rows_stride != 0) {
-        idst = i01*ne0 + i0/2;
-        idst += rope_data_i[i02].x * p.set_rows_stride;
+        idst = i1*p.nb11 + i0/2;
+        idst += rope_data_i[i2].x * p.set_rows_stride;
     }
 
     if (i0 >= p.n_dims) {
@@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     float theta_base = 0.0;
     if (p.is_imrope != 0) {
         if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
-            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
         } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
         } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
-            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
         } else {
-            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
         }
     } else {
         if (sector < p.sections[0]) {
-            theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= p.sections[0] && sector < sec_w) {
-            theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f);
         }
         else if (sector >= sec_w + p.sections[2]) {
-            theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
+            theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f);
         }
     }
 
@@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) {
     rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta);
 }
 
-void rope_vision(const uint i0, const uint i1, rope_params p) {
-    uint ne0 = p.ncols;
-    uint ne1 = p.p_delta_rows;
-    uint ne2 = p.ne02;
-
-    if (i0 >= ne0) {
+void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) {
+    if (i0 >= p.ne00) {
         return;
     }
 
-    const uint i01 = i1 % ne1;
-    const uint i02 = i1 / ne1;
-
-    const uint idst = i1*ne0 + i0/2;
-    const uint ix = rope_a_coord(i0/2, i01, i02, p);
+    const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13;
+    const uint ix = rope_a_coord(i0/2, i1, i2, i3, p);
 
     const int sect_dims = p.sections[0] + p.sections[1];
     const int sec_w = p.sections[1] + p.sections[0];
@@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) {
     float theta_base = 0.0;
     if (sector < p.sections[0]) {
         const uint p0 = sector;
-        theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0);
+        theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0);
     }
     else if (sector >= p.sections[0] && sector < sec_w) {
         const uint p0 = sector - p.sections[0];
-        theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0);
+        theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0);
     }
 
     const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
index f7587468..1528fbee 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_multi(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_multi(i0, i1, i2, i3, pc);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
index acb8ed78..ad089609 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_neox(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_neox(i0, i1, i2, i3, pc);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
index 0033cdb2..11220817 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_norm(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_norm(i0, i1, i2, i3, pc);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
index 939cf3c5..ec6ceaca 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl
@@ -5,24 +5,29 @@
 
 struct rope_params {
     uint rope_mode;
-    uint ncols;
     uint nrows;
     uint n_dims;
     float freq_scale;
-    uint p_delta_rows;
     float freq_base;
     float ext_factor;
     float attn_factor;
     float corr_dims[2];
     float theta_scale;
     uint has_ff;
-    uint ne02;
-    uint nb01;
-    uint nb02;
     int sections[4];
     uint is_imrope;
     uint is_back;
     uint set_rows_stride;
+
+    uint ne00;
+    uint ne01;
+    uint ne02;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint nb11;
+    uint nb12;
+    uint nb13;
 };
 
 #endif // !defined(GGML_ROPE_PARAMS)
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
index d93800b5..ca71efb2 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
@@ -5,10 +5,13 @@
 
 void main() {
     const uint i0 = 2*gl_GlobalInvocationID.y;
-    // i1 is actually i2*nb2+i1, but the rows are contiguous
-    const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
-    if (i1 >= pc.nrows) {
+    const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
+    if (row >= pc.nrows) {
         return;
     }
-    rope_vision(i0, i1, pc);
+    const uint i3 = row / (pc.ne01*pc.ne02);
+    const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
+    const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
+
+    rope_vision(i0, i1, i2, i3, pc);
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp
new file mode 100644
index 00000000..a9c147bf
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp
@@ -0,0 +1,21 @@
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    data_d[i] = D_TYPE(sign(float(data_a[i])));
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
index d62696bc..6802b1fc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
@@ -5,8 +5,9 @@
 #include "types.glsl"
 
 layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+layout(constant_id = 1) const uint TOKENS_PER_WG = 16;
 
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;
 
 layout(binding = 0) readonly buffer Src0 { float src0[]; };
 layout(binding = 1) readonly buffer Src1 { float src1[]; };
@@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants {
 };
 
 void main() {
-    const uint global_thread_id = gl_GlobalInvocationID.x;
-    const uint i2 = gl_WorkGroupID.y;
+    const uint i1 = gl_GlobalInvocationID.x;
+    const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
     const uint i3 = gl_WorkGroupID.z;
 
-    if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
+    if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
         return;
     }
 
-    const uint i1 = global_thread_id;
     const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
     const uint src1_base = i1 * (nb11 / 4);
-    const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
 
     float sum = 0.0;
-    [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
-        const uint src0_idx = src0_base + i0;
-        const uint src1_idx = src1_base + i0;
-        sum += src0[src0_idx] * src1[src1_idx];
+
+    if (nc == 4) {
+        sum = dot(
+            vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
+            vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
+        );
+    } else {
+        [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
+            sum += src0[src0_base + i0] * src1[src1_base + i0];
+        }
     }
 
+    const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
     dst[dst_idx] = sum;
 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index bbdbf9dc..4b00ba3d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -330,7 +330,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
         std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
     #endif
 
-    // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
+    // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
     // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
     // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
     if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
@@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 }
 
 void process_shaders() {
-    std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
-
     // matmul
     for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
         // No coopmats
@@ -622,49 +620,63 @@ void process_shaders() {
         }
     }
 
-    // flash attention
-    for (const auto& f16acc : {false, true}) {
-        std::map fa_base_dict = base_dict;
-        fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
-        fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
-        if (f16acc) {
-            fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
+    for (const bool& fp16 : {false, true}) {
+        std::map base_dict;
+        if (fp16) {
+            base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
+        } else {
+            base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
         }
 
-        for (const auto& tname : type_names) {
-            if (tname == "bf16") continue;
-
-#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
-            } else {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
+        // flash attention
+        for (const bool& f16acc : {false, true}) {
+            std::map fa_base_dict = base_dict;
+            fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
+            fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
+            if (fp16 && f16acc) {
+                fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
             }
+
+            for (const auto& tname : type_names) {
+                if (tname == "bf16") continue;
+
+                if (fp16) {
+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
+                } else {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
+                }
 #endif
 #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            }
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
+                } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
+                }
 #endif
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+                }
+
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
+                } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
+                }
             }
         }
     }
 
+    std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
+
     for (const auto& tname : type_names) {
         // mul mat vec
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@@ -790,6 +802,8 @@ void process_shaders() {
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
     string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
 
+    string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {});
+
     string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
     string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
 
@@ -853,8 +867,12 @@ void process_shaders() {
     string_to_spv("hardswish_f32",  "hardswish.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("abs_f16",        "abs.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("abs_f32",        "abs.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("elu_f16",        "elu.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("elu_f32",        "elu.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("xielu_f16",      "xielu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("xielu_f32",      "xielu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("sgn_f16",        "sgn.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("sgn_f32",        "sgn.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
     string_to_spv("tri_f16",        "tri.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("tri_f32",        "tri.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
@@ -969,6 +987,8 @@ void process_shaders() {
 
     string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
+    string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
+
     string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
     string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
index 7fdb4c8c..3d7e59fd 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
@@ -1,20 +1,273 @@
 #ifndef GGML_WEBGPU_SHADER_LIB_HPP
 #define GGML_WEBGPU_SHADER_LIB_HPP
 
+#include "ggml-wgsl-shaders.hpp"
 #include "ggml.h"
 #include "pre_wgsl.hpp"
 
+#include 
+
+#include 
+#include 
 #include 
+#include 
 #include 
 
 #define GGML_WEBGPU_F16_SIZE_BYTES                   2
 #define GGML_WEBGPU_F32_SIZE_BYTES                   4
+#define GGML_WEBGPU_I32_SIZE_BYTES                   4
 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE     128u
 // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
 #define GGML_WEBGPU_KV_SEQ_PAD                       256u
 
-struct ggml_webgpu_flash_attn_shader_lib_context {
+#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
+
+// Matrix multiplication parameters
+
+// Register tiling parameters
+#define WEBGPU_MUL_MAT_TILE_M    8
+#define WEBGPU_MUL_MAT_TILE_N    8
+#define WEBGPU_MUL_MAT_WG_SIZE_M 8
+#define WEBGPU_MUL_MAT_WG_SIZE_N 8
+#define WEBGPU_MUL_MAT_TILE_K    32
+
+// Subgroup matrix parameters
+// The number of subgroups in the M dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_M        2
+// The number of subgroups in the N dimension
+#define WEBGPU_MUL_MAT_SUBGROUP_N        2
+// The number of subgroup matrices each subgroup accumulates over
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
+#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
+
+// Matrix-vector multiplication parameters
+#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
+
+// Must be multiple of 4 to work with vectorized paths, and must divide
+// mul_mat_vec wg size
+#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K         256
+
+#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64
+#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K         256
+
+// Requires 32 threads per output (wg_size/outputs_per_wg == 32)
+#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8
+// Requires at least two (and multiple of 2) k-quant blocks per tile
+#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K         512
+
+// default size for legacy matrix multiplication
+#define WEBGPU_MUL_MAT_WG_SIZE 256
+
+// Same hash combine function as in boost
+template  inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
+    seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+}
+
+struct ggml_webgpu_shader_lib_context {
+    ggml_tensor * src0;
+    ggml_tensor * src1;
+    ggml_tensor * src2;
+    ggml_tensor * src3;
+    ggml_tensor * src4;
+    ggml_tensor * dst;
+
+    uint32_t max_wg_size;
+    size_t   wg_mem_limit_bytes       = 0;
+    bool     inplace                  = false;
+    bool     overlap                  = false;
+    bool     src_overlap              = false;
+    bool     supports_subgroup_matrix = false;
+    uint32_t sg_mat_m                 = 0;
+    uint32_t sg_mat_n                 = 0;
+    uint32_t sg_mat_k                 = 0;
+    uint32_t max_subgroup_size        = 0;
+};
+
+struct webgpu_pipeline {
+    wgpu::ComputePipeline pipeline;
+    std::string           name;
+    std::shared_ptr context = nullptr;
+};
+
+struct ggml_webgpu_generic_shader_decisions {
+    uint32_t wg_size = 0;
+};
+
+/** Argsort **/
+
+struct ggml_webgpu_argsort_shader_lib_context {
+    uint32_t max_wg_size;
+    size_t   wg_mem_limit_bytes;
+    int32_t  order;
+};
+
+/** Set Rows **/
+
+struct ggml_webgpu_set_rows_pipeline_key {
+    int dst_type;
+    int vec4;
+    int i64_idx;
+
+    bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
+        return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
+    }
+};
+
+struct ggml_webgpu_set_rows_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.dst_type);
+        ggml_webgpu_hash_combine(seed, key.vec4);
+        ggml_webgpu_hash_combine(seed, key.i64_idx);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_set_rows_shader_decisions {
+    bool     vec4;
+    bool     i64_idx;
+    uint32_t wg_size;
+};
+
+/** Get Rows **/
+
+struct ggml_webgpu_get_rows_pipeline_key {
+    ggml_type src_type;
+    int       vectorized;
+
+    bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const {
+        return src_type == other.src_type && vectorized == other.vectorized;
+    }
+};
+
+struct ggml_webgpu_get_rows_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src_type);
+        ggml_webgpu_hash_combine(seed, key.vectorized);
+        return seed;
+    }
+};
+
+/** Pad **/
+struct ggml_webgpu_pad_pipeline_key {
+    bool circular;
+
+    bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
+};
+
+struct ggml_webgpu_pad_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.circular);
+        return seed;
+    }
+};
+
+/** Scale **/
+
+struct ggml_webgpu_scale_pipeline_key {
+    int inplace;
+
+    bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; }
+};
+
+struct ggml_webgpu_scale_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
+/** Concat **/
+
+struct ggml_webgpu_concat_pipeline_key {
+    int type;
+
+    bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
+};
+
+struct ggml_webgpu_concat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        return seed;
+    }
+};
+
+/** Repeat **/
+
+struct ggml_webgpu_repeat_pipeline_key {
+    int type;
+
+    bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; }
+};
+
+struct ggml_webgpu_repeat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        return seed;
+    }
+};
+
+/** Binary **/
+
+struct ggml_webgpu_binary_pipeline_key {
+    int  type;
+    int  op;
+    bool inplace;
+    bool overlap;
+    bool src_overlap;
+
+    bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
+        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap &&
+               src_overlap == other.src_overlap;
+    }
+};
+
+struct ggml_webgpu_binary_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        ggml_webgpu_hash_combine(seed, key.overlap);
+        ggml_webgpu_hash_combine(seed, key.src_overlap);
+        return seed;
+    }
+};
+
+/** Unary **/
+
+struct ggml_webgpu_unary_pipeline_key {
+    int  type;
+    int  op;
+    bool is_unary;  // many unary operators fall under the GGML_OP_UNARY umbrella
+    bool inplace;
+
+    bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
+        return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
+    }
+};
+
+struct ggml_webgpu_unary_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.is_unary);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        return seed;
+    }
+};
+
+/** FlashAttention */
+
+struct ggml_webgpu_flash_attn_pipeline_key {
     ggml_type kv_type;
     uint32_t  head_dim_qk;
     uint32_t  head_dim_v;
@@ -22,11 +275,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
     bool      has_mask;
     bool      has_sinks;
     bool      uses_logit_softcap;
-    uint32_t  sg_mat_m;
-    uint32_t  sg_mat_n;
-    uint32_t  sg_mat_k;
-    size_t    wg_mem_limit_bytes;
-    uint32_t  max_subgroup_size;
+
+    bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
+        return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
+               kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
+               uses_logit_softcap == other.uses_logit_softcap;
+    }
+};
+
+struct ggml_webgpu_flash_attn_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.kv_type);
+        ggml_webgpu_hash_combine(seed, key.head_dim_qk);
+        ggml_webgpu_hash_combine(seed, key.head_dim_v);
+        ggml_webgpu_hash_combine(seed, key.kv_direct);
+        ggml_webgpu_hash_combine(seed, key.has_mask);
+        ggml_webgpu_hash_combine(seed, key.has_sinks);
+        ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_flash_attn_shader_lib_context {
+    ggml_webgpu_flash_attn_pipeline_key key;
+    uint32_t                            sg_mat_m;
+    uint32_t                            sg_mat_n;
+    uint32_t                            sg_mat_k;
+    size_t                              wg_mem_limit_bytes;
+    uint32_t                            max_subgroup_size;
 };
 
 struct ggml_webgpu_flash_attn_shader_decisions {
@@ -35,12 +312,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
     uint32_t wg_size = 0;
 };
 
-struct ggml_webgpu_processed_shader {
-    std::string                             wgsl;
-    std::string                             variant;
-    ggml_webgpu_flash_attn_shader_decisions decisions;
-};
-
 // This is exposed because it's necessary in supports_op
 inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
                                                   uint32_t kv_tile,
@@ -65,105 +336,1039 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
     return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
 }
 
-static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
-    const size_t limit_bytes  = context.wg_mem_limit_bytes;
-    const size_t q_tile       = context.sg_mat_m;
-    const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
-                                2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
-    size_t bytes_per_kv = 0;
-    if (!context.kv_direct) {
-        bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
+/** Matrix Multiplication **/
+
+struct ggml_webgpu_legacy_mul_mat_pipeline_key {
+    ggml_type src0_type;
+    ggml_type src1_type;
+
+    bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const {
+        return src0_type == other.src0_type && src1_type == other.src1_type;
     }
-    if (context.has_mask) {
-        bytes_per_kv += q_tile;
+};
+
+struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src0_type);
+        ggml_webgpu_hash_combine(seed, key.src1_type);
+        return seed;
     }
-    bytes_per_kv += q_tile;
-    bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
-    const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
-    return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
-}
+};
 
-inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
-    pre_wgsl::Preprocessor &                          preprocessor,
-    const char *                                      shader_src,
-    const ggml_webgpu_flash_attn_shader_lib_context & context) {
-    std::vector defines;
-    std::string              variant = "flash_attn";
+struct ggml_webgpu_mul_mat_vec_pipeline_key {
+    ggml_type src0_type;
+    ggml_type src1_type;
+    int       vectorized;
 
-    switch (context.kv_type) {
-        case GGML_TYPE_F32:
-            defines.push_back("KV_F32");
-            break;
-        case GGML_TYPE_F16:
-            defines.push_back("KV_F16");
-            break;
-        case GGML_TYPE_Q4_0:
-            defines.push_back("KV_Q4_0");
-            break;
-        case GGML_TYPE_Q8_0:
-            defines.push_back("KV_Q8_0");
-            break;
-        default:
-            GGML_ABORT("Unsupported KV type for flash attention shader");
+    bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const {
+        return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized;
     }
-    variant += std::string("_") + ggml_type_name(context.kv_type);
+};
 
-    if (context.has_mask) {
-        defines.push_back("MASK");
-        variant += "_mask";
+struct ggml_webgpu_mul_mat_vec_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src0_type);
+        ggml_webgpu_hash_combine(seed, key.src1_type);
+        ggml_webgpu_hash_combine(seed, key.vectorized);
+        return seed;
     }
-    if (context.has_sinks) {
-        defines.push_back("SINKS");
-        variant += "_sinks";
+};
+
+struct ggml_webgpu_mul_mat_vec_shader_decisions {
+    uint32_t wg_size;
+    uint32_t tile_k;
+    uint32_t outputs_per_wg;
+    uint32_t vec_size;
+};
+
+struct ggml_webgpu_mul_mat_pipeline_key {
+    ggml_type src0_type;
+    ggml_type src1_type;
+    int       vectorized;
+    int       use_subgroup_matrix;
+
+    bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const {
+        return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized &&
+               use_subgroup_matrix == other.use_subgroup_matrix;
     }
-    if (context.uses_logit_softcap) {
-        defines.push_back("LOGIT_SOFTCAP");
-        variant += "_lgsc";
+};
+
+struct ggml_webgpu_mul_mat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.src0_type);
+        ggml_webgpu_hash_combine(seed, key.src1_type);
+        ggml_webgpu_hash_combine(seed, key.vectorized);
+        ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix);
+        return seed;
     }
+};
 
-    if (context.kv_direct) {
-        defines.push_back("KV_DIRECT");
-        variant += "_kvdirect";
-    }
+struct ggml_webgpu_mul_mat_shader_decisions {
+    uint32_t tile_k;
+    uint32_t wg_size_m;
+    uint32_t wg_size_n;
+    uint32_t wg_size;
+    uint32_t outputs_per_wg;
+    int      use_subgroup_matrix;
 
-    defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
-    variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
+    uint32_t tile_m;
+    uint32_t tile_n;
 
-    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
-    variant += std::string("_hsv") + std::to_string(context.head_dim_v);
+    // Subgroup matrix parameters
+    uint32_t subgroup_m;
+    uint32_t subgroup_n;
+    uint32_t subgroup_matrix_m;
+    uint32_t subgroup_matrix_n;
 
-    // For now these are not part of the variant name
-    defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
-    defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
-    defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
+    uint32_t mul_mat_wg_size;
+};
 
-    // Add chosen Q/KV tile sizes
-    uint32_t q_tile  = context.sg_mat_m;
-    uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
-                                context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
-    if (context.kv_direct) {
-        GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
-        // Avoids having to use bounds-checks and decreasing performance for direct KV loads
-        while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
-            kv_tile -= context.sg_mat_n;
+class ggml_webgpu_shader_lib {
+    wgpu::Device           device;
+    pre_wgsl::Preprocessor preprocessor;
+
+    std::unordered_map sum_rows_pipelines;       // key is fixed, no variants yet
+    std::unordered_map argmax_pipelines;         // key is vec4
+    std::unordered_map argsort_pipelines;        // key is order
+    std::unordered_map argsort_merge_pipelines;  // key is order
+    std::unordered_map cumsum_pipelines;         // key is fixed, no variants yet
+    std::unordered_map
+        get_rows_pipelines;                                            // src_type, vectorized
+    std::unordered_map
+        unary_pipelines;                                               // type/op/inplace
+    std::unordered_map
+        scale_pipelines;                                               // inplace
+    std::unordered_map
+        pad_pipelines;                                                 // circular/non-circular
+    std::unordered_map
+        binary_pipelines;                                              // type/op/inplace/overlap
+    std::unordered_map
+        concat_pipelines;                                              // type
+    std::unordered_map
+        repeat_pipelines;                                              // type
+    std::unordered_map
+        flash_attn_pipelines;
+    std::unordered_map
+        mul_mat_legacy_pipelines;  // legacy mul_mat (non-subgroup/non-regtile/non-vec)
+    std::unordered_map
+        mul_mat_vec_pipelines;     // fast mat-vec (n==1)
+    std::unordered_map
+        mul_mat_fast_pipelines;    // fast mat-mat (reg-tile or subgroup)
+
+    std::unordered_map
+        set_rows_pipelines;
+
+  public:
+    ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
+
+    webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        auto it = sum_rows_pipelines.find(1);
+        if (it != sum_rows_pipelines.end()) {
+            return it->second;
         }
+        std::vector defines;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed        = preprocessor.preprocess(wgsl_sum_rows, defines);
+        sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows");
+        return sum_rows_pipelines[1];
     }
 
-    defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
-    defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
+    webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        bool vec4 = context.src0->ne[0] % 4 == 0;
 
-    // workgroup size
-    uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+        auto it = argmax_pipelines.find(vec4);
+        if (it != argmax_pipelines.end()) {
+            return it->second;
+        }
+        std::string              variant = "argmax";
+        std::vector defines;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+        if (vec4) {
+            defines.push_back("VEC4");
+            variant += "_vec4";
+        }
 
-    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+        auto processed         = preprocessor.preprocess(wgsl_argmax, defines);
+        argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant);
+        return argmax_pipelines.at(vec4);
+    }
 
-    ggml_webgpu_processed_shader result;
-    result.wgsl              = preprocessor.preprocess(shader_src, defines);
-    result.variant           = variant;
-    result.decisions.q_tile  = q_tile;
-    result.decisions.kv_tile = kv_tile;
-    result.decisions.wg_size = wg_size;
-    return result;
-}
+    webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type,
+                                                  .vec4     = context.src0->ne[0] % 4 == 0,
+                                                  .i64_idx  = context.src1->type == GGML_TYPE_I64 };
+
+        auto it = set_rows_pipelines.find(key);
+        if (it != set_rows_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "set_rows";
+
+        switch (context.dst->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("DST_F32");
+                variant += "_dstf32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("DST_F16");
+                variant += "_dstf16";
+                break;
+            default:
+                GGML_ABORT("Unsupported dst type for set_rows shader");
+        }
+
+        if (key.vec4) {
+            defines.push_back("VEC4");
+            variant += "_vec4";
+        }
+        if (key.i64_idx) {
+            defines.push_back("I64_IDX");
+            variant += "_i64idx";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed                  = preprocessor.preprocess(wgsl_set_rows, defines);
+        auto decisions                  = std::make_shared();
+        decisions->vec4                 = key.vec4;
+        decisions->i64_idx              = key.i64_idx;
+        decisions->wg_size              = context.max_wg_size;
+        set_rows_pipelines[key]         = ggml_webgpu_create_pipeline(device, processed, variant);
+        set_rows_pipelines[key].context = decisions;
+        return set_rows_pipelines[key];
+    }
+
+    webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        auto it = cumsum_pipelines.find(1);
+        if (it != cumsum_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed      = preprocessor.preprocess(wgsl_cumsum, defines);
+        cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum");
+        return cumsum_pipelines[1];
+    }
+
+    webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        bool          is_top_k = context.dst->op == GGML_OP_TOP_K;
+        // ascending order is 0, descending order is 1
+        const int32_t order =
+            is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
+
+        auto it = argsort_pipelines.find(order);
+        if (it != argsort_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "argsort";
+        defines.push_back(std::string("ORDER=") + std::to_string(order));
+        variant += std::string("_order") + std::to_string(order);
+        uint32_t wg_size = 1;
+        while (wg_size * 2 <= context.max_wg_size &&
+               wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
+            wg_size *= 2;
+        }
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+        auto processed                   = preprocessor.preprocess(wgsl_argsort, defines);
+        auto decisions                   = std::make_shared();
+        decisions->wg_size               = wg_size;
+        argsort_pipelines[order]         = ggml_webgpu_create_pipeline(device, processed, variant);
+        argsort_pipelines[order].context = decisions;
+        return argsort_pipelines[order];
+    }
+
+    webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        bool          is_top_k = context.dst->op == GGML_OP_TOP_K;
+        // ascending order is 0, descending order is 1
+        const int32_t order =
+            is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0);
+
+        auto it = argsort_merge_pipelines.find(order);
+        if (it != argsort_merge_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "argsort_merge";
+        defines.push_back(std::string("ORDER=") + std::to_string(order));
+        variant += std::string("_order") + std::to_string(order);
+        uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+
+        auto processed                 = preprocessor.preprocess(wgsl_argsort_merge, defines);
+        argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant);
+        return argsort_merge_pipelines[order];
+    }
+
+    webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        const bool vectorized                 = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0;
+        ggml_webgpu_get_rows_pipeline_key key = {
+            .src_type   = context.src0->type,
+            .vectorized = (int) vectorized,
+        };
+
+        auto it = get_rows_pipelines.find(key);
+        if (it != get_rows_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "get_rows";
+
+        const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type);
+        const char *                    type_str    = type_traits->type_name;
+
+        switch (key.src_type) {
+            case GGML_TYPE_F32:
+                if (key.vectorized) {
+                    defines.push_back("F32_VEC");
+                    defines.push_back("SRC_TYPE=vec4");
+                    defines.push_back("DST_TYPE=vec4");
+                    defines.push_back("BLOCK_SIZE=4u");
+                } else {
+                    defines.push_back("F32");
+                    defines.push_back("SRC_TYPE=f32");
+                    defines.push_back("DST_TYPE=f32");
+                    defines.push_back("BLOCK_SIZE=1u");
+                }
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("F16");
+                defines.push_back("SRC_TYPE=f16");
+                defines.push_back("DST_TYPE=f32");
+                defines.push_back("BLOCK_SIZE=1u");
+                variant += "_f16";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("I32");
+                defines.push_back("SRC_TYPE=i32");
+                defines.push_back("DST_TYPE=i32");
+                defines.push_back("BLOCK_SIZE=1u");
+                variant += "_i32";
+                break;
+            default:
+                {
+                    std::string type_upper = type_str;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back(type_upper + "_T");
+                    defines.push_back(type_upper);
+                    defines.push_back(type_upper + "_SCALE_MIN");
+                    defines.push_back(type_upper + "_TABLES");
+                    defines.push_back(type_upper + "_GRID");
+
+                    variant += "_";
+                    variant += type_str;
+
+                    defines.push_back(std::string("SRC_TYPE=") + type_str);
+                    defines.push_back("DST_TYPE=f32");
+
+                    if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
+                        key.src_type == GGML_TYPE_IQ4_NL) {
+                        defines.push_back("BLOCK_SIZE=32u");
+                    } else if (key.src_type >= GGML_TYPE_Q2_K) {
+                        defines.push_back("BLOCK_SIZE=256u");
+                    } else {
+                        defines.push_back("BLOCK_SIZE=1u");
+                    }
+                    break;
+                }
+        }
+
+        if (key.vectorized) {
+            variant += "_vec";
+        }
+
+        defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_get_rows, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        get_rows_pipelines[key]  = pipeline;
+        return get_rows_pipelines[key];
+    }
+
+    webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace };
+
+        auto it = scale_pipelines.find(key);
+        if (it != scale_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "scale";
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_scale, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        scale_pipelines[key]     = pipeline;
+        return scale_pipelines[key];
+    }
+
+    webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
+
+        auto it = pad_pipelines.find(key);
+        if (it != pad_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "pad";
+
+        if (key.circular) {
+            defines.push_back("CIRCULAR");
+            variant += "_circular";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_pad, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        pad_pipelines[key]       = pipeline;
+        return pad_pipelines[key];
+    }
+
+    webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_mul_mat_vec_pipeline_key key = {
+            .src0_type  = context.src0->type,
+            .src1_type  = context.src1->type,
+            // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float
+            .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 &&
+                           (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
+                              1 :
+                              0,
+        };
+
+        auto it = mul_mat_vec_pipelines.find(key);
+        if (it != mul_mat_vec_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "mul_mat_vec";
+
+        // src0 type (matrix row)
+        switch (context.src0->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC0_INNER_TYPE=f32");
+                defines.push_back("MUL_ACC_FLOAT");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC0_INNER_TYPE=f16");
+                defines.push_back("MUL_ACC_FLOAT");
+                variant += "_f16";
+                break;
+            default:
+                {
+                    // Quantized types: use helpers but accumulate in f16
+                    const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
+                    std::string                     src0_name   = src0_traits->type_name;
+                    std::string                     type_upper  = src0_name;
+                    variant += "_" + src0_name;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back("MUL_ACC_" + type_upper);
+
+                    // For fast path we always dequantize from f16 inside the shader
+                    defines.push_back("SRC0_INNER_TYPE=f16");
+                    break;
+                }
+        }
+
+        // src1 type (vector)
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_INNER_TYPE=f32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_INNER_TYPE=f16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat_vec shader");
+        }
+
+        // VEC/SCALAR controls
+        defines.push_back(key.vectorized ? "VEC" : "SCALAR");
+
+        uint32_t wg_size        = WEBGPU_MUL_MAT_VEC_WG_SIZE;
+        uint32_t tile_k         = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K;
+        uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
+
+        if (key.src0_type >= GGML_TYPE_Q2_K) {
+            tile_k         = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K;
+            outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
+        } else if (key.src0_type >= GGML_TYPE_Q4_0) {
+            tile_k         = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K;
+            outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+        defines.push_back(std::string("TILE_K=") + std::to_string(tile_k));
+        defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg));
+
+        auto processed            = preprocessor.preprocess(wgsl_mul_mat_vec, defines);
+        auto decisions            = std::make_shared();
+        decisions->wg_size        = wg_size;
+        decisions->tile_k         = tile_k;
+        decisions->outputs_per_wg = outputs_per_wg;
+        decisions->vec_size       = key.vectorized ? 4 : 1;
+
+        webgpu_pipeline pipeline   = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context           = decisions;
+        mul_mat_vec_pipelines[key] = pipeline;
+        return mul_mat_vec_pipelines[key];
+    }
+
+    webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_mul_mat_pipeline_key key = {
+            .src0_type  = context.src0->type,
+            .src1_type  = context.src1->type,
+            .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 &&
+                           (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ?
+                              1 :
+                              0,
+            .use_subgroup_matrix = context.supports_subgroup_matrix
+        };
+
+        auto it = mul_mat_fast_pipelines.find(key);
+        if (it != mul_mat_fast_pipelines.end()) {
+            return it->second;
+        }
+
+        const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile;
+        std::vector defines;
+        std::string              variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile";
+
+        // src1 type
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_INNER_TYPE=f32");
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_INNER_TYPE=f16");
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat fast shader");
+        }
+
+        // src0 type
+        const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
+        const char *                    src0_name   = src0_traits->type_name;
+
+        switch (context.src0->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC0_INNER_TYPE=f32");
+                defines.push_back("FLOAT");
+                defines.push_back("MUL_ACC_FLOAT");
+                defines.push_back("INIT_SRC0_SHMEM_FLOAT");
+                defines.push_back("INIT_SRC1_SHMEM_FLOAT");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC0_INNER_TYPE=f16");
+                defines.push_back("FLOAT");
+                defines.push_back("MUL_ACC_FLOAT");
+                defines.push_back("INIT_SRC0_SHMEM_FLOAT");
+                defines.push_back("INIT_SRC1_SHMEM_FLOAT");
+                variant += "_f16";
+                break;
+            default:
+                {
+                    std::string type_upper = src0_name;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back("MUL_ACC_" + type_upper);
+                    defines.push_back("INIT_SRC0_SHMEM_" + type_upper);
+                    defines.push_back("INIT_SRC1_SHMEM_FLOAT");
+
+                    // Use f16 inside the shader for quantized types
+                    defines.push_back("SRC0_INNER_TYPE=f16");
+
+                    variant += std::string("_") + src0_name;
+                    break;
+                }
+        }
+
+        // VEC/SCALAR controls
+        defines.push_back(key.vectorized ? "VEC" : "SCALAR");
+
+        // Tiles
+        defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u");
+        defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u");
+        defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u");
+
+        // Subgroup matrix specifics
+        if (key.use_subgroup_matrix) {
+            defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u");
+            defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u");
+            defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u");
+            defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u");
+            defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u");
+            defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u");
+            defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u");
+            defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u");
+        }
+
+        // variant suffix for src1 type
+        variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16");
+        if (key.vectorized) {
+            variant += "_vectorized";
+        }
+
+        if (!key.use_subgroup_matrix) {
+            defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u");
+            defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u");
+        }
+
+        auto processed = preprocessor.preprocess(shader_src, defines);
+
+        auto decisions                 = std::make_shared();
+        decisions->tile_k              = WEBGPU_MUL_MAT_TILE_K;
+        decisions->tile_m              = WEBGPU_MUL_MAT_TILE_M;
+        decisions->tile_n              = WEBGPU_MUL_MAT_TILE_N;
+        decisions->use_subgroup_matrix = key.use_subgroup_matrix;
+        if (key.use_subgroup_matrix) {
+            decisions->subgroup_m        = WEBGPU_MUL_MAT_SUBGROUP_M;
+            decisions->subgroup_n        = WEBGPU_MUL_MAT_SUBGROUP_N;
+            decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M;
+            decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N;
+            decisions->wg_size           = context.max_subgroup_size;
+        } else {
+            decisions->wg_size_m       = WEBGPU_MUL_MAT_WG_SIZE_M;
+            decisions->wg_size_n       = WEBGPU_MUL_MAT_WG_SIZE_N;
+            decisions->wg_size         = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N;
+            decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE;
+        }
+
+        webgpu_pipeline pipeline    = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context            = decisions;
+        mul_mat_fast_pipelines[key] = pipeline;
+        return mul_mat_fast_pipelines[key];
+    }
+
+    webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type,
+                                                        .src1_type = context.src1->type };
+
+        auto it = mul_mat_legacy_pipelines.find(key);
+        if (it != mul_mat_legacy_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "mul_mat";
+
+        switch (context.src1->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC1_TYPE=f32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC1_TYPE=f16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported src1 type for mul_mat legacy shader");
+        }
+
+        const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type);
+        const char *                    src0_name   = src0_traits->type_name;
+
+        switch (context.src0->type) {
+            case GGML_TYPE_F32:
+                defines.push_back("SRC0_TYPE=f32");
+                defines.push_back("FLOAT");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("SRC0_TYPE=f16");
+                defines.push_back("FLOAT");
+                variant += "_f16";
+                break;
+            default:
+                {
+                    // quantized types
+                    std::string type_upper = src0_name;
+                    std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
+
+                    defines.push_back(std::string("SRC0_TYPE=") + src0_name);
+                    defines.push_back("BYTE_HELPERS");
+                    defines.push_back(type_upper + "_T");
+                    defines.push_back(type_upper);
+                    defines.push_back(type_upper + "_SCALE_MIN");
+                    defines.push_back(type_upper + "_TABLES");
+                    defines.push_back(type_upper + "_GRID");
+
+                    variant += std::string("_") + src0_name;
+                    break;
+                }
+        }
+
+        auto processed = preprocessor.preprocess(wgsl_mul_mat, defines);
+
+        auto decisions     = std::make_shared();
+        decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE;
+
+        webgpu_pipeline pipeline      = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context              = decisions;
+        mul_mat_legacy_pipelines[key] = pipeline;
+        return mul_mat_legacy_pipelines[key];
+    }
+
+    webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        const bool                     is_unary = context.dst->op == GGML_OP_UNARY;
+        const int                      op       = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op;
+        ggml_webgpu_unary_pipeline_key key      = {
+                 .type     = context.dst->type,
+                 .op       = op,
+                 .is_unary = is_unary,
+                 .inplace  = context.inplace,
+        };
+
+        auto it = unary_pipelines.find(key);
+        if (it != unary_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant =
+            key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op);
+        defines.push_back(variant);
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("TYPE_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for unary shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_unary, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        unary_pipelines[key]     = pipeline;
+        return unary_pipelines[key];
+    }
+
+    webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_binary_pipeline_key key = {
+            .type        = context.dst->type,
+            .op          = context.dst->op,
+            .inplace     = context.inplace,
+            .overlap     = context.overlap,
+            .src_overlap = context.src_overlap,
+        };
+
+        auto it = binary_pipelines.find(key);
+        if (it != binary_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              op_name = ggml_op_name((ggml_op) key.op);
+        std::string              variant = op_name;
+
+        defines.push_back(std::string("OP_") + op_name);
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("TYPE_F16");
+                variant += "_f16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for binary shader");
+        }
+
+        if (key.inplace) {
+            defines.push_back("INPLACE");
+            variant += "_inplace";
+        } else if (key.overlap) {
+            defines.push_back("OVERLAP");
+            variant += "_overlap";
+        } else if (key.src_overlap) {
+            defines.push_back("SRC_OVERLAP");
+            variant += "_src_overlap";
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_binary, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        binary_pipelines[key]    = pipeline;
+        return binary_pipelines[key];
+    }
+
+    webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_concat_pipeline_key key = {
+            .type = context.dst->type,
+        };
+
+        auto it = concat_pipelines.find(key);
+        if (it != concat_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "concat";
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("TYPE_I32");
+                variant += "_i32";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for concat shader");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_concat, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        concat_pipelines[key]    = pipeline;
+        return concat_pipelines[key];
+    }
+
+    webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_repeat_pipeline_key key = {
+            .type = context.dst->type,
+        };
+
+        auto it = repeat_pipelines.find(key);
+        if (it != repeat_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "repeat";
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("TYPE_I32");
+                variant += "_i32";
+                break;
+            case GGML_TYPE_I16:
+                defines.push_back("TYPE_I16");
+                variant += "_i16";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for repeat shader");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed           = preprocessor.preprocess(wgsl_repeat, defines);
+        auto decisions           = std::make_shared();
+        decisions->wg_size       = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context         = decisions;
+        repeat_pipelines[key]    = pipeline;
+        return repeat_pipelines[key];
+    }
+
+    webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        const bool has_mask  = context.src3 != nullptr;
+        const bool has_sinks = context.src4 != nullptr;
+
+        bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
+                         (context.src1->ne[1] % context.sg_mat_n == 0);
+
+        ggml_webgpu_flash_attn_pipeline_key key = {
+            .kv_type            = context.src1->type,
+            .head_dim_qk        = (uint32_t) context.src0->ne[0],
+            .head_dim_v         = (uint32_t) context.src2->ne[0],
+            .kv_direct          = kv_direct,
+            .has_mask           = has_mask,
+            .has_sinks          = has_sinks,
+            .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
+        };
+
+        auto it = flash_attn_pipelines.find(key);
+        if (it != flash_attn_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector defines;
+        std::string              variant = "flash_attn";
+
+        switch (key.kv_type) {
+            case GGML_TYPE_F32:
+                defines.push_back("KV_F32");
+                break;
+            case GGML_TYPE_F16:
+                defines.push_back("KV_F16");
+                break;
+            case GGML_TYPE_Q4_0:
+                defines.push_back("KV_Q4_0");
+                break;
+            case GGML_TYPE_Q8_0:
+                defines.push_back("KV_Q8_0");
+                break;
+            default:
+                GGML_ABORT("Unsupported KV type for flash attention shader");
+        }
+        variant += std::string("_") + ggml_type_name(key.kv_type);
+
+        if (key.has_mask) {
+            defines.push_back("MASK");
+            variant += "_mask";
+        }
+        if (key.has_sinks) {
+            defines.push_back("SINKS");
+            variant += "_sinks";
+        }
+        if (key.uses_logit_softcap) {
+            defines.push_back("LOGIT_SOFTCAP");
+            variant += "_lgsc";
+        }
+        if (key.kv_direct) {
+            defines.push_back("KV_DIRECT");
+            variant += "_kvdirect";
+        }
+
+        defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
+        variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
+
+        defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
+        variant += std::string("_hsv") + std::to_string(key.head_dim_v);
+
+        defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
+        defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
+        defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
+
+        uint32_t q_tile = context.sg_mat_m;
+        uint32_t kv_tile =
+            std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
+                                                          context.wg_mem_limit_bytes, context.max_subgroup_size }),
+                     context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
+        if (key.kv_direct) {
+            while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
+                kv_tile -= context.sg_mat_n;
+            }
+        }
+
+        defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
+        defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
+
+        uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+
+        auto processed     = preprocessor.preprocess(wgsl_flash_attn, defines);
+        auto decisions     = std::make_shared();
+        decisions->q_tile  = q_tile;
+        decisions->kv_tile = kv_tile;
+        decisions->wg_size = wg_size;
+
+        webgpu_pipeline pipeline  = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context          = decisions;
+        flash_attn_pipelines[key] = pipeline;
+        return flash_attn_pipelines[key];
+    }
+
+  private:
+    static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
+                                                       std::string    shader_code,
+                                                       std::string    label) {
+        wgpu::ShaderSourceWGSL shader_source;
+        shader_source.code = shader_code.c_str();
+
+        wgpu::ShaderModuleDescriptor shader_desc;
+        shader_desc.nextInChain = &shader_source;
+
+        wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
+
+        wgpu::ComputePipelineDescriptor pipeline_desc;
+        pipeline_desc.label              = label.c_str();
+        pipeline_desc.compute.module     = shader_module;
+        pipeline_desc.compute.entryPoint = "main";   // Entry point in the WGSL code
+        pipeline_desc.layout             = nullptr;  // nullptr means auto layout
+        return { device.CreateComputePipeline(&pipeline_desc), label };
+    }
+
+    static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
+        const size_t limit_bytes = context.wg_mem_limit_bytes;
+        const size_t q_tile      = context.sg_mat_m;
+        const size_t base_q_bytes =
+            (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+            2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+        size_t bytes_per_kv = 0;
+        if (!context.key.kv_direct) {
+            bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
+        }
+        if (context.key.has_mask) {
+            bytes_per_kv += q_tile;
+        }
+        bytes_per_kv += q_tile;
+        bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
+        const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
+        return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
+    }
+};
 
 #endif  // GGML_WEBGPU_SHADER_LIB_HPP
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
index 5b8f7f72..128b7dc3 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
@@ -8,8 +8,6 @@
 #include "ggml-backend-impl.h"
 #include "ggml-impl.h"
 #include "ggml-webgpu-shader-lib.hpp"
-#include "ggml-wgsl-shaders.hpp"
-#include "pre_wgsl.hpp"
 
 #ifdef __EMSCRIPTEN__
 #    include 
@@ -21,16 +19,30 @@
 #include 
 #include 
 #include 
-#include 
+#ifdef GGML_WEBGPU_GPU_PROFILE
+#    include 
+#endif
+#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE)
+#    include 
+#endif
 #include 
+#include 
 #include 
 #include 
 #include 
+#include 
 #include 
 
 #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
 #define CEIL_DIV(M, N)        (((M) + (N) - 1) / (N))
 
+// Return a rectangular grid of workgroups with minimal over-provisioned workgroups.
+// Assumes that the total number of workgroups does not exceed max_per_dim^2.
+static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) {
+    wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim));
+    wg_x = CEIL_DIV(total_wg, wg_y);
+}
+
 #ifdef GGML_WEBGPU_DEBUG
 #    define WEBGPU_LOG_DEBUG(msg)  std::cout << msg << std::endl
 #    define WEBGPU_DEBUG_BUF_ELEMS 512
@@ -47,7 +59,6 @@
         double cpu_total_time_##id =                                                                      \
             std::chrono::duration(cpu_total_end_##id - cpu_total_start_##id).count(); \
         (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
-
 // fine-grained timing (not included in totals)
 #    define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
 
@@ -64,56 +75,34 @@
 #endif  // GGML_WEBGPU_CPU_PROFILE
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
-#    define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS       24
+#    define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS       32
 #    define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16  // e.g. enough for two timestamps
 #endif
 
 /* Constants */
 
-// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
-#define WEBGPU_MAX_WG_SIZE 288
-
-#define WEBGPU_MUL_MAT_WG_SIZE               256
-#define WEBGPU_NUM_PARAM_BUFS                32u
-#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     8u
+#define WEBGPU_NUM_PARAM_BUFS                96u
+#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE     32u
 #define WEBGPU_WAIT_ANY_TIMEOUT_MS           0
-// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
-#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
+// Maximum number of in-flight submissions per-thread, to avoid exhausting the
+// parameter buffer pool
+#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD  (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE)
 #define WEBGPU_PARAMS_BUF_SIZE_BYTES         128  // enough for 32 parameters
-#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS       32
 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
-#define WEBGPU_STORAGE_BUF_BINDING_MULT      4  // a storage buffer binding size must be a multiple of 4
+#define WEBGPU_STORAGE_BUF_BINDING_MULT      4    // a storage buffer binding size must be a multiple of 4
 
-// For operations which process a row in parallel, this seems like a reasonable default
+// For operations which process a row in parallel, this seems like a reasonable
+// default
 #define WEBGPU_ROW_SPLIT_WG_SIZE 64
 
-// Matrix multiplication parameters
-
-// Register tiling parameters
-#define WEBGPU_MUL_MAT_TILE_M    8
-#define WEBGPU_MUL_MAT_TILE_N    8
-#define WEBGPU_MUL_MAT_WG_SIZE_M 8
-#define WEBGPU_MUL_MAT_WG_SIZE_N 8
-#define WEBGPU_MUL_MAT_TILE_K    32
-
-// Subgroup matrix parameters
-// The number of subgroups in the M dimension
-#define WEBGPU_MUL_MAT_SUBGROUP_M        2
-// The number of subgroups in the N dimension
-#define WEBGPU_MUL_MAT_SUBGROUP_N        2
-// The number of subgroup matrices each subgroup accumulates over
-#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
-#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
-
-// Matrix-vector multiplication parameters
-#define WEBGPU_MUL_MAT_VEC_WG_SIZE        256
-// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
-#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
-#define WEBGPU_MUL_MAT_VEC_TILE_K         256
+// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to
+// implementations so this can be removed, necessary only for get_rows right now
+#define WEBGPU_MAX_WG_SIZE 288
 
 /* End Constants */
 
-// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
+// This is a "fake" base pointer, since WebGPU buffers do not have pointers to
+// their locations.
 static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000;  // NOLINT
 
 // Always returns the base offset of a tensor, regardless of views.
@@ -133,47 +122,70 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
                                       wgpu::BufferUsage usage,
                                       const char *      label);
 
-struct webgpu_pool_bufs {
-    wgpu::Buffer host_buf;
-    wgpu::Buffer dev_buf;
-};
-
-// The futures to wait on for a single queue submission
-struct webgpu_submission_futures {
-    std::vector futures;
-};
-
 // Holds a pool of parameter buffers for WebGPU operations
 struct webgpu_buf_pool {
-    std::vector free;
-
-    std::mutex mutex;
+    std::vector free;
 
+    // The pool must be synchronized because
+    // 1. The memset pool is shared globally by every ggml buffer,
+    // since allocating a pool per ggml buffer would consume too much memory.
+    // 2. For the per-thread buffer pools in webgpu_context,
+    // buffers are allocated and freed in Dawn callbacks,
+    // which can run on a different thread than the calling thread.
+    std::mutex              mutex;
     std::condition_variable cv;
+    size_t                  cur_pool_size;
+    size_t                  max_pool_size;
+    wgpu::Device            device;
+    wgpu::BufferUsage       dev_buf_usage;
+    size_t                  buf_size;
+    bool                    should_grow;
 
     void init(wgpu::Device      device,
               int               num_bufs,
               size_t            buf_size,
               wgpu::BufferUsage dev_buf_usage,
-              wgpu::BufferUsage host_buf_usage) {
+              bool              should_grow   = false,
+              size_t            max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) {
+        this->max_pool_size = max_pool_size;
+        this->cur_pool_size = num_bufs;
+        this->device        = device;
+        this->dev_buf_usage = dev_buf_usage;
+        this->buf_size      = buf_size;
+        this->should_grow   = should_grow;
         for (int i = 0; i < num_bufs; i++) {
-            wgpu::Buffer host_buf;
             wgpu::Buffer dev_buf;
-            ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
             ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
-            free.push_back({ host_buf, dev_buf });
+            free.push_back(dev_buf);
         }
     }
 
-    webgpu_pool_bufs alloc_bufs() {
+    wgpu::Buffer alloc_bufs() {
         std::unique_lock lock(mutex);
+        if (!free.empty()) {
+            wgpu::Buffer buf = free.back();
+            free.pop_back();
+            return buf;
+        }
+
+        // Try growing the pool if no free buffers
+        if (free.empty() && cur_pool_size < max_pool_size && should_grow) {
+            cur_pool_size++;
+            wgpu::Buffer dev_buf;
+            ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
+
+            if (!dev_buf) {
+                GGML_ABORT("webgpu_buf_pool: failed to allocate buffers");
+            }
+            return dev_buf;
+        }
         cv.wait(lock, [this] { return !free.empty(); });
-        webgpu_pool_bufs bufs = free.back();
+        wgpu::Buffer buf = free.back();
         free.pop_back();
-        return bufs;
+        return buf;
     }
 
-    void free_bufs(std::vector bufs) {
+    void free_bufs(std::vector bufs) {
         std::lock_guard lock(mutex);
         free.insert(free.end(), bufs.begin(), bufs.end());
         cv.notify_all();
@@ -181,12 +193,15 @@ struct webgpu_buf_pool {
 
     void cleanup() {
         std::lock_guard lock(mutex);
-        for (auto & bufs : free) {
-            bufs.host_buf.Destroy();
-            bufs.dev_buf.Destroy();
+        for (auto & buf : free) {
+            if (buf) {
+                buf.Destroy();
+            }
         }
         free.clear();
     }
+
+    ~webgpu_buf_pool() { this->cleanup(); }
 };
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
@@ -248,121 +263,49 @@ struct webgpu_gpu_profile_buf_pool {
         }
         free.clear();
     }
+
+    ~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
 };
 #endif
 
-struct webgpu_pipeline {
-    wgpu::ComputePipeline pipeline;
-    std::string           name;
-    void *                context = nullptr;
-};
-
 struct webgpu_command {
-    wgpu::CommandBuffer             commands;
-    webgpu_pool_bufs                params_bufs;
-    std::optional set_rows_error_bufs;
+    uint32_t                  num_kernels;
+    wgpu::CommandBuffer       commands;
+    std::vector params_bufs;
 #ifdef GGML_WEBGPU_GPU_PROFILE
     webgpu_gpu_profile_bufs timestamp_query_bufs;
     std::string             pipeline_name;
 #endif
 };
 
-struct flash_attn_pipeline_key {
-    int      q_type;
-    int      kv_type;
-    int      dst_type;
-    uint32_t head_dim_qk;
-    uint32_t head_dim_v;
-    bool     kv_direct;
-    bool     has_mask;
-    bool     has_sinks;
-    bool     uses_logit_softcap;
+struct webgpu_capabilities {
+    wgpu::Limits limits;
+    bool         supports_subgroup_matrix = false;
 
-    bool operator==(const flash_attn_pipeline_key & other) const {
-        return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
-               head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
-               has_mask == other.has_mask && has_sinks == other.has_sinks &&
-               uses_logit_softcap == other.uses_logit_softcap;
-    }
+    uint32_t sg_mat_m = 0;
+    uint32_t sg_mat_n = 0;
+    uint32_t sg_mat_k = 0;
+
+    uint32_t subgroup_size     = 0;
+    uint32_t max_subgroup_size = 0;
+    size_t   memset_bytes_per_thread;
 };
 
-// Same hash combine function as in boost
-template  inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
-    seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
-}
-
-struct flash_attn_pipeline_key_hash {
-    size_t operator()(const flash_attn_pipeline_key & key) const {
-        size_t seed = 0;
-        ggml_webgpu_hash_combine(seed, key.q_type);
-        ggml_webgpu_hash_combine(seed, key.kv_type);
-        ggml_webgpu_hash_combine(seed, key.dst_type);
-        ggml_webgpu_hash_combine(seed, key.head_dim_qk);
-        ggml_webgpu_hash_combine(seed, key.head_dim_v);
-        ggml_webgpu_hash_combine(seed, key.kv_direct);
-        ggml_webgpu_hash_combine(seed, key.has_mask);
-        ggml_webgpu_hash_combine(seed, key.has_sinks);
-        ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
-        return seed;
-    }
-};
-
-// All the base objects needed to run operations on a WebGPU device
-struct webgpu_context_struct {
+// Stores global webgpu members
+struct webgpu_global_context_struct {
     wgpu::Instance instance;
     wgpu::Adapter  adapter;
     wgpu::Device   device;
     wgpu::Queue    queue;
-    wgpu::Limits   limits;
-
-    uint32_t max_subgroup_size;
-
-    bool     supports_subgroup_matrix = false;
-    uint32_t sg_mat_m;
-    uint32_t sg_mat_n;
-    uint32_t sg_mat_k;
 
+    webgpu_capabilities  capabilities;
+    // Shared buffer to move data from device to host
+    wgpu::Buffer         get_tensor_staging_buf;
+    // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches.
     std::recursive_mutex mutex;
-    std::atomic_uint     inflight_threads = 0;
 
-    webgpu_buf_pool param_buf_pool;
-    webgpu_buf_pool set_rows_error_buf_pool;
-
-    pre_wgsl::Preprocessor p;
-
-    std::map memset_pipelines;                                 // variant or type index
-
-    std::map>> mul_mat_pipelines;  // src0_type, src1_type, vectorized
-    std::map>>
-        mul_mat_vec_pipelines;                                                       // src0_type, src1_type, vectorized
-
-    std::unordered_map flash_attn_pipelines;
-
-    std::map> set_rows_pipelines;                 // dst_type, vectorized
-    std::map> get_rows_pipelines;                 // src_type, vectorized
-
-    std::map> cpy_pipelines;                      // src_type, dst_type
-    std::map> add_pipelines;                      // type, inplace
-    std::map> sub_pipelines;                      // type, inplace
-    std::map> mul_pipelines;                      // type, inplace
-    std::map> div_pipelines;                      // type, inplace
-
-    std::map                               rms_norm_pipelines;  // inplace
-    std::map>> rope_pipelines;      // type, ff, inplace
-    std::map>> glu_pipelines;       // glu_op, type, split
-    std::map                               scale_pipelines;     // inplace
-    std::map>> soft_max_pipelines;  // mask_type, has_sink, inplace
-    std::map>> unary_pipelines;     // unary_op, type, inplace
-
-    size_t memset_bytes_per_thread;
-
-    // Staging buffer for reading data from the GPU
-    wgpu::Buffer get_tensor_staging_buf;
-
-#ifdef GGML_WEBGPU_DEBUG
-    wgpu::Buffer debug_host_buf;
-    wgpu::Buffer debug_dev_buf;
-#endif
+    webgpu_buf_pool                memset_buf_pool;
+    std::map memset_pipelines;  // variant or type index
 
 #ifdef GGML_WEBGPU_CPU_PROFILE
     // Profiling: labeled CPU time in ms (total)
@@ -377,59 +320,98 @@ struct webgpu_context_struct {
     // Profiling: pool of timestamp query buffers (one per operation)
     webgpu_gpu_profile_buf_pool             timestamp_query_buf_pool;
 #endif
+
+#ifdef GGML_WEBGPU_DEBUG
+    wgpu::Buffer debug_host_buf;
+    wgpu::Buffer debug_dev_buf;
+#endif
+
+    ~webgpu_global_context_struct() {
+        if (this->get_tensor_staging_buf) {
+            this->get_tensor_staging_buf.Destroy();
+            this->get_tensor_staging_buf = nullptr;
+        }
+#ifdef GGML_WEBGPU_DEBUG
+        if (this->debug_host_buf) {
+            this->debug_host_buf.Destroy();
+            this->debug_host_buf = nullptr;
+        }
+        if (this->debug_dev_buf) {
+            this->debug_dev_buf.Destroy();
+            this->debug_dev_buf = nullptr;
+        }
+#endif
+    }
+};
+
+typedef std::shared_ptr webgpu_global_context;
+
+struct webgpu_submission {
+    wgpu::FutureWaitInfo submit_done;
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    std::vector profile_futures;
+#endif
+};
+
+// All the base objects needed to run operations on a WebGPU device
+struct webgpu_context_struct {
+    // Points to global instances owned by ggml_backend_webgpu_reg_context
+    webgpu_global_context global_ctx;
+
+    std::unique_ptr shader_lib;
+
+    webgpu_buf_pool param_buf_pool;
+    wgpu::Buffer    set_rows_dev_error_buf;
+    wgpu::Buffer    set_rows_host_error_buf;
+
+    std::map> cpy_pipelines;                      // src_type, dst_type
+
+    std::map                               rms_norm_pipelines;  // inplace
+    std::map>> rope_pipelines;      // type, ff, inplace
+    std::map>> glu_pipelines;       // glu_op, type, split
+
+    std::map>> soft_max_pipelines;  // mask_type, has_sink, inplace
+
+    size_t memset_bytes_per_thread;
 };
 
 typedef std::shared_ptr webgpu_context;
 
+// Metadata required for the ggml backend registration/discovery interface
 struct ggml_backend_webgpu_reg_context {
-    webgpu_context webgpu_ctx;
-    size_t         device_count;
-    const char *   name;
+    // Since the Instance is a global entrypoint into the WebGPU API, it lives here
+    webgpu_global_context webgpu_global_ctx;
+    size_t                device_count;
+    const char *          name;
 };
 
+// Per-device struct for the global logical device interface
 struct ggml_backend_webgpu_device_context {
-    webgpu_context webgpu_ctx;
-    std::string    device_name;
-    std::string    device_desc;
+    webgpu_global_context webgpu_global_ctx;
+    std::string           device_name;
+    std::string           device_desc;
 };
 
+// Per-thread data required to actually run WebGPU operations in a backend instance
 struct ggml_backend_webgpu_context {
     webgpu_context webgpu_ctx;
     std::string    name;
 };
 
+// Per-thread data related to buffers
 struct ggml_backend_webgpu_buffer_context {
-    webgpu_context webgpu_ctx;
-    wgpu::Buffer   buffer;
-    std::string    label;
+    wgpu::Buffer          buffer;
+    std::string           label;
+    webgpu_global_context global_ctx;
 
-    ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
-        webgpu_ctx(std::move(ctx)),
+    ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) :
         buffer(std::move(buf)),
-        label(std::move(lbl)) {}
+        label(std::move(lbl)),
+        global_ctx(std::move(global_ctx_)) {}
 };
 
 /* WebGPU object initializations */
 
-// Process a WGSL shader string, replacing tokens of the form {{KEY}} with
-// the corresponding values provided in `repls`.
-static std::string ggml_webgpu_process_shader_repls(const char *                               src,
-                                                    const std::map & repls) {
-    if (!src) {
-        return std::string();
-    }
-    std::string s = src;
-    for (const auto & kv : repls) {
-        std::string token = "{{" + kv.first + "}}";
-        size_t      pos   = 0;
-        while ((pos = s.find(token, pos)) != std::string::npos) {
-            s.replace(pos, token.length(), kv.second);
-            pos += kv.second.length();
-        }
-    }
-    return s;
-}
-
 static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device &                           device,
                                                    const char *                             shader_code,
                                                    const char *                             label,
@@ -473,44 +455,113 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 
 /** WebGPU Actions */
 
-// Wait for the queue to finish processing all submitted work
-static void ggml_backend_webgpu_wait(webgpu_context &                         ctx,
-                                     std::vector & futures,
-                                     bool                                     block = true) {
-    // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
-    // inflight_max may be 0, meaning that we must wait on all futures.
-    uint64_t timeout_ms       = block ? UINT64_MAX : 0;
-    uint32_t inflight_threads = ctx->inflight_threads;
-    uint32_t inflight_max     = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
-    while (futures.size() >= inflight_max && futures.size() > 0) {
-        ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
-        futures.erase(futures.begin());
+static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) {
+    switch (status) {
+        case wgpu::WaitStatus::Success:
+            return true;
+        case wgpu::WaitStatus::TimedOut:
+            if (allow_timeout) {
+                return false;
+            }
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n");
+            return false;
+        case wgpu::WaitStatus::Error:
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+            return false;
+        default:
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
+            return false;
     }
-    size_t i = 0;
-    while (i < futures.size()) {
-        auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
-        switch (waitStatus) {
-            case wgpu::WaitStatus::Success:
-                futures.erase(futures.begin() + i);
-                break;
-            case wgpu::WaitStatus::TimedOut:
-                i++;
-                break;
-            case wgpu::WaitStatus::Error:
-                GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
-                break;
-            default:
-                GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
-                break;
+}
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) {
+    futures.erase(std::remove_if(futures.begin(), futures.end(),
+                                 [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
+                  futures.end());
+}
+
+static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context &             ctx,
+                                                     std::vector & futures,
+                                                     bool                                block) {
+    if (futures.empty()) {
+        return;
+    }
+
+    uint64_t timeout_ms = block ? UINT64_MAX : 0;
+    if (block) {
+        while (!futures.empty()) {
+            auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+            if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
+                ggml_backend_webgpu_erase_completed_futures(futures);
+            }
+        }
+    } else {
+        auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+        if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) {
+            ggml_backend_webgpu_erase_completed_futures(futures);
+        }
+    }
+}
+#endif
+
+// Wait for the queue to finish processing all submitted work
+static void ggml_backend_webgpu_wait(webgpu_global_context &          ctx,
+                                     std::vector & subs,
+                                     bool                             block = true) {
+    // If we have too many in-flight submissions, wait on the oldest one first.
+    if (subs.empty()) {
+        return;
+    }
+    while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
+        auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX);
+        if (ggml_backend_webgpu_handle_wait_status(waitStatus)) {
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true);
+#endif
+            subs.erase(subs.begin());
+        }
+    }
+
+    if (subs.empty()) {
+        return;
+    }
+
+    if (block) {
+        for (auto & sub : subs) {
+            while (!sub.submit_done.completed) {
+                auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX);
+                ggml_backend_webgpu_handle_wait_status(waitStatus);
+            }
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true);
+#endif
+        }
+        subs.clear();
+    } else {
+        // Poll each submit future once and remove completed submissions.
+        for (auto sub = subs.begin(); sub != subs.end();) {
+            auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0);
+            ggml_backend_webgpu_handle_wait_status(waitStatus, true);
+#ifdef GGML_WEBGPU_GPU_PROFILE
+            ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false);
+            if (sub->submit_done.completed && sub->profile_futures.empty()) {
+#else
+            if (sub->submit_done.completed) {
+#endif
+                sub = subs.erase(sub);
+            } else {
+                ++sub;
+            }
         }
     }
 }
 
-static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
-                                           wgpu::Buffer &   buffer,
-                                           wgpu::MapMode    mode,
-                                           size_t           offset,
-                                           size_t           size) {
+static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx,
+                                           wgpu::Buffer &          buffer,
+                                           wgpu::MapMode           mode,
+                                           size_t                  offset,
+                                           size_t                  size) {
     ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
                                           [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
                                               if (status != wgpu::MapAsyncStatus::Success) {
@@ -525,7 +576,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
 // This function adds debugging information to shaders, as WebGPU does not support printing directly.
 // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
 // debug statements in the shader, and then call this function after encoding the commands and submitting them.
-static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
+static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
     wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
     encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
     wgpu::CommandBuffer commands = encoder.Finish();
@@ -537,53 +588,32 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
 }
 #endif
 
-static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector commands) {
+static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context &       ctx,
+                                                    std::vector & commands,
+                                                    webgpu_buf_pool &             param_buf_pool) {
     std::vector command_buffers;
-    std::vector    params_bufs;
-    std::vector    set_rows_error_bufs;
+    std::vector        params_bufs;
+    webgpu_submission                submission;
 #ifdef GGML_WEBGPU_GPU_PROFILE
     std::vector> pipeline_name_and_ts_bufs;
 #endif
 
     for (const auto & command : commands) {
         command_buffers.push_back(command.commands);
-        params_bufs.push_back(command.params_bufs);
-        if (command.set_rows_error_bufs) {
-            set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
-        }
+        params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end());
     }
     ctx->queue.Submit(command_buffers.size(), command_buffers.data());
 
-    std::vector futures;
-
     wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
         wgpu::CallbackMode::AllowSpontaneous,
-        [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
+        [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
             if (status != wgpu::QueueWorkDoneStatus::Success) {
                 GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
             }
             // Free the staged buffers
-            ctx->param_buf_pool.free_bufs({ params_bufs });
+            param_buf_pool.free_bufs(params_bufs);
         });
-    futures.push_back({ p_f });
-
-    for (const auto & bufs : set_rows_error_bufs) {
-        wgpu::Future f = bufs.host_buf.MapAsync(
-            wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
-            [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
-                if (status != wgpu::MapAsyncStatus::Success) {
-                    GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
-                } else {
-                    const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
-                    if (*error_data) {
-                        GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
-                    }
-                    // We can't unmap in here due to WebGPU reentrancy limitations.
-                    ctx->set_rows_error_buf_pool.free_bufs({ bufs });
-                }
-            });
-        futures.push_back({ f });
-    }
+    submission.submit_done = { p_f };
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
     for (const auto & command : commands) {
@@ -600,52 +630,54 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx,
                     // WebGPU timestamps are in ns; convert to ms
                     double           elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
                     ctx->shader_gpu_time_ms[label] += elapsed_ms;
-                    // We can't unmap in here due to WebGPU reentrancy limitations.
-                    ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
                 }
+                // We can't unmap in here due to WebGPU reentrancy limitations.
+                ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
             });
-        futures.push_back({ f });
+        submission.profile_futures.push_back({ f });
     }
 #endif
-    return { futures };
+    return submission;
 }
 
-static webgpu_command ggml_backend_webgpu_build(webgpu_context &                  ctx,
-                                                webgpu_pipeline &                 pipeline,
-                                                std::vector             params,
-                                                std::vector bind_group_entries,
-                                                uint32_t                          wg_x,
-                                                uint32_t                          wg_y                = 1,
-                                                std::optional   set_rows_error_bufs = std::nullopt) {
-    webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
+static webgpu_command ggml_backend_webgpu_build_multi(
+    webgpu_global_context &                                ctx,
+    webgpu_buf_pool &                                      param_buf_pool,
+    const std::vector &                   pipelines,
+    const std::vector> &             params_list,
+    const std::vector> & bind_group_entries_list,
+    const std::vector> &     workgroups_list) {
+    GGML_ASSERT(pipelines.size() == params_list.size());
+    GGML_ASSERT(pipelines.size() == bind_group_entries_list.size());
+    GGML_ASSERT(pipelines.size() == workgroups_list.size());
 
-    ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
-    uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
-    for (size_t i = 0; i < params.size(); i++) {
-        _params[i] = params[i];
-    };
+    std::vector    params_bufs_list;
+    std::vector bind_groups;
 
-    params_bufs.host_buf.Unmap();
+    for (size_t i = 0; i < pipelines.size(); i++) {
+        wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs();
 
-    uint32_t params_bufs_binding_num = bind_group_entries.size();
-    bind_group_entries.push_back({ .binding = params_bufs_binding_num,
-                                   .buffer  = params_bufs.dev_buf,
-                                   .offset  = 0,
-                                   .size    = params_bufs.dev_buf.GetSize() });
+        std::vector entries            = bind_group_entries_list[i];
+        uint32_t                          params_binding_num = entries.size();
+        entries.push_back(
+            { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() });
 
-    wgpu::BindGroupDescriptor bind_group_desc;
-    bind_group_desc.layout     = pipeline.pipeline.GetBindGroupLayout(0);
-    bind_group_desc.entryCount = bind_group_entries.size();
-    bind_group_desc.entries    = bind_group_entries.data();
-    bind_group_desc.label      = pipeline.name.c_str();
-    wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
+        wgpu::BindGroupDescriptor bind_group_desc;
+        bind_group_desc.layout     = pipelines[i].pipeline.GetBindGroupLayout(0);
+        bind_group_desc.entryCount = entries.size();
+        bind_group_desc.entries    = entries.data();
+        bind_group_desc.label      = pipelines[i].name.c_str();
+        bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc));
+
+        params_bufs_list.push_back(params_bufs);
+    }
 
     wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
-    encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
+    for (size_t i = 0; i < params_bufs_list.size(); i++) {
+        ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
+    }
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
-    // --- Profiling: GPU timestamp queries ---
-    // Allocate a timestamp query buffer (2 timestamps: start/end)
     webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
     if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
         ts_bufs.host_buf.Unmap();
@@ -659,50 +691,63 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context &
 #else
     wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
 #endif
-    pass.SetPipeline(pipeline.pipeline);
-    pass.SetBindGroup(0, bind_group);
-    pass.DispatchWorkgroups(wg_x, wg_y, 1);
+    for (size_t i = 0; i < pipelines.size(); i++) {
+        pass.SetPipeline(pipelines[i].pipeline);
+        pass.SetBindGroup(0, bind_groups[i]);
+        pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1);
+    }
     pass.End();
 
 #ifdef GGML_WEBGPU_GPU_PROFILE
-    // Resolve the query set into the device buffer
     encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
     encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
 #endif
 
-    // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
-    if (set_rows_error_bufs) {
-        encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
-                                   set_rows_error_bufs->host_buf.GetSize());
-    }
-
     wgpu::CommandBuffer commands = encoder.Finish();
     webgpu_command      result   = {};
     result.commands              = commands;
-    result.params_bufs           = params_bufs;
-    result.set_rows_error_bufs   = set_rows_error_bufs;
+    result.params_bufs           = params_bufs_list;
+    result.num_kernels           = pipelines.size();
 #ifdef GGML_WEBGPU_GPU_PROFILE
     result.timestamp_query_bufs = ts_bufs;
-    result.pipeline_name        = pipeline.name;
+    // TODO: handle multiple pipeline names
+    result.pipeline_name        = pipelines.front().name;
 #endif
     return result;
 }
 
-static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
-                                              wgpu::Buffer &   buf,
-                                              uint32_t         value,
-                                              size_t           offset,
-                                              size_t           size) {
+static webgpu_command ggml_backend_webgpu_build(webgpu_global_context &           ctx,
+                                                webgpu_buf_pool &                 param_buf_pool,
+                                                webgpu_pipeline &                 pipeline,
+                                                std::vector             params,
+                                                std::vector bind_group_entries,
+                                                uint32_t                          wg_x,
+                                                uint32_t                          wg_y = 1) {
+    return ggml_backend_webgpu_build_multi(ctx, param_buf_pool,
+                                           {
+                                               pipeline
+    },
+                                           { std::move(params) }, { std::move(bind_group_entries) },
+                                           { { wg_x, wg_y } });
+}
+
+static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
+                                              wgpu::Buffer &          buf,
+                                              uint32_t                value,
+                                              size_t                  offset,
+                                              size_t                  size) {
     std::vector             params  = { (uint32_t) offset, (uint32_t) size, value };
     std::vector entries = {
         { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
     };
-    size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
+    size_t   bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread;
     uint32_t wg_x         = CEIL_DIV(size + 3, bytes_per_wg);
 
-    webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
-    std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) };
-    ggml_backend_webgpu_wait(ctx, futures);
+    webgpu_command command =
+        ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
+    std::vector    commands = { command };
+    std::vector sub      = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) };
+    ggml_backend_webgpu_wait(ctx, sub);
 }
 
 /** End WebGPU Actions */
@@ -714,7 +759,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
     return ctx->name.c_str();
 }
 
-// TODO: implement proper cleanup
 static void ggml_backend_webgpu_free(ggml_backend_t backend) {
     ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@@ -722,19 +766,19 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
 #ifdef GGML_WEBGPU_CPU_PROFILE
     std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
     double total_cpu = 0.0;
-    for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
         total_cpu += kv.second;
     }
     std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
     std::cout << "ggml_webgpu: cpu breakdown:\n";
-    for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) {
         double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
         std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
     }
-    if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
+    if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) {
         std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
     }
-    for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) {
         double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
         std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
     }
@@ -743,14 +787,15 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
 #ifdef GGML_WEBGPU_GPU_PROFILE
     std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
     double total_gpu = 0.0;
-    for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
         total_gpu += kv.second;
     }
     std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
     std::cout << "\nggml_webgpu: gpu breakdown:\n";
-    for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
+    for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) {
         double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
-        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
+        std::cout << "ggml_webgpu:  " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2)
+                  << pct << "%)\n";
     }
 #endif
 
@@ -758,9 +803,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
     std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
 #endif
 
-#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
-    GGML_UNUSED(ctx);
-#endif
+    delete ctx;
+    delete backend;
 }
 
 static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
@@ -774,12 +818,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
 
 static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
     size_t offset = ggml_webgpu_tensor_offset(t);
-    return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
+    return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
 }
 
 static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
     size_t offset = ggml_webgpu_tensor_offset(t);
-    return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
+    return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1);
 }
 
 static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
@@ -792,6 +836,30 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
            (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
 }
 
+// Used to determine if two tensors share the same buffer and their byte ranges overlap,
+static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
+    return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
+           ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
+           ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
+}
+
+struct binary_overlap_flags {
+    bool inplace;  // src0 == dst
+    bool overlap;  // src1 == dst
+    bool src_overlap;
+};
+
+static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
+                                                              ggml_tensor * src1,
+                                                              ggml_tensor * dst) {
+    binary_overlap_flags flags = {};
+    flags.inplace              = ggml_webgpu_tensor_equal(src0, dst);
+    flags.overlap              = ggml_webgpu_tensor_overlap(src1, dst);
+    flags.src_overlap          = ggml_webgpu_tensor_overlap(src0, src1);
+
+    return flags;
+}
+
 static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     uint32_t ne = (uint32_t) ggml_nelements(dst);
 
@@ -820,22 +888,85 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
     };
 
     uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
+                                     params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast(pipeline.context.get());
+
+    const uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+    std::vector params = {
+        ne,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        // Strides (in elements)
+        (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+        (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+        // Shapes
+        (uint32_t) src->ne[0],
+        (uint32_t) src->ne[1],
+        (uint32_t) src->ne[2],
+        (uint32_t) src->ne[3],
+        (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2],
+        (uint32_t) dst->ne[3],
+        // Pad sizes
+        (uint32_t) ggml_get_op_params_i32(dst, 0),
+        (uint32_t) ggml_get_op_params_i32(dst, 1),
+        (uint32_t) ggml_get_op_params_i32(dst, 2),
+        (uint32_t) ggml_get_op_params_i32(dst, 3),
+        (uint32_t) ggml_get_op_params_i32(dst, 4),
+        (uint32_t) ggml_get_op_params_i32(dst, 5),
+        (uint32_t) ggml_get_op_params_i32(dst, 6),
+        (uint32_t) ggml_get_op_params_i32(dst, 7),
+    };
+
+    std::vector entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
+    };
+
+    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static std::optional ggml_webgpu_set_rows(webgpu_context & ctx,
                                                           ggml_tensor *    src,
                                                           ggml_tensor *    idx,
                                                           ggml_tensor *    dst) {
-    // For set rows specifically, we need to check if src and idx are empty tensors.
+    // For set rows specifically, we need to check if src and idx are empty
+    // tensors.
     if (ggml_is_empty(src) || ggml_is_empty(idx)) {
         return std::nullopt;
     }
 
-    webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
-    if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
-        error_bufs.host_buf.Unmap();
-    }
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = idx,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast(pipeline.context.get());
 
     std::vector params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
@@ -865,44 +996,67 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx,
         { .binding = 2,
          .buffer  = ggml_webgpu_tensor_buf(dst),
          .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
-         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) },
-        { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    int             vectorized = src->ne[0] % 4 == 0;
-    webgpu_pipeline pipeline   = ctx->set_rows_pipelines[0][vectorized];
-    uint32_t        threads;
-    if (vectorized) {
+    if (decisions->i64_idx) {
+        entries.push_back({ .binding = 3,
+                            .buffer  = ctx->set_rows_dev_error_buf,
+                            .offset  = 0,
+                            .size    = ctx->set_rows_dev_error_buf.GetSize() });
+    }
+
+    uint32_t threads;
+    if (decisions->vec4) {
         threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
     } else {
         threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
     }
+    uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1);
+}
 
-    uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
-
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
+// Workgroup size is a common constant
+static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) {
+    std::vector constants(1);
+    constants[0].key   = "wg_size";
+    constants[0].value = wg_size;
+    return constants;
 }
 
 static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
                                            ggml_tensor *    src,
                                            ggml_tensor *    idx,
                                            ggml_tensor *    dst) {
-    std::vector params = {
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
-        // Convert byte-strides to element-strides
-        (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
-        (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
-        (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
-        (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
-        (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
-        // Shape of dst
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
-        // Shape of idx
-        (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = WEBGPU_MAX_WG_SIZE,
     };
 
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast(pipeline.context.get());
+
+    std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+                                     (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
+                                     (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)),
+                                     (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
+                                     (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+                                     (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+                                     (uint32_t) dst->ne[0],
+                                     (uint32_t) dst->ne[1],
+                                     (uint32_t) dst->ne[2],
+                                     (uint32_t) dst->ne[3],
+                                     (uint32_t) (idx->ne[1]),
+                                     (uint32_t) (idx->ne[2]) };
+
     std::vector entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src),
@@ -918,36 +1072,97 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
     };
 
-    uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
+    uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
 
-    uint32_t        vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
-    webgpu_pipeline pipeline   = ctx->get_rows_pipelines[src->type][vectorized];
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
                                           ggml_tensor *    src0,
                                           ggml_tensor *    src1,
                                           ggml_tensor *    dst) {
+    // Determine if this is a mat-vec operation
+    bool is_vec = (dst->ne[1] == 1);
+
+    // Determine if we should use fast path
+    bool use_fast = false;
+    switch (src1->type) {
+        case GGML_TYPE_F16:
+            use_fast = (src0->type == GGML_TYPE_F16);
+            break;
+        case GGML_TYPE_F32:
+            // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K
+            switch (src0->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q8_1:
+                case GGML_TYPE_Q6_K:
+                    use_fast = true;
+                    break;
+                case GGML_TYPE_Q2_K:
+                case GGML_TYPE_Q3_K:
+                case GGML_TYPE_Q4_K:
+                case GGML_TYPE_Q5_K:
+                    // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat
+                    use_fast = !is_vec;
+                    break;
+                default:
+                    break;
+            }
+            break;
+        default:
+            break;
+    }
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0                     = src0,
+        .src1                     = src1,
+        .dst                      = dst,
+        .max_wg_size              = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix,
+        .sg_mat_m                 = ctx->global_ctx->capabilities.sg_mat_m,
+        .sg_mat_n                 = ctx->global_ctx->capabilities.sg_mat_n,
+        .sg_mat_k                 = ctx->global_ctx->capabilities.sg_mat_k,
+        .max_subgroup_size        = ctx->global_ctx->capabilities.max_subgroup_size,
+    };
+
+    // Get or create pipeline
+    webgpu_pipeline pipeline;
+
+    if (use_fast && is_vec) {
+        pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx);
+    } else if (use_fast) {
+        pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx);
+    } else {
+        pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx);
+    }
+
+    // Build params
     std::vector params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
-        (uint32_t) dst->ne[0],                                  // number of rows in result (M, transposed)
-        (uint32_t) dst->ne[1],                                  // number of columns in result (N)
-        (uint32_t) src0->ne[0],                                 // number of columns in src0/src1 (K)
-        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 1
-        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 1
-        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 2
-        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 2
-        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),  // stride (elements/blocks) of src0 in dimension 3
-        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),  // stride (elements/blocks) of src1 in dimension 3
-        (uint32_t) src0->ne[2],                                 // batch size in dimension 2
-        (uint32_t) src0->ne[3],                                 // batch size in dimension 3
-        (uint32_t) (src1->ne[2] / src0->ne[2]),                 // broadcast in dimension 2
-        (uint32_t) (src1->ne[3] / src0->ne[3])                  // broadcast in dimension 3
+        (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1],
+        (uint32_t) src0->ne[0],
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+        (uint32_t) src0->ne[2],
+        (uint32_t) src0->ne[3],
+        (uint32_t) (src1->ne[2] / src0->ne[2]),
+        (uint32_t) (src1->ne[3] / src0->ne[3])
     };
 
+    // Build bind group entries
     std::vector entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src0),
@@ -963,69 +1178,51 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
          .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  },
     };
 
-    webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
+    // Calculate workgroup dimensions
+    uint32_t       wg_x           = 1;
+    uint32_t       wg_y           = 1;
+    const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
 
-    uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
-    uint32_t wg_y = 1;
+    if (use_fast && is_vec) {
+        auto * decisions = static_cast(pipeline.context.get());
 
-    bool use_fast = false;
-    switch (src1->type) {
-        case GGML_TYPE_F16:
-            use_fast = (src0->type == GGML_TYPE_F16);
-            break;
-        case GGML_TYPE_F32:
-            switch (src0->type) {
-                case GGML_TYPE_F32:
-                case GGML_TYPE_F16:
-                case GGML_TYPE_Q4_0:
-                    use_fast = true;
-                    break;
-                default:
-                    break;
-            }
-            break;
-        default:
-            break;
-    }
+        uint32_t batches       = dst->ne[2] * dst->ne[3];
+        uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
+        uint32_t total_wg      = output_groups * batches;
+        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
+    } else if (use_fast) {
+        auto * decisions = static_cast(pipeline.context.get());
 
-    if (use_fast) {
-        int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
-        if (dst->ne[1] == 1) {
-            // We don't support vectorized mul_mat_vec for quantized types
-            vectorized             = vectorized && (src0->type < 2);
-            pipeline               = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
-            uint32_t batches       = dst->ne[2] * dst->ne[3];
-            uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
-            uint32_t total_wg      = output_groups * batches;
-            wg_x                   = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
-            wg_y                   = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
+        // Fast-path tiled/subgroup calculations
+        uint32_t wg_m;
+        uint32_t wg_n;
+        if (decisions->use_subgroup_matrix) {
+            uint32_t wg_m_sg_tile =
+                decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m;
+            wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
+            uint32_t wg_n_sg_tile =
+                decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n;
+            wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
         } else {
-            pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
-            uint32_t wg_m;
-            uint32_t wg_n;
-#ifndef __EMSCRIPTEN__
-            if (ctx->supports_subgroup_matrix) {
-                // The total number of subgroups/workgroups needed per matrix.
-                uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
-                wg_m                  = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
-                uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
-                wg_n                  = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
-            } else {
-#endif
-                uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
-                uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
-                wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);
-                wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);
-#ifndef __EMSCRIPTEN__
-            }
-#endif
-
-            wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+            uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m;
+            uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n;
+            wg_m              = CEIL_DIV(dst->ne[0], tile_m_s);
+            wg_n              = CEIL_DIV(dst->ne[1], tile_n_s);
         }
+        uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3];
+        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
+
+    } else {  // legacy
+        auto *   decisions = static_cast(pipeline.context.get());
+        uint32_t wg_size   = decisions->wg_size;
+        uint32_t total_wg  = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size);
+        compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y);
     }
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
+
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
 }
 
+#ifndef __EMSCRIPTEN__
 static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                                              ggml_tensor *    Q,
                                              ggml_tensor *    K,
@@ -1109,105 +1306,97 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
-    bool kv_direct =
-        (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
-
-    flash_attn_pipeline_key key = {
-        .q_type             = Q->type,
-        .kv_type            = K->type,
-        .dst_type           = dst->type,
-        .head_dim_qk        = (uint32_t) Q->ne[0],
-        .head_dim_v         = (uint32_t) V->ne[0],
-        .kv_direct          = kv_direct,
-        .has_mask           = static_cast(has_mask),
-        .has_sinks          = static_cast(has_sinks),
-        .uses_logit_softcap = logit_softcap != 0.0f,
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0               = Q,
+        .src1               = K,
+        .src2               = V,
+        .src3               = mask,
+        .src4               = sinks,
+        .dst                = dst,
+        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+        .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,
+        .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,
+        .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,
+        .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size,
     };
 
-    webgpu_pipeline                         pipeline;
-    ggml_webgpu_flash_attn_shader_decisions decisions = {};
+    webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
 
-    auto it = ctx->flash_attn_pipelines.find(key);
-    if (it != ctx->flash_attn_pipelines.end()) {
-        pipeline  = it->second;
-        decisions = *static_cast(pipeline.context);
-    } else {
-        std::lock_guard lock(ctx->mutex);
-        it = ctx->flash_attn_pipelines.find(key);
-        if (it != ctx->flash_attn_pipelines.end()) {
-            pipeline  = it->second;
-            decisions = *static_cast(pipeline.context);
-        } else {
-            ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type     = K->type,
-                                                                         .head_dim_qk = (uint32_t) Q->ne[0],
-                                                                         .head_dim_v  = (uint32_t) V->ne[0],
-                                                                         .kv_direct   = kv_direct,
-                                                                         .has_mask    = static_cast(has_mask),
-                                                                         .has_sinks   = static_cast(has_sinks),
-                                                                         .uses_logit_softcap = logit_softcap != 0.0f,
-                                                                         .sg_mat_m           = ctx->sg_mat_m,
-                                                                         .sg_mat_n           = ctx->sg_mat_n,
-                                                                         .sg_mat_k           = ctx->sg_mat_k,
-                                                                         .wg_mem_limit_bytes =
-                                                                             ctx->limits.maxComputeWorkgroupStorageSize,
-                                                                         .max_subgroup_size = ctx->max_subgroup_size };
+    auto * decisions = static_cast(pipeline.context.get());
 
-            ggml_webgpu_processed_shader processed =
-                ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
-            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
-            pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
-            ctx->flash_attn_pipelines.emplace(key, pipeline);
-            decisions = processed.decisions;
-        }
-    }
-
-    uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
+    uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
     uint32_t wg_x        = wg_per_head * Q->ne[2] * Q->ne[3];  // wg per head * number of heads * number of batches
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
+#endif
 
 static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    uint32_t      ne       = (uint32_t) ggml_nelements(dst);
-    ggml_unary_op unary_op = ggml_get_unary_op(dst);
-    uint32_t      inplace  = ggml_webgpu_tensor_equal(src, dst);
+    bool is_unary = dst->op == GGML_OP_UNARY;
+    bool inplace  = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL);
 
-    std::vector params = {
-        ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
-        // Convert byte-strides to element-strides
-        (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
-        (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
-        (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
-        (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
-        // Logical shapes
-        (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
-        (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = inplace,
     };
 
-    switch (unary_op) {
-        case GGML_UNARY_OP_XIELU:
-            {
-                // Get float parameters and reinterpret their bit patterns as uint32_t
-                // for passing through the params buffer
-                float alpha_n = ggml_get_op_params_f32(dst, 1);
-                float alpha_p = ggml_get_op_params_f32(dst, 2);
-                float beta    = ggml_get_op_params_f32(dst, 3);
-                float eps     = ggml_get_op_params_f32(dst, 4);
-                params.push_back(*reinterpret_cast(&alpha_n));
-                params.push_back(*reinterpret_cast(&alpha_p));
-                params.push_back(*reinterpret_cast(&beta));
-                params.push_back(*reinterpret_cast(&eps));
+    webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast(pipeline.context.get());
+
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+    std::vector params = { ne,
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) (src->nb[0] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+                                     (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+                                     (uint32_t) src->ne[0],
+                                     (uint32_t) src->ne[1],
+                                     (uint32_t) src->ne[2] };
+
+    ggml_tensor * effective_src = src;
+    if (is_unary) {
+        ggml_unary_op unary_op = ggml_get_unary_op(dst);
+        switch (unary_op) {
+            case GGML_UNARY_OP_XIELU:
+                {
+                    // Get float parameters and reinterpret their bit patterns as uint32_t
+                    // for passing through the params buffer
+                    float alpha_n = ggml_get_op_params_f32(dst, 1);
+                    float alpha_p = ggml_get_op_params_f32(dst, 2);
+                    float beta    = ggml_get_op_params_f32(dst, 3);
+                    float eps     = ggml_get_op_params_f32(dst, 4);
+                    params.push_back(*reinterpret_cast(&alpha_n));
+                    params.push_back(*reinterpret_cast(&alpha_p));
+                    params.push_back(*reinterpret_cast(&beta));
+                    params.push_back(*reinterpret_cast(&eps));
+                    break;
+                }
+            default:
                 break;
-            }
-        default:
-            break;
+        }
+    } else if (dst->op == GGML_OP_CLAMP) {
+        float clamp_min = ggml_get_op_params_f32(dst, 0);
+        float clamp_max = ggml_get_op_params_f32(dst, 1);
+        params.push_back(*reinterpret_cast(&clamp_min));
+        params.push_back(*reinterpret_cast(&clamp_max));
+    } else if (dst->op == GGML_OP_FILL) {
+        float fill_val = ggml_get_op_params_f32(dst, 0);
+        params.push_back(*reinterpret_cast(&fill_val));
+        effective_src = dst;  // fill simply fills dst
     }
 
     std::vector entries = {
         { .binding = 0,
-         .buffer  = ggml_webgpu_tensor_buf(src),
-         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
-         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+         .buffer  = ggml_webgpu_tensor_buf(effective_src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, effective_src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, effective_src) },
     };
     if (!inplace) {
         entries.push_back({ .binding = 1,
@@ -1216,21 +1405,54 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
+    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
-                                            ggml_tensor *     src0,
-                                            ggml_tensor *     src1,
-                                            ggml_tensor *     dst,
-                                            webgpu_pipeline & pipeline,
-                                            bool              inplace) {
+static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
+                                            ggml_tensor *    src0,
+                                            ggml_tensor *    src1,
+                                            ggml_tensor *    dst) {
+    binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = flags.inplace,
+        .overlap     = flags.overlap,
+        .src_overlap = flags.src_overlap,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
+
+    auto * decisions = static_cast(pipeline.context.get());
+
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+    size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
+    size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
+
+    uint32_t offset_merged_src0 = 0;
+    uint32_t offset_merged_src1 = 0;
+    if (flags.src_overlap) {
+        size_t min_off     = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+        offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
+        offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
+    }
+
     std::vector params = {
-        (uint32_t) ggml_nelements(dst),
+        ne,
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        offset_merged_src0,
+        offset_merged_src1,
+        (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
         (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
         (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
         (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
@@ -1244,6 +1466,79 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
         (uint32_t) src1->ne[3],
     };
 
+    std::vector entries;
+
+    if (flags.src_overlap) {
+        size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+        size_t merged_end    = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
+                                        src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
+        entries.push_back({
+            .binding = 0,
+            .buffer  = ggml_webgpu_tensor_buf(src0),
+            .offset  = merged_offset,
+            .size    = merged_end - merged_offset,
+        });
+        entries.push_back({
+            .binding = 1,
+            .buffer  = ggml_webgpu_tensor_buf(dst),
+            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+            .size    = ggml_webgpu_tensor_binding_size(ctx, dst),
+        });
+    } else {
+        entries.push_back({
+            .binding = 0,
+            .buffer  = ggml_webgpu_tensor_buf(src0),
+            .offset  = src0_webgpu_tensor_align_offset,
+            .size    = ggml_webgpu_tensor_binding_size(ctx, src0),
+        });
+        entries.push_back({
+            .binding = 1,
+            .buffer  = ggml_webgpu_tensor_buf(src1),
+            .offset  = src1_webgpu_tensor_align_offset,
+            .size    = ggml_webgpu_tensor_binding_size(ctx, src1),
+        });
+        if (!flags.inplace && !flags.overlap) {
+            entries.push_back({
+                .binding = 2,
+                .buffer  = ggml_webgpu_tensor_buf(dst),
+                .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                .size    = ggml_webgpu_tensor_binding_size(ctx, dst),
+            });
+        }
+    }
+
+    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
+                                         ggml_tensor *    src0,
+                                         ggml_tensor *    src1,
+                                         ggml_tensor *    dst) {
+    uint32_t ne  = (uint32_t) ggml_nelements(dst);
+    uint32_t dim = (uint32_t) dst->op_params[0];
+
+    std::vector params = {
+        ne,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
+        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+        (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2],
+        (uint32_t) dst->ne[3],
+        dim,
+        (uint32_t) src0->ne[dim]
+    };
+
     std::vector entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src0),
@@ -1252,17 +1547,66 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
         { .binding = 1,
          .buffer  = ggml_webgpu_tensor_buf(src1),
          .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
-         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) }
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) },
+        { .binding = 2,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }
     };
-    if (!inplace) {
-        entries.push_back({ .binding = 2,
-                            .buffer  = ggml_webgpu_tensor_buf(dst),
-                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
-                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
-    }
 
-    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast(pipeline.context.get());
+    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) {
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+
+    std::vector params = { ne,
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) /
+                                                 ggml_type_size(src0->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+                                     (uint32_t) (src0->ne[0]),
+                                     (uint32_t) (src0->ne[1]),
+                                     (uint32_t) (src0->ne[2]),
+                                     (uint32_t) (src0->ne[3]),
+                                     (uint32_t) (dst->ne[0]),
+                                     (uint32_t) (dst->ne[1]),
+                                     (uint32_t) (dst->ne[2]) };
+
+    std::vector entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src0),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst)  }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast(pipeline.context.get());
+    uint32_t        wg_x      = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@@ -1297,7 +1641,8 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params,
+                                     entries, ggml_nrows(src));
 }
 
 static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
@@ -1312,7 +1657,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
     const int mode       = ((int32_t *) dst->op_params)[2];
     const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
     memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
     memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
     memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
@@ -1384,7 +1734,7 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
 
     webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
     uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@@ -1436,12 +1786,24 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
 
     webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
     uint32_t        wg_x     = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
-    int inplace = ggml_webgpu_tensor_equal(src, dst);
+    bool inplace = ggml_webgpu_tensor_equal(src, dst);
 
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .inplace     = inplace,
+    };
+
+    webgpu_pipeline pipeline  = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx);
+    auto *          decisions = static_cast(pipeline.context.get());
+
+    // params unchanged
     std::vector params = {
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1459,12 +1821,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
         *(uint32_t *) &dst->op_params[1]  // bias
     };
 
+    // bindgroups unchanged
     std::vector entries = {
         { .binding = 0,
          .buffer  = ggml_webgpu_tensor_buf(src),
          .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
          .size    = ggml_webgpu_tensor_binding_size(ctx, src) }
     };
+
     if (!inplace) {
         entries.push_back({ .binding = 1,
                             .buffer  = ggml_webgpu_tensor_buf(dst),
@@ -1472,8 +1836,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
-    return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
+    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
 static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
@@ -1545,15 +1909,261 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
+                                     ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
                                      ggml_nrows(dst));
 }
 
+static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) src->ne[0] };
+
+    std::vector entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx);
+    uint32_t        wg_x     = ggml_nelements(dst);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    bool is_top_k = dst->op == GGML_OP_TOP_K;
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0               = src,
+        .src1               = nullptr,
+        .dst                = dst,
+        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+    };
+
+    webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx);
+    auto * argsort_decisions = static_cast(argsort_pipeline.context.get());
+
+    webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx);
+
+    const uint32_t src_ne0 = (uint32_t) src->ne[0];
+    const uint32_t nrows   = (uint32_t) ggml_nrows(src);
+    const uint32_t npr     = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
+    const uint32_t block_size =
+        is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
+    uint32_t out_ne0 = src_ne0;
+    if (is_top_k) {
+        if (npr > 1) {
+            const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
+            out_ne0                  = (npr - 1) * block_size + std::min(last_tile, block_size);
+        } else {
+            out_ne0 = block_size;
+        }
+    }
+
+    uint32_t merge_len    = block_size;
+    uint32_t merge_passes = 0;
+    while (merge_len < out_ne0) {
+        merge_len <<= 1;
+        merge_passes++;
+    }
+
+    const bool start_in_tmp = (merge_passes % 2) == 1;
+
+    const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
+    const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t);
+    const size_t tmp_offset =
+        ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
+    const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT);
+    const size_t dst_binding_size =
+        ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT);
+
+    const uint32_t offset_src  = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type));
+    const uint32_t offset_dst  = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type));
+    const uint32_t offset_tmp  = 0;
+    const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type));
+    const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type));
+    const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type));
+    const uint32_t stride_idx1 = out_ne0;
+    const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1];
+    const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2];
+
+    std::vector                   pipelines;
+    std::vector>             params_list;
+    std::vector> entries_list;
+    std::vector>     workgroups_list;
+
+    const uint32_t init_offset       = start_in_tmp ? offset_tmp : offset_dst;
+    const size_t   init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+    const size_t   init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size;
+
+    std::vector init_params = {
+        offset_src,  init_offset, stride_src1, stride_src2,           stride_src3,           stride_idx1,
+        stride_idx2, stride_idx3, src_ne0,     (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0,
+        block_size,  npr,         nrows
+    };
+
+    const uint32_t                    total_wg_init = npr * nrows;
+    const uint32_t                    max_wg    = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
+    const uint32_t                    wg_x_init = std::min(total_wg_init, max_wg);
+    const uint32_t                    wg_y_init = CEIL_DIV(total_wg_init, wg_x_init);
+    std::vector init_entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+        { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size }
+    };
+
+    pipelines.push_back(argsort_pipeline);
+    params_list.push_back(std::move(init_params));
+    entries_list.push_back(std::move(init_entries));
+    workgroups_list.push_back({ wg_x_init, wg_y_init });
+
+    if (merge_passes == 0) {
+        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+                                               entries_list, workgroups_list);
+    }
+
+    bool     in_is_tmp = start_in_tmp;
+    uint32_t len       = block_size;
+    while (len < out_ne0) {
+        const uint32_t nm = CEIL_DIV(out_ne0, 2 * len);
+
+        const bool     out_is_tmp  = !in_is_tmp;
+        const uint32_t offset_in   = in_is_tmp ? offset_tmp : offset_dst;
+        const uint32_t offset_out  = out_is_tmp ? offset_tmp : offset_dst;
+        const size_t   align_in    = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+        const size_t   align_out   = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst);
+        const size_t   size_in     = in_is_tmp ? tmp_binding_size : dst_binding_size;
+        const size_t   size_out    = out_is_tmp ? tmp_binding_size : dst_binding_size;
+        const uint32_t top_k_out   = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0;
+        const uint32_t stride_out1 = top_k_out;
+        const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1];
+        const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2];
+
+        std::vector merge_params = { offset_src,
+                                               offset_in,
+                                               offset_out,
+                                               stride_src1,
+                                               stride_src2,
+                                               stride_src3,
+                                               stride_idx1,
+                                               stride_idx2,
+                                               stride_idx3,
+                                               stride_out1,
+                                               stride_out2,
+                                               stride_out3,
+                                               out_ne0,
+                                               (uint32_t) src->ne[1],
+                                               (uint32_t) src->ne[2],
+                                               top_k_out,
+                                               len,
+                                               nm,
+                                               nrows };
+
+        std::vector merge_entries = {
+            { .binding = 0,
+             .buffer  = ggml_webgpu_tensor_buf(src),
+             .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+             .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+            { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in },
+            { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out }
+        };
+
+        const uint32_t total_wg_merge = nm * nrows;
+        const uint32_t wg_x_merge     = std::min(total_wg_merge, max_wg);
+        const uint32_t wg_y_merge     = CEIL_DIV(total_wg_merge, wg_x_merge);
+        workgroups_list.push_back({ wg_x_merge, wg_y_merge });
+        pipelines.push_back(argsort_merge_pipeline);
+        params_list.push_back(std::move(merge_params));
+        entries_list.push_back(std::move(merge_entries));
+
+        len <<= 1;
+        in_is_tmp = !in_is_tmp;
+    }
+
+    return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list,
+                                           workgroups_list);
+}
+
+static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     (uint32_t) src->ne[0] };
+
+    std::vector entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src,
+        .src1        = nullptr,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx);
+    uint32_t        wg_x     = ggml_nrows(dst);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
+static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
+    bool                  total_sum = dst->op == GGML_OP_SUM;
+    std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
+                                     (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+                                     total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
+                                     total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
+                                     total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
+                                     total_sum ? static_cast(ggml_nelements(src)) : (uint32_t) src->ne[0],
+                                     total_sum ? 1 : (uint32_t) src->ne[1],
+                                     total_sum ? 1 : (uint32_t) src->ne[2] };
+
+    std::vector entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(src),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, src),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, src) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(dst),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx);
+
+    uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
 // Returns the encoded command, or std::nullopt if the operation is a no-op
 static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
     if (ggml_is_empty(node)) {
         return std::nullopt;
     }
+    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+        return std::nullopt;
+    }
     WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
 
     ggml_tensor * src0 = node->src[0];
@@ -1578,27 +2188,20 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_MUL_MAT:
             return ggml_webgpu_mul_mat(ctx, src0, src1, node);
         case GGML_OP_FLASH_ATTN_EXT:
+#ifndef __EMSCRIPTEN__
             return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
+#else
+            return std::nullopt;
+#endif
         case GGML_OP_ADD:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
-            }
         case GGML_OP_SUB:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
-            }
         case GGML_OP_MUL:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
-            }
         case GGML_OP_DIV:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
-            }
+            return ggml_webgpu_binary_op(ctx, src0, src1, node);
+        case GGML_OP_CONCAT:
+            return ggml_webgpu_concat(ctx, src0, src1, node);
+        case GGML_OP_REPEAT:
+            return ggml_webgpu_repeat(ctx, src0, node);
         case GGML_OP_RMS_NORM:
             return ggml_webgpu_rms_norm(ctx, src0, node);
         case GGML_OP_ROPE:
@@ -1610,7 +2213,27 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_SOFT_MAX:
             return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
         case GGML_OP_UNARY:
+        case GGML_OP_CLAMP:
+        case GGML_OP_FILL:
+        case GGML_OP_LOG:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
             return ggml_webgpu_unary_op(ctx, src0, node);
+        case GGML_OP_PAD:
+            return ggml_webgpu_pad(ctx, src0, node);
+        case GGML_OP_ARGMAX:
+            return ggml_webgpu_argmax(ctx, src0, node);
+        case GGML_OP_ARGSORT:
+        case GGML_OP_TOP_K:
+            // we reuse the same argsort implementation for top_k
+            return ggml_webgpu_argsort(ctx, src0, node);
+        case GGML_OP_CUMSUM:
+            return ggml_webgpu_cumsum(ctx, src0, node);
+        case GGML_OP_SUM:
+        case GGML_OP_SUM_ROWS:
+            return ggml_webgpu_sum_rows(ctx, src0, node);
         default:
             return std::nullopt;
     }
@@ -1619,39 +2242,57 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx,
 static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
 
-    ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context);
+    ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context;
     webgpu_context                ctx         = backend_ctx->webgpu_ctx;
 
     WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
 
-    ctx->inflight_threads++;
+    std::vector    commands;
+    std::vector subs;
+    uint32_t                       num_batched_kernels = 0;
+    bool                           contains_set_rows   = false;
 
-    std::vector            commands;
-    std::vector futures;
     for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
+            contains_set_rows = true;
+        }
         if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
             commands.push_back(*cmd);
+            num_batched_kernels += cmd.value().num_kernels;
         }
-        // compute the batch size based on the number of inflight threads
-        uint32_t inflight_threads = ctx->inflight_threads;
-        uint32_t batch_size       = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
-                                             WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
-        if (commands.size() >= batch_size) {
-            futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
+
+        if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
+            num_batched_kernels = 0;
+            subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
             // Process events and check for completed submissions
-            ctx->instance.ProcessEvents();
-            ggml_backend_webgpu_wait(ctx, futures, false);
+            ctx->global_ctx->instance.ProcessEvents();
+            ggml_backend_webgpu_wait(ctx->global_ctx, subs, false);
             commands.clear();
         }
     }
     if (!commands.empty()) {
-        webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
-        futures.push_back(new_futures);
+        subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool));
+        commands.clear();
     }
 
-    ggml_backend_webgpu_wait(ctx, futures);
-    ctx->inflight_threads--;
-    WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
+    // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking.
+    if (contains_set_rows) {
+        wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder();
+        encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0,
+                                   ctx->set_rows_host_error_buf.GetSize());
+        wgpu::CommandBuffer set_rows_commands = encoder.Finish();
+        ctx->global_ctx->queue.Submit(1, &set_rows_commands);
+        ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0,
+                                       ctx->set_rows_host_error_buf.GetSize());
+        const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange();
+        if (*error_data) {
+            GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
+        }
+        ctx->set_rows_host_error_buf.Unmap();
+    }
+
+    ggml_backend_webgpu_wait(ctx->global_ctx, subs);
+    WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx);
     return GGML_STATUS_SUCCESS;
 }
 
@@ -1678,7 +2319,10 @@ static ggml_backend_i ggml_backend_webgpu_i = {
 
 static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context);
-    ctx->buffer.Destroy();
+    if (ctx != nullptr && ctx->buffer != nullptr) {
+        ctx->buffer.Destroy();
+        delete ctx;
+    }
 }
 
 // Returns the "fake" base pointer.
@@ -1693,7 +2337,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
                                                      size_t                offset,
                                                      size_t                size) {
     if (size == 0) {
-        WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
+        WEBGPU_LOG_DEBUG(
+            "ggml_backend_webgpu_buffer_memset_tensor: size is zero, "
+            "nothing to do.");
         return;
     }
 
@@ -1708,8 +2354,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
 
     // This is a trick to set all bytes of a u32 to the same 1 byte value.
     uint32_t val32 = (uint32_t) value * 0x01010101;
-    ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
-    WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
+    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size);
+    WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx);
 }
 
 static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
@@ -1718,15 +2364,14 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
                                                   size_t                offset,
                                                   size_t                size) {
     WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
-    ggml_backend_webgpu_buffer_context * buf_ctx    = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    webgpu_context                       webgpu_ctx = buf_ctx->webgpu_ctx;
+    ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
 
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
                                                               << ", " << offset << ", " << size << ")");
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
-    webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
+    buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
 
     if (size % 4 != 0) {
         // If size is not a multiple of 4, we need to memset the remaining bytes
@@ -1739,21 +2384,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
             ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
         }
         // memset the remaining bytes
-        ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
-                                          remaining_size);
+        ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32,
+                                          total_offset + (size - remaining_size), remaining_size);
     } else {
         // wait for WriteBuffer to complete
-        webgpu_ctx->instance.WaitAny(
-            webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
+        buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone(
+                                                  wgpu::CallbackMode::AllowSpontaneous,
                                                   [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
                                                       if (status != wgpu::QueueWorkDoneStatus::Success) {
                                                           GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
                                                                          std::string(message).c_str());
                                                       }
                                                   }),
-            UINT64_MAX);
+                                              UINT64_MAX);
     }
-    WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
+    WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx);
 }
 
 static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
@@ -1765,53 +2410,56 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
                                                               << ", " << offset << ", " << size << ")");
-    webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
-    wgpu::Device   device     = webgpu_ctx->device;
+    wgpu::Device device = buf_ctx->global_ctx->device;
 
     size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
 
     size_t final_size = size;
     if (size % 4 != 0) {
-        // If size is not a multiple of 4, we need to round it up to the next multiple of 4
+        // If size is not a multiple of 4, we need to round it up to the next
+        // multiple of 4
         final_size = size + (4 - (size % 4));
     }
 
-    std::lock_guard lock(webgpu_ctx->mutex);
+    std::lock_guard lock(buf_ctx->global_ctx->mutex);
 
-    if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
+    if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr ||
+        buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) {
         // Create a new staging buffer if it doesn't exist or is too small
-        if (webgpu_ctx->get_tensor_staging_buf) {
-            webgpu_ctx->get_tensor_staging_buf.Destroy();
+        if (buf_ctx->global_ctx->get_tensor_staging_buf) {
+            buf_ctx->global_ctx->get_tensor_staging_buf.Destroy();
         }
-        ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
+        ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size,
                                   wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
     }
 
     // Copy the data from the buffer to the staging buffer
     wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
-    encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
+    encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0,
+                               final_size);
     wgpu::CommandBuffer commands = encoder.Finish();
 
     // Submit the command buffer to the queue
-    webgpu_ctx->queue.Submit(1, &commands);
+    buf_ctx->global_ctx->queue.Submit(1, &commands);
 
     // Map the staging buffer to read the data
-    ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
+    ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf,
+                                   wgpu::MapMode::Read, 0, final_size);
     // Must specify size here since the staging buffer might be larger than the tensor size
-    const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
+    const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
 
     // Copy the data from the mapped range to the output buffer
     std::memcpy(data, mapped_range, size);
-    webgpu_ctx->get_tensor_staging_buf.Unmap();
-    WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
+    buf_ctx->global_ctx->get_tensor_staging_buf.Unmap();
+    WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx);
 }
 
 static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
     WEBGPU_CPU_PROFILE_TOTAL_START(clear);
     ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
-    ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
-    WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
+    ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size);
+    WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx);
 }
 
 static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
@@ -1823,7 +2471,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
     /* .get_tensor      = */ ggml_backend_webgpu_buffer_get_tensor,
     /* .cpy_tensor      = */ NULL,  // TODO: optional, implement this
     /* .clear           = */ ggml_backend_webgpu_buffer_clear,
-    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with .init_tensor
+    /* .reset           = */ NULL,  // TODO: optional, think it coordinates with
+                                    // .init_tensor
 };
 
 /* End GGML Backend Buffer Interface */
@@ -1841,28 +2490,57 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b
     int                     buffer_id = buffer_count++;
     std::string             buf_name  = "tensor_buf" + std::to_string(buffer_id);
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
-    ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context);
 
-    wgpu::Buffer buf;
-    ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
+    ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context);
+    wgpu::Buffer                         buf;
+    ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
                               wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
                               buf_name.c_str());
 
     ggml_backend_webgpu_buffer_context * buf_ctx =
-        new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
+        new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx);
 
     return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
 }
 
 static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context);
-    return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
+    ggml_backend_webgpu_device_context * dev_ctx =
+        static_cast(buft->device->context);
+    return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
 }
 
-// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
+// maxBufferSize might be larger, but you can't bind more than
+// maxStorageBufferBindingSize to a single binding.
 static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    ggml_backend_webgpu_device_context * dev_ctx =
+        static_cast(buft->device->context);
+    return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize;
+}
+
+static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
+                                                             const ggml_tensor *        tensor) {
     ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context);
-    return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
+    size_t                               res = ggml_nbytes(tensor);
+    switch (tensor->op) {
+        case GGML_OP_ARGSORT:
+            res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
+                               WEBGPU_STORAGE_BUF_BINDING_MULT);
+            break;
+        case GGML_OP_TOP_K:
+            {
+                const ggml_tensor * src0 = tensor->src[0];
+                if (src0) {
+                    const size_t full = sizeof(int32_t) * ggml_nelements(src0);
+                    res               = ROUNDUP_POW2(
+                        full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment,
+                        WEBGPU_STORAGE_BUF_BINDING_MULT);
+                }
+            }
+            break;
+        default:
+            break;
+    }
+    return res;
 }
 
 /* End GGML Backend Buffer Type Interface */
@@ -1883,7 +2561,7 @@ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t
     ggml_backend_webgpu_device_context * ctx = static_cast(dev->context);
     // TODO: for now, return maxBufferSize as both free and total memory
     // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates.
-    uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize;
+    uint64_t                             max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize;
     // If we're on a 32-bit system, clamp to UINTPTR_MAX
 #if UINTPTR_MAX < UINT64_MAX
     uint64_t max_ptr_size = static_cast(UINTPTR_MAX);
@@ -1918,329 +2596,64 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
     return reinterpret_cast((void *) guid_str);
 }
 
-// Workgroup size is a common constant
-static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) {
-    std::vector constants(1);
-    constants[0].key   = "wg_size";
-    constants[0].value = wg_size;
-    return constants;
-}
-
-static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
+static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
     // we use the maximum workgroup size for the memset pipeline
-    size_t max_threads                  = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
+    size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
     // Size the bytes_per_thread so that the largest buffer size can be handled
-    webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
+    ctx->capabilities.memset_bytes_per_thread =
+        CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads);
     std::vector constants(2);
-    constants[0].key                = "wg_size";
-    constants[0].value              = WEBGPU_MAX_WG_SIZE;
-    constants[1].key                = "bytes_per_thread";
-    constants[1].value              = webgpu_ctx->memset_bytes_per_thread;
-    webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
-}
-
-static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
-    // Q4/Q5/Q8 classic quantizations
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
-
-    // K-quantizations
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
-
-    // IQ quantizations (2-, 3-, 4-bit variants)
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
-
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
-
-    // 1-bit and 4-bit IQ variants
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
-
-    std::string proc_mul_mat_f32_f32;
-    std::string proc_mul_mat_f32_f32_vec;
-    std::string proc_mul_mat_f16_f32;
-    std::string proc_mul_mat_f16_f32_vec;
-    std::string proc_mul_mat_f16_f16;
-    std::string proc_mul_mat_f16_f16_vec;
-    std::string proc_mul_mat_q4_0_f32;
-    std::string proc_mul_mat_q4_0_f32_vec;
-
-    std::vector mul_mat_constants;
-#ifndef __EMSCRIPTEN__
-    if (webgpu_ctx->supports_subgroup_matrix) {
-        std::map sg_matrix_repls;
-        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
-        sg_matrix_repls["WEBGPU_TILE_K"]            = std::to_string(WEBGPU_MUL_MAT_TILE_K);
-        sg_matrix_repls["WEBGPU_SUBGROUP_M"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
-        sg_matrix_repls["WEBGPU_SUBGROUP_N"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
-        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
-        sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
-        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_m);
-        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_n);
-        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_k);
-
-        proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
-        proc_mul_mat_f32_f32_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
-        proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
-        proc_mul_mat_f16_f32_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
-        proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
-        proc_mul_mat_f16_f16_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
-        proc_mul_mat_q4_0_f32 =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
-        proc_mul_mat_q4_0_f32_vec =
-            ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
-    } else {
-#endif
-        mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
-        mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
-        mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
-
-        std::map reg_repls;
-        reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
-        reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
-
-        proc_mul_mat_f32_f32      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
-        proc_mul_mat_f32_f32_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
-        proc_mul_mat_f16_f32      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
-        proc_mul_mat_f16_f32_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
-        proc_mul_mat_f16_f16      = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
-        proc_mul_mat_f16_f16_vec  = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
-        proc_mul_mat_q4_0_f32     = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
-        proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
-#ifndef __EMSCRIPTEN__
-    }
-#endif
-
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
-    webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
-
-    std::vector mul_mat_vec_constants(3);
-    mul_mat_vec_constants[0].key   = "WORKGROUP_SIZE";
-    mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
-    mul_mat_vec_constants[1].key   = "TILE_K";
-    mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
-    mul_mat_vec_constants[2].key   = "OUTPUTS_PER_WG";
-    mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
-
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
-    webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
-}
-
-static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
-    webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
-    webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
-}
-
-static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
-
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
-    webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
+    constants[0].key         = "wg_size";
+    constants[0].value       = WEBGPU_MAX_WG_SIZE;
+    constants[1].key         = "bytes_per_thread";
+    constants[1].value       = ctx->capabilities.memset_bytes_per_thread;
+    ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
 }
 
 static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
     std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
+    webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
     webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
-}
-
-static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
-    webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
-    webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
-    webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
-    webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
-    webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
-    webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
 }
 
 static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
     std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
 
     webgpu_ctx->rms_norm_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
-    webgpu_ctx->rms_norm_pipelines[1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants);
+    webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
 }
 
 static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
     std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
 
     webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
     webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
 
     webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
     webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
-    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
+    webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
 }
 
 static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
@@ -2248,242 +2661,59 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
 
     // REGLU
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
 
     // GEGLU
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
 
     // SWIGLU
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
 
     // SWIGLU_OAI
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
 
     // GEGLU_ERF
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
 
     // GEGLU_QUICK
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
     webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
-    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
-}
-
-static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    // ABS
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
-
-    // SGN
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
-
-    // NEG
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
-
-    // STEP
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
-
-    // TANH
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
-
-    // ELU
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
-
-    // RELU
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
-
-    // SIGMOID
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
-
-    // GELU
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
-
-    // GELU_QUICK
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
-
-    // SILU
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
-
-    // HARDSWISH
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
-
-    // HARDSIGMOID
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
-
-    // EXP
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
-
-    // GELU_ERF
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
-
-    // XIELU
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
-
-    // CEIL
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
-    webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
-}
-
-static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->scale_pipelines[0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
-    webgpu_ctx->scale_pipelines[1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
+    webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
 }
 
 static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
@@ -2491,56 +2721,239 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
 
     // f32 (no mask)
     webgpu_ctx->soft_max_pipelines[2][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
-    webgpu_ctx->soft_max_pipelines[2][0][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
-    webgpu_ctx->soft_max_pipelines[2][1][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
+    webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
+    webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
     webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
 
     // f32 mask (mask_type = 0)
-    webgpu_ctx->soft_max_pipelines[0][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
+    webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
     webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
     webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
-    webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
+    webgpu_ctx->soft_max_pipelines[0][1][1] =
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
+                                    "soft_max_f32_mask_f32_sink_inplace", constants);
 
     // f16 mask (mask_type = 1)
-    webgpu_ctx->soft_max_pipelines[1][0][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
+    webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
     webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
     webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
-    webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
-        webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
+        webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
+    webgpu_ctx->soft_max_pipelines[1][1][1] =
+        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
+                                    "soft_max_f32_mask_f16_sink_inplace", constants);
 }
 
-// TODO: move most initialization logic here
-static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
+static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
+    wgpu::RequestAdapterOptions options = {};
+
+#ifndef __EMSCRIPTEN__
+    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
+    const char * const          adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
+    wgpu::DawnTogglesDescriptor adapterTogglesDesc;
+    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;
+    adapterTogglesDesc.enabledToggleCount = 2;
+    options.nextInChain                   = &adapterTogglesDesc;
+#endif
+
+    ctx->webgpu_global_ctx->instance.WaitAny(
+        ctx->webgpu_global_ctx->instance.RequestAdapter(
+            &options, wgpu::CallbackMode::AllowSpontaneous,
+            [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
+                if (status != wgpu::RequestAdapterStatus::Success) {
+                    GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
+                    return;
+                }
+                ctx->webgpu_global_ctx->adapter = std::move(adapter);
+            }),
+        UINT64_MAX);
+    GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr);
+
+    ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits);
+
+    wgpu::AdapterInfo info{};
+#ifndef __EMSCRIPTEN__
+    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
+    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+        info.nextInChain = &subgroup_matrix_configs;
+    }
+#endif
+    ctx->webgpu_global_ctx->adapter.GetInfo(&info);
+    wgpu::SupportedFeatures features;
+    ctx->webgpu_global_ctx->adapter.GetFeatures(&features);
+    // we require f16 support
+    GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
+
+#ifndef __EMSCRIPTEN__
+    // Only support square f16 matrices of size 8 or 16 for now
+    bool valid_subgroup_matrix_config = false;
+    if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
+        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
+            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
+            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
+                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
+                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
+                ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M;
+                ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N;
+                ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K;
+                valid_subgroup_matrix_config                  = true;
+                break;
+            }
+        }
+    }
+    ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;
+#endif
+
+    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
+    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
+    ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
+    // Initialize device
+    std::vector required_features       = { wgpu::FeatureName::ShaderF16 };
+
+#ifndef __EMSCRIPTEN__
+    required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
+    if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
+        required_features.push_back(wgpu::FeatureName::Subgroups);
+        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
+    }
+#endif
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    required_features.push_back(wgpu::FeatureName::TimestampQuery);
+#endif
+
+    wgpu::DeviceDescriptor dev_desc;
+    dev_desc.requiredLimits       = &ctx->webgpu_global_ctx->capabilities.limits;
+    dev_desc.requiredFeatures     = required_features.data();
+    dev_desc.requiredFeatureCount = required_features.size();
+    dev_desc.SetDeviceLostCallback(
+        wgpu::CallbackMode::AllowSpontaneous,
+        [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
+            if (reason == wgpu::DeviceLostReason::Destroyed) {
+                return;
+            }
+            GGML_UNUSED(device);
+            GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason),
+                           std::string(message).c_str());
+        });
+    dev_desc.SetUncapturedErrorCallback(
+        [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
+            GGML_UNUSED(device);
+            GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason),
+                       std::string(message).c_str());
+        });
+
+#ifndef __EMSCRIPTEN__
+    // Enable Dawn-specific toggles to increase native performance
+    // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
+    //       only for native performance?
+    const char * const deviceEnabledToggles[]  = { "skip_validation", "disable_robustness", "disable_workgroup_init",
+                                                   "disable_polyfills_on_integer_div_and_mod" };
+    const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
+    wgpu::DawnTogglesDescriptor deviceTogglesDesc;
+    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;
+    deviceTogglesDesc.enabledToggleCount  = 4;
+    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;
+    deviceTogglesDesc.disabledToggleCount = 1;
+
+    dev_desc.nextInChain = &deviceTogglesDesc;
+#endif
+
+    ctx->webgpu_global_ctx->instance.WaitAny(
+        ctx->webgpu_global_ctx->adapter.RequestDevice(
+            &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
+            [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
+                if (status != wgpu::RequestDeviceStatus::Success) {
+                    GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str());
+                    return;
+                }
+                ctx->webgpu_global_ctx->device = std::move(device);
+            }),
+        UINT64_MAX);
+    GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr);
+
+    ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx);
+    ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES,
+                                                 wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+                                                 wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
+    ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue();
+
+#ifdef GGML_WEBGPU_GPU_PROFILE
+    // Initialize buffer pool for timestamp queries, used for profiling
+    ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(
+        ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
+        wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
+        wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
+#endif
+
+    GGML_LOG_INFO(
+        "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
+        "device_desc: %s\n",
+        info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
+        std::string(info.device).c_str(), std::string(info.description).c_str());
+    return true;
+}
+
+static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
+    ggml_backend_webgpu_device_context * dev_ctx    = (ggml_backend_webgpu_device_context *) dev->context;
+    webgpu_context                       webgpu_ctx = std::make_shared();
+    webgpu_ctx->global_ctx                          = dev_ctx->webgpu_global_ctx;
+    webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device);
+    webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
+                                    wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
+                                    wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true);
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf,
+                              WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf");
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf,
+                              WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
+                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
+
+    ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
+    ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
+    ggml_webgpu_init_rope_pipeline(webgpu_ctx);
+    ggml_webgpu_init_glu_pipeline(webgpu_ctx);
+    ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
+#ifdef GGML_WEBGPU_DEBUG
+    // Initialize debug buffers
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,
+                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
+                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
+    ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf,
+                              WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
+                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
+#endif
+    return webgpu_ctx;
+}
+
+static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) {
     GGML_UNUSED(params);
 
-    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
+    WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()");
 
-    ggml_backend_webgpu_device_context * dev_ctx    = static_cast(dev->context);
-    webgpu_context                       webgpu_ctx = dev_ctx->webgpu_ctx;
+    ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context);
 
-    static ggml_backend_webgpu_context backend_ctx;
-    backend_ctx.name       = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
-    backend_ctx.webgpu_ctx = webgpu_ctx;
+    auto * backend_ctx      = new ggml_backend_webgpu_context();
+    backend_ctx->name       = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
+    backend_ctx->webgpu_ctx = initialize_webgpu_context(dev);
 
     // See GGML Backend Interface section
-    static ggml_backend backend = {
+    auto * backend = new ggml_backend();
+    *backend       = {
         /* .guid      = */ ggml_backend_webgpu_guid(),
         /* .interface = */ ggml_backend_webgpu_i,
         /* .device    = */ dev,
-        /* .context   = */ &backend_ctx,
+        /* .context   = */ backend_ctx,
     };
-    return &backend;
+    return backend;
 }
 
 static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -2549,15 +2962,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm
     static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
         /* .iface = */ {
                         /* .get_name         = */ ggml_backend_webgpu_buffer_type_get_name,
-                        /* .alloc_buffer     = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
-                        /* .get_alignment    = */ ggml_backend_webgpu_buffer_type_get_alignment,
-                        /* .get_max_size     = */ ggml_backend_webgpu_buffer_type_get_max_size,
-                        /* .get_alloc_size   = */ NULL,  // defaults to ggml_nbytes
-            /* .is_host          = */ NULL,  // defaults to false
+                        /* .alloc_buffer     = */
+            ggml_backend_webgpu_buffer_type_alloc_buffer,                                    /* .get_alignment    = */
+            ggml_backend_webgpu_buffer_type_get_alignment,                                   /* .get_max_size     = */
+            ggml_backend_webgpu_buffer_type_get_max_size,                                    /* .get_alloc_size   = */
+            ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host          = */ NULL,  // defaults to false
         },
         /* .device  = */
         dev,
-        /* .context = */ NULL,
+        /* .context = */
+        NULL
     };
 
     return &ggml_backend_webgpu_buffer_type;
@@ -2598,16 +3012,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) {
 static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     ggml_backend_webgpu_device_context * ctx = static_cast(dev->context);
 
-    webgpu_context webgpu_ctx = ctx->webgpu_ctx;
-
     ggml_tensor * src0 = op->src[0];
     ggml_tensor * src1 = op->src[1];
     ggml_tensor * src2 = op->src[2];
 
     // on smaller devices (or CI), tensors may be larger than the max storage buffer size
-    if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
-        (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
-        (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
+        (src0 != nullptr &&
+         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+        (src1 != nullptr &&
+         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
         return false;
     }
 
@@ -2624,23 +3038,30 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-            // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
-            // see https://github.com/ggml-org/llama.cpp/pull/16857
             supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
-                          (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
+                          (src1->type == op->type);
+            break;
+        case GGML_OP_CONCAT:
+            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
+            break;
+        case GGML_OP_REPEAT:
+            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16);
             break;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
-            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
-                          (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+            supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+                           (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
+                          (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
             break;
         case GGML_OP_SET_ROWS:
-            supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
+            supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
+                           (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
             break;
         case GGML_OP_GET_ROWS:
-            if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
-                ggml_webgpu_supported_qtype(src0->type)) {
+            if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) {
                 supports_op = (op->type == GGML_TYPE_F32);
+            } else if (src0->type == GGML_TYPE_I32) {
+                supports_op = op->type == GGML_TYPE_I32;
             }
             break;
         case GGML_OP_MUL_MAT:
@@ -2684,17 +3105,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
             }
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                if (!webgpu_ctx->supports_subgroup_matrix) {
+#ifndef __EMSCRIPTEN__
+                if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) {
                     break;
                 }
                 // Head dimensions must fit in workgroup memory with minimum tile sizes
-                size_t     limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
+                size_t     limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
                 const bool has_mask    = op->src[3] != nullptr;
-                const bool kv_direct   = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
+                const bool kv_direct   = src1->type == GGML_TYPE_F16 &&
+                                       (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 &&
                                        (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
                 const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
-                    webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
-                    has_mask, kv_direct);
+                    ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n,
+                    (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct);
                 if (min_bytes > limit_bytes) {
                     break;
                 }
@@ -2703,6 +3126,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                               (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
                                src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
                               src2->type == src1->type && op->type == GGML_TYPE_F32;
+#endif
                 break;
             }
         case GGML_OP_RMS_NORM:
@@ -2753,9 +3177,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                     case GGML_UNARY_OP_HARDSIGMOID:
                     case GGML_UNARY_OP_EXP:
                     case GGML_UNARY_OP_GELU_ERF:
-                    case GGML_UNARY_OP_XIELU:
+                    case GGML_UNARY_OP_SOFTPLUS:
+                    case GGML_UNARY_OP_EXPM1:
+                    case GGML_UNARY_OP_FLOOR:
                     case GGML_UNARY_OP_CEIL:
-                        supports_op = supports_op =
+                    case GGML_UNARY_OP_ROUND:
+                    case GGML_UNARY_OP_TRUNC:
+                    case GGML_UNARY_OP_XIELU:
+                        supports_op =
                             (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
                         break;
                     default:
@@ -2763,14 +3192,56 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                 }
             }
             break;
-
+        case GGML_OP_CLAMP:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_FILL:
+            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+            break;
+        case GGML_OP_LOG:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_SQR:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_SQRT:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_SIN:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_COS:
+            supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
+            break;
+        case GGML_OP_PAD:
+            supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+            break;
+        case GGML_OP_ARGMAX:
+            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32;
+            break;
+        case GGML_OP_ARGSORT:
+            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
+            break;
+        case GGML_OP_TOP_K:
+            supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0);
+            break;
+        case GGML_OP_CUMSUM:
+            supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type;
+            break;
+        case GGML_OP_SUM:
+        case GGML_OP_SUM_ROWS:
+            supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0);
+            break;
         default:
             break;
     }
-    if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
-        (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
-        (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
-        (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+    if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize ||
+        (src0 != nullptr &&
+         ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+        (src1 != nullptr &&
+         ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) ||
+        (src2 != nullptr &&
+         ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) {
         supports_op = false;
         WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
     }
@@ -2795,7 +3266,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
     /* .get_memory           = */ ggml_backend_webgpu_device_get_memory,
     /* .get_type             = */ ggml_backend_webgpu_device_get_type,
     /* .get_props            = */ ggml_backend_webgpu_device_get_props,
-    /* .init_backend         = */ ggml_backend_webgpu_device_init,
+    /* .init_backend         = */ ggml_backend_webgpu_backend_init,
     /* .get_buffer_type      = */ ggml_backend_webgpu_device_get_buffer_type,
     /* .get_host_buffer_type = */ NULL,
     /* .buffer_from_host_ptr = */ NULL,
@@ -2821,8 +3292,6 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
     return ctx->device_count;
 }
 
-// TODO: Does this need to be thread safe? Is it only called once?
-// TODO: move most logic to device_init function so backend can be freed/initialized properly
 // Only one device is supported for now
 static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
     GGML_ASSERT(index == 0);
@@ -2832,191 +3301,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
 
     ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context);
 
-    webgpu_context ctx = reg_ctx->webgpu_ctx;
-
-    wgpu::RequestAdapterOptions options = {};
-
-#ifndef __EMSCRIPTEN__
-    // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
-    const char * const          adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
-    wgpu::DawnTogglesDescriptor adapterTogglesDesc;
-    adapterTogglesDesc.enabledToggles     = adapterEnabledToggles;
-    adapterTogglesDesc.enabledToggleCount = 2;
-    options.nextInChain                   = &adapterTogglesDesc;
-#endif
-
-    ctx->instance.WaitAny(ctx->instance.RequestAdapter(
-                              &options, wgpu::CallbackMode::AllowSpontaneous,
-                              [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
-                                  if (status != wgpu::RequestAdapterStatus::Success) {
-                                      GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
-                                      return;
-                                  }
-                                  ctx->adapter = std::move(adapter);
-                              }),
-                          UINT64_MAX);
-    GGML_ASSERT(ctx->adapter != nullptr);
-
-    ctx->adapter.GetLimits(&ctx->limits);
-
-    wgpu::AdapterInfo info{};
-#ifndef __EMSCRIPTEN__
-    wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
-    if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
-        info.nextInChain = &subgroup_matrix_configs;
-    }
-#endif
-    ctx->adapter.GetInfo(&info);
-
-    wgpu::SupportedFeatures features;
-    ctx->adapter.GetFeatures(&features);
-    // we require f16 support
-    GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
-
-#ifndef __EMSCRIPTEN__
-    // Only support square f16 matrices of size 8 or 16 for now
-    bool valid_subgroup_matrix_config = false;
-    if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
-        for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
-            const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
-            if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
-                config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
-                config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
-                ctx->sg_mat_m                = config.M;
-                ctx->sg_mat_n                = config.N;
-                ctx->sg_mat_k                = config.K;
-                valid_subgroup_matrix_config = true;
-                break;
-            }
-        }
-    }
-
-    ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
-#endif
-    // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
-    // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
-    ctx->max_subgroup_size = info.subgroupMaxSize;
-
-    // Initialize device
-    std::vector required_features = { wgpu::FeatureName::ShaderF16 };
-
-#ifndef __EMSCRIPTEN__
-    required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
-    if (ctx->supports_subgroup_matrix) {
-        required_features.push_back(wgpu::FeatureName::Subgroups);
-        required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
-    }
-#endif
-
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    required_features.push_back(wgpu::FeatureName::TimestampQuery);
-#endif
-
-    wgpu::DeviceDescriptor dev_desc;
-    dev_desc.requiredLimits       = &ctx->limits;
-    dev_desc.requiredFeatures     = required_features.data();
-    dev_desc.requiredFeatureCount = required_features.size();
-    dev_desc.SetDeviceLostCallback(
-        wgpu::CallbackMode::AllowSpontaneous,
-        [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
-            GGML_UNUSED(device);
-            GGML_UNUSED(reason);
-            GGML_UNUSED(message);
-            //TODO: uncomment once proper free logic is in place
-            //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason),
-            //std::string(message).c_str());
-        });
-    dev_desc.SetUncapturedErrorCallback(
-        [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
-            GGML_UNUSED(device);
-            GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason),
-                       std::string(message).c_str());
-        });
-
-#ifndef __EMSCRIPTEN__
-    // Enable Dawn-specific toggles to increase native performance
-    // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
-    //       only for native performance?
-    const char * const deviceEnabledToggles[]  = { "skip_validation", "disable_robustness", "disable_workgroup_init",
-                                                   "disable_polyfills_on_integer_div_and_mod" };
-    const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
-    wgpu::DawnTogglesDescriptor deviceTogglesDesc;
-    deviceTogglesDesc.enabledToggles      = deviceEnabledToggles;
-    deviceTogglesDesc.enabledToggleCount  = 4;
-    deviceTogglesDesc.disabledToggles     = deviceDisabledToggles;
-    deviceTogglesDesc.disabledToggleCount = 1;
-
-    dev_desc.nextInChain = &deviceTogglesDesc;
-#endif
-
-    ctx->instance.WaitAny(ctx->adapter.RequestDevice(
-                              &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
-                              [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
-                                  if (status != wgpu::RequestDeviceStatus::Success) {
-                                      GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
-                                                     std::string(message).c_str());
-                                      return;
-                                  }
-                                  ctx->device = std::move(device);
-                              }),
-                          UINT64_MAX);
-    GGML_ASSERT(ctx->device != nullptr);
-
-    // Initialize (compute) queue
-    ctx->queue = ctx->device.GetQueue();
-
-    // Create buffer pool for shader parameters
-    ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
-                             wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
-                             wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
-
-#ifdef GGML_WEBGPU_GPU_PROFILE
-    // Initialize buffer pool for timestamp queries (profiling)
-    ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
-                                       WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
-                                       wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
-                                       wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
-#endif
-
-    ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
-                                      wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
-                                      wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
-
-    ggml_webgpu_init_memset_pipeline(ctx);
-    ggml_webgpu_init_mul_mat_pipeline(ctx);
-    ggml_webgpu_init_set_rows_pipeline(ctx);
-    ggml_webgpu_init_get_rows_pipeline(ctx);
-    ggml_webgpu_init_cpy_pipeline(ctx);
-    ggml_webgpu_init_add_pipeline(ctx);
-    ggml_webgpu_init_sub_pipeline(ctx);
-    ggml_webgpu_init_mul_pipeline(ctx);
-    ggml_webgpu_init_div_pipeline(ctx);
-    ggml_webgpu_init_rms_norm_pipeline(ctx);
-    ggml_webgpu_init_rope_pipeline(ctx);
-    ggml_webgpu_init_glu_pipeline(ctx);
-    ggml_webgpu_init_scale_pipeline(ctx);
-    ggml_webgpu_init_soft_max_pipeline(ctx);
-    ggml_webgpu_init_unary_pipeline(ctx);
-
-#ifdef GGML_WEBGPU_DEBUG
-    // Initialize debug buffers
-    ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
-                              wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
-    ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
-                              wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
-#endif
+    create_webgpu_device(reg_ctx);
 
     static ggml_backend_webgpu_device_context device_ctx;
-    device_ctx.webgpu_ctx  = ctx;
-    device_ctx.device_name = GGML_WEBGPU_NAME;
-    device_ctx.device_desc = info.description;
-
-    GGML_LOG_INFO(
-        "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
-        "device_desc: %s\n",
-        info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
-        std::string(info.device).c_str(), std::string(info.description).c_str());
-
+    device_ctx.device_name            = GGML_WEBGPU_NAME;
+    device_ctx.device_desc            = GGML_WEBGPU_NAME;
+    device_ctx.webgpu_global_ctx      = reg_ctx->webgpu_global_ctx;
     // See GGML Backend Device Interface section
     static ggml_backend_device device = {
         /* .iface   = */ ggml_backend_webgpu_device_i,
@@ -3024,7 +3314,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
         /* .context = */ &device_ctx,
     };
 
-    WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
+    WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx);
     return &device;
 }
 
@@ -3040,10 +3330,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
 ggml_backend_reg_t ggml_backend_webgpu_reg() {
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
 
-    webgpu_context webgpu_ctx = std::make_shared();
-
     static ggml_backend_webgpu_reg_context ctx;
-    ctx.webgpu_ctx   = webgpu_ctx;
     ctx.name         = GGML_WEBGPU_NAME;
     ctx.device_count = 1;
 
@@ -3060,15 +3347,17 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
     instance_descriptor.nextInChain        = &instanceTogglesDesc;
 #endif
 
-    webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
+    wgpu::Instance inst             = wgpu::CreateInstance(&instance_descriptor);
+    ctx.webgpu_global_ctx           = webgpu_global_context(new webgpu_global_context_struct());
+    ctx.webgpu_global_ctx->instance = std::move(inst);
 
 #ifdef __EMSCRIPTEN__
-    if (webgpu_ctx->instance == nullptr) {
+    if (ctx.webgpu_global_ctx->instance == nullptr) {
         GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
         return nullptr;
     }
 #endif
-    GGML_ASSERT(webgpu_ctx->instance != nullptr);
+    GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr);
 
     static ggml_backend_reg reg = {
         /* .api_version = */ GGML_BACKEND_API_VERSION,
@@ -3081,7 +3370,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() {
 ggml_backend_t ggml_backend_webgpu_init(void) {
     ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
 
-    return ggml_backend_webgpu_device_init(dev, nullptr);
+    return ggml_backend_webgpu_backend_init(dev, nullptr);
 }
 
 GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
new file mode 100644
index 00000000..ca5bfcc4
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl
@@ -0,0 +1,72 @@
+@group(0) @binding(0)
+#ifdef VEC4
+var src: array>;
+#define VEC_SIZE 4
+#else
+var src: array;
+#define VEC_SIZE 1
+#endif
+
+@group(0) @binding(1)
+var dst: array;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+    ne0: u32,
+};
+
+@group(0) @binding(2)
+var params: Params;
+
+const FLOAT_MIN: f32 = -1.0e9;
+
+struct Pair {
+    value: f32,
+    index: i32
+};
+
+var shared_max: array;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3,
+        @builtin(local_invocation_id) lid: vec3) {
+    let row_idx = params.offset_src + wid.x * params.ne0;
+    var local_pair = Pair(FLOAT_MIN, -1);
+#ifdef VEC4
+    for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
+        let vec_val = src[row_idx / VEC_SIZE + col];
+        for (var v = 0u; v < VEC_SIZE; v++) {
+            let val = vec_val[v];
+            if (val >= local_pair.value) {
+                local_pair = Pair(val, i32(col * VEC_SIZE + v));
+            }
+        }
+    }
+#else
+    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+        if (src[row_idx + col] >= local_pair.value) {
+            local_pair = Pair(src[row_idx + col], i32(col));
+        }
+    }
+#endif
+    shared_max[lid.x] = local_pair;
+    workgroupBarrier();
+    var offset: u32 = WG_SIZE >> 1;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            let a = shared_max[lid.x];
+            let b = shared_max[lid.x + offset];
+            if (b.value > a.value) {
+                shared_max[lid.x] = b;
+            } else if (b.value == a.value && b.index > a.index) {
+                shared_max[lid.x] = b;
+            }
+        }
+        workgroupBarrier();
+        offset >>= 1;
+    }
+    if (lid.x == 0u) {
+        dst[params.offset_dst + wid.x] = shared_max[0].index;
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
new file mode 100644
index 00000000..46ed19fc
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl
@@ -0,0 +1,106 @@
+@group(0) @binding(0)
+var src: array;
+
+@group(0) @binding(1)
+var dst: array;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_dst1: u32,
+    stride_dst2: u32,
+    stride_dst3: u32,
+
+    // src/dst dimensions
+    src_ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    ne0: u32,
+    top_k: u32,
+
+    npr: u32,   // tiles per row
+    nrows: u32
+};
+
+@group(0) @binding(2)
+var params: Params;
+
+var shmem_idx: array;
+
+#if ORDER == 0
+#define EXTREME_VALUE 1e30
+#define SWAP_COMPARE_UP >
+#define SWAP_COMPARE_DOWN <
+#else
+#define EXTREME_VALUE -1e30
+#define SWAP_COMPARE_UP <
+#define SWAP_COMPARE_DOWN >
+#endif
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3,
+        @builtin(num_workgroups) num_wg: vec3,
+        @builtin(local_invocation_id) lid: vec3) {
+    let linear = wid.x + wid.y * num_wg.x;
+    // guard against overprovisioned workgroups
+    if (linear >= params.npr * params.nrows) {
+        return;
+    }
+    let tile = linear % params.npr;
+    var row = linear / params.npr;
+    let i3 = row / (params.ne2 * params.ne1);
+    row = row % (params.ne2 * params.ne1);
+    let i2 = row / params.ne1;
+    let i1 = row % params.ne1;
+
+    let row_base = params.offset_src +
+        i1 * params.stride_src1 +
+        i2 * params.stride_src2 +
+        i3 * params.stride_src3;
+
+    let tile_base = tile * WG_SIZE;
+    let idx = tile_base + lid.x;
+    shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
+    workgroupBarrier();
+
+    var k = 2u;
+    while (k <= WG_SIZE) {
+        var j = k >> 1;
+        while (j > 0) {
+            let ixj = lid.x ^ j;
+            if (ixj > lid.x) {
+                let dir_up = (lid.x & k) == 0;
+                let a_idx = shmem_idx[lid.x];
+                let b_idx = shmem_idx[ixj];
+                let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
+                let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
+                let should_swap = select(
+                    (a_val SWAP_COMPARE_DOWN b_val),
+                    (a_val SWAP_COMPARE_UP b_val),
+                    dir_up);
+                if (should_swap) {
+                    shmem_idx[lid.x] = b_idx;
+                    shmem_idx[ixj] = a_idx;
+                }
+            }
+            workgroupBarrier();
+            j >>= 1;
+        }
+        k <<= 1;
+    }
+
+    let out_idx = tile * params.top_k + lid.x;
+    if (out_idx < params.ne0 && lid.x < params.top_k) {
+        let row_dst = params.offset_dst +
+            i1 * params.stride_dst1 +
+            i2 * params.stride_dst2 +
+            i3 * params.stride_dst3;
+        dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
new file mode 100644
index 00000000..9a77f6ec
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl
@@ -0,0 +1,134 @@
+@group(0) @binding(0)
+var src: array;
+
+@group(0) @binding(1)
+var idx_in: array;
+
+@group(0) @binding(2)
+var idx_out: array;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_in: u32,  // in elements
+    offset_out: u32, // in elements
+
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    stride_idx1: u32,
+    stride_idx2: u32,
+    stride_idx3: u32,
+
+    stride_out1: u32,
+    stride_out2: u32,
+    stride_out3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+
+    top_k: u32,
+
+    len: u32,
+    nm: u32,
+    nrows: u32
+};
+
+@group(0) @binding(3)
+var params: Params;
+
+fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
+    let a_val = src[row_base + u32(a_idx)];
+    let b_val = src[row_base + u32(b_idx)];
+#if ORDER == 0
+    return a_val <= b_val;
+#else
+    return a_val >= b_val;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3,
+        @builtin(num_workgroups) num_wg: vec3,
+        @builtin(local_invocation_id) lid: vec3) {
+    let linear = wid.x + wid.y * num_wg.x;
+    // guard against overprovisioned workgroups
+    if (linear >= params.nm * params.nrows) {
+        return;
+    }
+
+    let start = (linear % params.nm) * params.len * 2;
+    let len0 = min(params.len, params.ne0 - start);
+    let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
+    let len1 = min(params.len, rem1);
+    let total = len0 + len1;
+    let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
+    let k0 = lid.x * chunk;
+    let k1 = min(min(k0 + chunk, total), params.top_k);
+    // guard against overprovisioned threads
+    if (k0 >= params.top_k || k0 >= total) {
+        return;
+    }
+
+    var row = linear / params.nm;
+    let i3 = row / (params.ne2 * params.ne1);
+    row = row % (params.ne2 * params.ne1);
+    let i2 = row / params.ne1;
+    let i1 = row % params.ne1;
+
+    let row_src = params.offset_src +
+        i1 * params.stride_src1 +
+        i2 * params.stride_src2 +
+        i3 * params.stride_src3;
+
+    let row_in = params.offset_in +
+        i1 * params.stride_idx1 +
+        i2 * params.stride_idx2 +
+        i3 * params.stride_idx3;
+
+    let row_out = params.offset_out +
+        i1 * params.stride_out1 +
+        i2 * params.stride_out2 +
+        i3 * params.stride_out3;
+
+
+    var low: u32 = select(0, k0 - len1, k0 > len1);
+    var high: u32 = min(k0, len0);
+
+    while (low < high) {
+        let mid = (low + high) >> 1;
+        let idx0 = idx_in[row_in + start + mid];
+        let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
+        if (take_left(idx0, idx1, row_src)) {
+            low = mid + 1;
+        } else {
+            high = mid;
+        }
+    }
+
+    var i = low;
+    var j = k0 - i;
+    var k = k0;
+    while (k < k1) {
+        var take_l = false;
+        if (i >= len0) {
+            take_l = false;
+        } else if (j >= len1) {
+            take_l = true;
+        } else {
+            let idx0 = idx_in[row_in + start + i];
+            let idx1 = idx_in[row_in + start + params.len + j];
+            take_l = take_left(idx0, idx1, row_src);
+        }
+
+        let out_idx = select(
+            idx_in[row_in + start + params.len + j],
+            idx_in[row_in + start + i],
+            take_l);
+        idx_out[row_out + start + k] = out_idx;
+        i = select(i, i + 1, take_l);
+        j = select(j + 1, j, take_l);
+        k += 1;
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl
deleted file mode 100644
index 1ce4d83f..00000000
--- a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl
+++ /dev/null
@@ -1,188 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "SHADER_NAME": "add_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "+"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "add_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "+"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "add_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "+"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "add_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "+"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "*"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "*"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "*"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "*"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "-"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "-"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "-"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "-"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "/"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "/"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "/"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "/"
-    },
-    "DECLS": ["INPLACE"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(NOT_INPLACE)
-
-fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
-    dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
-}
-
-@group(0) @binding(2)
-var dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-
-fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
-    src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
-}
-
-@group(0) @binding(2)
-var params: Params;
-
-#enddecl(INPLACE)
-
-#end(DECLS)
-
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var src1: array<{{TYPE}}>;
-
-DECLS
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3) {
-    if (gid.x < params.ne) {
-        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
-    }
-}
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl
new file mode 100644
index 00000000..a748dc1b
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl
@@ -0,0 +1,141 @@
+enable f16;
+
+struct Params {
+    ne: u32,
+
+    // offsets in elements
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    offset_merged_src0: u32,
+    offset_merged_src1: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
+
+    stride_src1_0: u32,
+    stride_src1_1: u32,
+    stride_src1_2: u32,
+    stride_src1_3: u32,
+
+    a_ne0: u32,
+    a_ne1: u32,
+    a_ne2: u32,
+
+    b_ne0: u32,
+    b_ne1: u32,
+    b_ne2: u32,
+    b_ne3: u32,
+};
+
+fn src0_index(_i: u32) -> u32 {
+    var i = _i;
+    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    let a_i2 = i / (params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne1 * params.a_ne0);
+    let a_i1 = i / params.a_ne0;
+    let a_i0 = i % params.a_ne0;
+
+    return a_i0 * params.stride_src0_0 +
+           a_i1 * params.stride_src0_1 +
+           a_i2 * params.stride_src0_2 +
+           a_i3 * params.stride_src0_3;
+}
+
+fn src1_index(_i: u32) -> u32 {
+    var i = _i;
+    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    let a_i2 = i / (params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne1 * params.a_ne0);
+    let a_i1 = i / params.a_ne0;
+    let a_i0 = i % params.a_ne0;
+
+    // handle repetition of b
+    // index loops back to the beginning and repeats after elements are exhausted = modulo
+    let b_i0 = a_i0 % params.b_ne0;
+    let b_i1 = a_i1 % params.b_ne1;
+    let b_i2 = a_i2 % params.b_ne2;
+    let b_i3 = a_i3 % params.b_ne3;
+
+    // compute index for position in b's flat array
+    return b_i0 * params.stride_src1_0 +
+           b_i1 * params.stride_src1_1 +
+           b_i2 * params.stride_src1_2 +
+           b_i3 * params.stride_src1_3;
+}
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+#ifdef SRC_OVERLAP
+@group(0) @binding(0)
+var merged_src: array;
+
+@group(0) @binding(1)
+var dst: array;
+
+@group(0) @binding(2)
+var params: Params;
+#else
+@group(0) @binding(0)
+var src0: array;
+
+@group(0) @binding(1)
+var src1 : array;
+#if defined(INPLACE) || defined(OVERLAP)
+@group(0) @binding(2)
+var params: Params;
+
+#else
+@group(0) @binding(2)
+var dst: array;
+
+@group(0) @binding(3)
+var params: Params;
+#endif
+#endif
+
+fn op(a: DataType, b: DataType) -> DataType {
+#ifdef OP_ADD
+    return a + b;
+#elif defined(OP_SUB)
+    return a - b;
+#elif defined(OP_MUL)
+    return a * b;
+#elif defined(OP_DIV)
+    return a / b;
+#endif
+}
+
+fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
+#ifdef SRC_OVERLAP
+    let result = op(merged_src[src0_i], merged_src[src1_i]);
+#else
+    let result = op(src0[src0_i], src1[src1_i]);
+#endif
+
+#ifdef INPLACE
+    src0[src0_i] = result;
+#elif defined(OVERLAP)
+    src1[src1_i] = result;
+#else
+    dst[dst_i] = result;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+    if (gid.x < params.ne) {
+        let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
+        let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
+        update(params.offset_dst + gid.x, src0_i, src1_i);
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl
deleted file mode 100644
index 4b254f46..00000000
--- a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl
+++ /dev/null
@@ -1,45 +0,0 @@
-struct Params {
-    ne: u32,
-
-    // offsets in elements
-    offset_src0: u32,
-    offset_src1: u32,
-    offset_dst: u32,
-
-    stride_src1_0: u32,
-    stride_src1_1: u32,
-    stride_src1_2: u32,
-    stride_src1_3: u32,
-
-    a_ne0: u32,
-    a_ne1: u32,
-    a_ne2: u32,
-
-    b_ne0: u32,
-    b_ne1: u32,
-    b_ne2: u32,
-    b_ne3: u32,
-};
-
-fn src1_index(_i: u32) -> u32 {
-    var i = _i;
-    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
-    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
-    let a_i2 = i / (params.a_ne1 * params.a_ne0);
-    i = i % (params.a_ne1 * params.a_ne0);
-    let a_i1 = i / params.a_ne0;
-    let a_i0 = i % params.a_ne0;
-
-    // handle repetition of b
-    // index loops back to the beginning and repeats after elements are exhausted = modulo
-    let b_i0 = a_i0 % params.b_ne0;
-    let b_i1 = a_i1 % params.b_ne1;
-    let b_i2 = a_i2 % params.b_ne2;
-    let b_i3 = a_i3 % params.b_ne3;
-
-    // compute index for position in b's flat array
-    return b_i0 * params.stride_src1_0 +
-           b_i1 * params.stride_src1_1 +
-           b_i2 * params.stride_src1_2 +
-           b_i3 * params.stride_src1_3;
-}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
index 389c97bb..9a5b18eb 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl
@@ -1,5 +1,4 @@
-#decl(BYTE_HELPERS)
-
+#ifdef BYTE_HELPERS
 fn get_byte(value: u32, index: u32) -> u32 {
     return (value >> (index * 8)) & 0xFF;
 }
@@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
 fn get_byte_i32(value: u32, index: u32) -> i32 {
     return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24;
 }
+#endif
 
-#enddecl(BYTE_HELPERS)
-
-#decl(Q4_0_T)
+#ifdef Q4_0_T
 struct q4_0 {
     d: f16,
     qs: array
 };
-#enddecl(Q4_0_T)
+#endif
 
-#decl(Q4_1_T)
+#ifdef Q4_1_T
 struct q4_1 {
     d: f16,
     m: f16,
     qs: array
 };
-#enddecl(Q4_1_T)
+#endif
 
-#decl(Q5_0_T)
+#ifdef Q5_0_T
 struct q5_0 {
     d: f16,
     qh: array,
     qs: array
 };
-#enddecl(Q5_0_T)
+#endif
 
-#decl(Q5_1_T)
+#ifdef Q5_1_T
 struct q5_1 {
     d: f16,
     m: f16,
     qh: u32,
     qs: array
 };
-#enddecl(Q5_1_T)
+#endif
 
-#decl(Q8_0_T)
+#ifdef Q8_0_T
 struct q8_0 {
     d: f16,
     qs: array
 };
-#enddecl(Q8_0_T)
+#endif
 
-#decl(Q8_1_T)
+#ifdef Q8_1_T
 struct q8_1 {
     d: f16,
     m: f16,
     qs: array
 };
-#enddecl(Q8_1_T)
+#endif
 
-#decl(Q2_K_T)
-struct q2_k {
+#ifdef Q2_K_T
+struct q2_K {
     scales: array,
     qs: array,
     d: f16,
     dmin: f16
 };
-#enddecl(Q2_K_T)
+#endif
 
-#decl(Q3_K_T)
-struct q3_k {
+#ifdef Q3_K_T
+struct q3_K {
     hmask: array,
     qs: array,
     scales: array,
     d: f16
 };
-#enddecl(Q3_K_T)
-
-#decl(Q45_K_SCALE_MIN)
+#endif
 
+#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
 fn get_scale_min(is: u32, scales: array) -> vec2 {
     if (is < 4) {
         let sc_byte = get_byte(scales[is / 4], is % 4);
@@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array) -> vec2 {
         return vec2(f32(sc), f32(m));
     }
 }
-
-#enddecl(Q45_K_SCALE_MIN)
-
-#decl(Q4_K_T)
-struct q4_k {
+#endif
+#ifdef Q4_K_T
+struct q4_K {
     d: f16,
     dmin: f16,
     scales: array,
     qs: array
 };
-#enddecl(Q4_K_T)
+#endif
 
-#decl(Q5_K_T)
-struct q5_k {
+#ifdef Q5_K_T
+struct q5_K {
     d: f16,
     dmin: f16,
     scales: array,
     qh: array,
     qs: array
 };
-#enddecl(Q5_K_T)
+#endif
 
-#decl(Q6_K_T)
-struct q6_k {
+#ifdef Q6_K_T
+struct q6_K {
     ql: array,
     qh: array,
     scales: array,
     d: f16
 };
-#enddecl(Q6_K_T)
+#endif
 
-#decl(IQ2_XXS_T)
+#ifdef IQ2_XXS_T
 struct iq2_xxs {
     d: f16,
     qs: array
 };
-#enddecl(IQ2_XXS_T)
+#endif
 
-#decl(IQ2_XS_T)
+#ifdef IQ2_XS_T
 struct iq2_xs {
     d: f16,
     qs: array,
     scales: array
 };
-#enddecl(IQ2_XS_T)
+#endif
 
-#decl(IQ2_S_T)
+#ifdef IQ2_S_T
 struct iq2_s {
     d: f16,
     qs: array,
     qh: array,
     scales: array
 };
-#enddecl(IQ2_S_T)
+#endif
 
-#decl(IQ3_XSS_T)
+#ifdef IQ3_XXS_T
 struct iq3_xxs {
     d: f16,
     qs: array
 };
-#enddecl(IQ3_XSS_T)
+#endif
 
-#decl(IQ3_S_T)
+#ifdef IQ3_S_T
 struct iq3_s {
     d: f16,
     qs: array,
@@ -161,41 +156,41 @@ struct iq3_s {
     signs: array,
     scales: array
 };
-#enddecl(IQ3_S_T)
+#endif
 
-#decl(IQ1_S_T)
+#ifdef IQ1_S_T
 struct iq1_s {
     d: f16,
     qs: array,
     qh: array
 };
-#enddecl(IQ1_S_T)
+#endif
 
-#decl(IQ1_M_T)
+#ifdef IQ1_M_T
 struct iq1_m {
     qs: array,
     qh: array,
     scales: array
 };
-#enddecl(IQ1_M_T)
+#endif
 
-#decl(IQ4_NL_T)
+#ifdef IQ4_NL_T
 struct iq4_nl {
     d: f16,
     qs: array,
 };
-#enddecl(IQ4_NL_T)
+#endif
 
-#decl(IQ4_XS_T)
+#ifdef IQ4_XS_T
 struct iq4_xs {
     d: f16,
     scales_h: f16,
     scales_l: u32,
     qs: array
 };
-#enddecl(IQ4_XS_T)
+#endif
 
-#decl(IQ23_TABLES)
+#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
 const kmask_iq2xs : array = array(
     0x08040201u, // 1, 2, 4, 8
     0x80402010u  // 16, 32, 64, 128
@@ -211,9 +206,9 @@ const ksigns_iq2xs: array = array(
     0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
     0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
 );
-#enddecl(IQ23_TABLES)
+#endif
 
-#decl(IQ2_XXS_GRID)
+#ifdef IQ2_XXS_GRID
 const iq2xxs_grid = array(
     0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
     0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
@@ -280,9 +275,9 @@ const iq2xxs_grid = array(
     0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
     0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
 );
-#enddecl(IQ2_XXS_GRID)
+#endif
 
-#decl(IQ2_XS_GRID)
+#ifdef IQ2_XS_GRID
 const iq2xs_grid = array(
     0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
     0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -413,9 +408,9 @@ const iq2xs_grid = array(
     0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
     0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
 );
-#enddecl(IQ2_XS_GRID)
+#endif
 
-#decl(IQ2_S_GRID)
+#ifdef IQ2_S_GRID
 const iq2s_grid = array(
     0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
     0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
@@ -674,10 +669,9 @@ const iq2s_grid = array(
     0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
     0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
 );
-#enddecl(IQ2_S_GRID)
-
-#decl(IQ3_XSS_GRID)
+#endif
 
+#ifdef IQ3_XXS_GRID
 const iq3xxs_grid = array(
     0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
     0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
@@ -712,10 +706,9 @@ const iq3xxs_grid = array(
     0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
     0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
 );
-#enddecl(IQ3_XSS_GRID)
-
-#decl(IQ3_S_GRID)
+#endif
 
+#ifdef IQ3_S_GRID
 const iq3s_grid = array(
     0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
     0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
@@ -782,9 +775,9 @@ const iq3s_grid = array(
     0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
     0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
 );
-#enddecl(IQ3_S_GRID)
+#endif
 
-#decl(IQ1_GRID)
+#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
 
 const IQ1_DELTA: f32 = 0.125;
 
@@ -919,12 +912,12 @@ const iq1_grid = array(
     0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
 );
 
-#enddecl(IQ1_GRID)
+#endif
 
-#decl(IQ4_GRID)
+#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
 
 const kvalues_iq4nl = array(
     -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
 );
 
-#enddecl(IQ4_GRID)
+#endif
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl
new file mode 100644
index 00000000..a22d245d
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl
@@ -0,0 +1,75 @@
+struct Params {
+    ne: u32,
+
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
+
+    stride_src1_0: u32,
+    stride_src1_1: u32,
+    stride_src1_2: u32,
+    stride_src1_3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+    ne3: u32,
+
+    dim: u32,
+    src0_nedim: u32
+};
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_I32
+#define DataType i32
+#endif
+
+@group(0) @binding(0)
+var src0: array;
+
+@group(0) @binding(1)
+var src1 : array;
+
+@group(0) @binding(2)
+var dst: array;
+
+@group(0) @binding(3)
+var params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+
+    if (gid.x < params.ne) {
+        var i = gid.x;
+        let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+        i = i % (params.ne2 * params.ne1 * params.ne0);
+        let i2 = i / (params.ne1 * params.ne0);
+        i = i % (params.ne1 * params.ne0);
+        let i1 = i / params.ne0;
+        let i0 = i % params.ne0;
+
+        var ni = array(i0, i1, i2, i3);
+
+        if (ni[params.dim] < params.src0_nedim) {
+            let src_i = ni[0] * params.stride_src0_0 +
+                             ni[1] * params.stride_src0_1 +
+                             ni[2] * params.stride_src0_2 +
+                             ni[3] * params.stride_src0_3;
+            dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
+        } else {
+            ni[params.dim] -= params.src0_nedim;
+            let src_i = ni[0] * params.stride_src1_0 +
+                             ni[1] * params.stride_src1_1 +
+                             ni[2] * params.stride_src1_2 +
+                             ni[3] * params.stride_src1_3;
+            dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
+        }
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
index db1aa349..b5e93b81 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl
@@ -7,6 +7,12 @@
       "DST_TYPE": "f32"
     }
   },
+  {
+    "REPLS": {
+      "SRC_TYPE": "f32",
+      "DST_TYPE": "i32"
+    }
+  },
   {
     "REPLS": {
       "SRC_TYPE": "f32",
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl
new file mode 100644
index 00000000..e622552c
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl
@@ -0,0 +1,66 @@
+@group(0) @binding(0)
+var src: array;
+
+@group(0) @binding(1)
+var dst: array;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+    ne0: u32,
+};
+
+@group(0) @binding(2)
+var params: Params;
+
+var shared_sum: array;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3,
+        @builtin(local_invocation_id) lid: vec3) {
+    let row_idx = params.offset_src + wid.x * params.ne0;
+    let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
+    var local_sum: f32 = 0.0;
+    for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
+        local_sum += src[row_idx + col];
+    }
+    shared_sum[lid.x] = local_sum;
+    workgroupBarrier();
+
+    // upsweep
+    var offset = 1u;
+    while (offset < WG_SIZE) {
+        let idx = (lid.x + 1) * offset * 2 - 1;
+        if (idx < WG_SIZE) {
+            shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];
+        }
+        workgroupBarrier();
+        offset <<= 1;
+    }
+
+    // set last to 0 for exclusive sum
+    if (lid.x == 0) {
+        shared_sum[WG_SIZE - 1] = 0.0;
+    }
+    workgroupBarrier();
+
+    // downsweep
+    offset = WG_SIZE >> 1;
+    while (offset > 0) {
+        let idx = (lid.x + 1) * offset * 2 - 1;
+        if (idx < WG_SIZE) {
+            let t = shared_sum[idx - offset];
+            shared_sum[idx - offset] = shared_sum[idx];
+            shared_sum[idx] = shared_sum[idx] + t;
+        }
+        workgroupBarrier();
+        offset = offset >> 1;
+    }
+
+    // shared_sum[lid] is exclusive prefix sum up to this thread.
+    var running_sum = shared_sum[lid.x];
+    for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
+        running_sum += src[row_idx + col];
+        dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
index d61df5bb..8b5cfe71 100755
--- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
@@ -56,12 +56,46 @@ def expand_includes(shader, input_dir):
     return include_pattern.sub(replacer, shader)
 
 
-def write_shader(shader_name, shader_code, output_dir, outfile):
+def chunk_shader(shader_code, max_chunk_len=60000):
+    """Split shader_code into safe raw-string sized chunks."""
+    return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]
+
+
+def raw_delim(shader_code):
+    """Pick a raw-string delimiter that does not appear in the shader."""
+    delim = "wgsl"
+    while f"){delim}\"" in shader_code:
+        delim += "_x"
+    return delim
+
+
+def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
+    shader_code = expand_includes(shader_code, input_dir)
+
     if output_dir:
         wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
         with open(wgsl_filename, "w", encoding="utf-8") as f_out:
             f_out.write(shader_code)
-    outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
+
+    delim = raw_delim(shader_code)
+    chunks = chunk_shader(shader_code)
+
+    if len(chunks) == 1:
+        outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n')
+    else:
+        for idx, chunk in enumerate(chunks):
+            outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n')
+        outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n')
+        outfile.write('    static const std::string s = []{\n')
+        outfile.write('        std::string tmp;\n')
+        outfile.write(f'        tmp.reserve({len(shader_code)});\n')
+        for idx in range(len(chunks)):
+            outfile.write(f'        tmp.append(wgsl_{shader_name}_part{idx});\n')
+        outfile.write('        return tmp;\n')
+        outfile.write('    }();\n')
+        outfile.write('    return s;\n')
+        outfile.write('}\n')
+        outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
 
 
 def generate_variants(fname, input_dir, output_dir, outfile):
@@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
     try:
         variants = ast.literal_eval(extract_block(text, "VARIANTS"))
     except ValueError:
-        write_shader(shader_base_name, text, output_dir, outfile)
+        write_shader(shader_base_name, text, output_dir, outfile, input_dir)
     else:
         try:
             decls_map = parse_decls(extract_block(text, "DECLS"))
@@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
                 output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
             else:
                 output_name = shader_base_name
-            write_shader(output_name, final_shader, output_dir, outfile)
+            write_shader(output_name, final_shader, output_dir, outfile, input_dir)
 
 
 def main():
@@ -137,7 +171,8 @@ def main():
         os.makedirs(args.output_dir, exist_ok=True)
 
     with open(args.output_file, "w", encoding="utf-8") as out:
-        out.write("// Auto-generated shader embedding\n\n")
+        out.write("// Auto-generated shader embedding\n")
+        out.write("#include \n\n")
         for fname in sorted(os.listdir(args.input_dir)):
             if fname.endswith(".wgsl"):
                 generate_variants(fname, args.input_dir, args.output_dir, out)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
index de7c132a..b6822161 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
@@ -114,7 +114,7 @@ struct Params {
 #define PARAMS_BINDING 4
 #endif
 
-@group(0) @binding(DST_BINDING) var dst: array;
+@group(0) @binding(DST_BINDING) var dst: array>;
 @group(0) @binding(PARAMS_BINDING) var params: Params;
 
 // Just a very small float value.
@@ -160,14 +160,21 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
     return v;
 }
 
+fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 {
+    return (*buf)[scalar_index >> 2u];
+}
+
+fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 {
+    return (*buf)[scalar_index >> 2u];
+}
 
 @compute @workgroup_size(WG_SIZE)
 fn main(@builtin(workgroup_id) wg_id: vec3,
-        @builtin(local_invocation_id) local_id: vec3,
-        @builtin(subgroup_id) subgroup_id: u32,
-        @builtin(subgroup_size) subgroup_size: u32,
-        @builtin(num_subgroups) num_subgroups: u32,
-        @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+    @builtin(local_invocation_id) local_id: vec3,
+    @builtin(subgroup_id) subgroup_id: u32,
+    @builtin(subgroup_size) subgroup_size: u32,
+    @builtin(num_subgroups) num_subgroups: u32,
+    @builtin(subgroup_invocation_id) sg_inv_id: u32) {
 
     // initialize row max for online softmax
     for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
@@ -231,9 +238,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
 
     for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
       // clear inter_shmem to ensure zero-initialized accumulators
-      for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
-          inter_shmem[elem_idx] = 0.0;
-      }
+        for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+            inter_shmem[elem_idx] = 0.0;
+        }
 
       // load k tile into shared memory
 #if defined(KV_Q4_0)
@@ -309,48 +316,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
 
       // accumulate q block * k block into registers across the entire KV tile
       // TODO: this loop seems to be the current largest bottleneck
-      for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
-          let inter_offset = kv_block * SG_MAT_N;
-          var acc: subgroup_matrix_result = subgroupMatrixLoad<
-              subgroup_matrix_result>(&inter_shmem, inter_offset, false, KV_TILE);
+      // this bracket exists to scope the lifetime of variables, reducing register pressure
+      {
 #ifdef KV_DIRECT
-          let k_block_row = kv_tile + kv_block * SG_MAT_N;
-          let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
+          let k_block_row = kv_tile + subgroup_id * SG_MAT_N;
+          var k_global_offset = k_head_offset + k_block_row * params.stride_k1;
 #else
-          let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
+          var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK;
 #endif
-          for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
-              // load q submatrix from shared memory
-              var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>(
-                  &q_shmem,
-                  head_dim_block,
-                  false,
-                  HEAD_DIM_QK
-              );
+          for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
+              let inter_offset = kv_block * SG_MAT_N;
+              var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE);
+
+              var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK);
 
-              // load k submatrix from device or shared memory
 #ifdef KV_DIRECT
-              var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>(
-                  &K,
-                  k_global_offset + head_dim_block,
-                  true,
-                  params.stride_k1
-              );
+              var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1);
 #else
-              var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>(
-                  &kv_shmem,
-                  k_block_offset + head_dim_block,
-                  true,
-                  HEAD_DIM_QK
-              );
+              var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK);
 #endif
-              acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
+
+              var t: u32 = 1u;
+              for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) {
+                  let h0 = t * SG_MAT_K;
+                  var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+                  var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1);
+#else
+                  var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK);
+#endif
+                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+                  q_cur = q0;
+                  k_cur = k0;
+
+                  let h1 = (t + 1u) * SG_MAT_K;
+                  var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+                  var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1);
+#else
+                  var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK);
+#endif
+                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+                  q_cur = q1g;
+                  k_cur = k1g;
+              }
+
+              // handle odd tail
+              if (t < HEAD_DIM_QK / SG_MAT_K) {
+                  let h = t * SG_MAT_K;
+                  var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK);
+#ifdef KV_DIRECT
+                  var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1);
+#else
+                  var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK);
+#endif
+                  acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+                  q_cur = qn;
+                  k_cur = kn;
+              }
+
+              acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc);
+
+#ifdef KV_DIRECT
+              k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1;
+#else
+              k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK;
+#endif
+              subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
           }
-
-          // store acc to shared memory for softmax (S matrix from paper)
-          subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
       }
 
+
 #ifdef MASK
       // load mask tile into shared memory for this KV block
       // TODO: optimize and skip if mask is -INF for the entire tile
@@ -495,7 +531,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
                   false,
                   HEAD_DIM_V
               );
-
               for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
                   let p_offset = kv_block * SG_MAT_N;
                   var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>(
@@ -527,11 +562,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
                   // O += P * V
                   o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
               }
-
               // store O back to shared memory
               subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
       }
-
       workgroupBarrier();
     }
 
@@ -566,26 +599,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
                 o_shmem[idx] = f16(val);
             }
     }
-
     workgroupBarrier();
 #endif
-
-    // write output back to global memory
     for (var q_tile_row = subgroup_id;
-         q_tile_row < Q_TILE;
-         q_tile_row += num_subgroups) {
-            let global_q_row = q_row_start + q_tile_row;
-            if (global_q_row >= params.seq_len_q) {
-                break;
-            }
+        q_tile_row < Q_TILE;
+        q_tile_row += num_subgroups) {
 
-            let exp_sum = exp_sum_shmem[q_tile_row];
-            let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
+        let global_q_row = q_row_start + q_tile_row;
+        if (global_q_row >= params.seq_len_q) { break; }
 
-            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
-                let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
-                let scaled = f32(o_val) * scale;
-                dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
-            }
+        let exp_sum = exp_sum_shmem[q_tile_row];
+        let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
+
+        let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride;
+
+        for (var elem_base = sg_inv_id * 4u;
+            elem_base < HEAD_DIM_V;
+            elem_base += subgroup_size * 4u) {
+
+            let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+            let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+            let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+            let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+            let v = vec4(
+                f32(o_shmem[i0]) * scale,
+                f32(o_shmem[i1]) * scale,
+                f32(o_shmem[i2]) * scale,
+                f32(o_shmem[i3]) * scale
+            );
+
+            let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
+            dst[dst_vec_index] = v;
+        }
     }
 }
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl
similarity index 83%
rename from ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl
rename to ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl
index f80ce1fc..b10800e3 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl
@@ -1,222 +1,31 @@
-#define(VARIANTS)
+enable f16;
+#include "common_decls.tmpl"
 
-[
-  {
-    "SHADER_SUFFIX": "f32_vec",
-    "REPLS": {
-      "TYPE" : "vec4",
-      "DST_TYPE": "vec4",
-      "BLOCK_SIZE": 4
-    },
-    "DECLS": ["F32_VEC"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 1
-    },
-    "DECLS": ["F32"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "f16",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 1
-    },
-    "DECLS": ["F16"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "i32",
-      "DST_TYPE": "i32",
-      "BLOCK_SIZE": 1
-    },
-    "DECLS": ["I32"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q4_0",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q4_1",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q5_0",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q5_1",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q8_0",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q2_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q3_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q4_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q5_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "q6_k",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "iq2_xxs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
-  },
-  {
-    "REPLS": {
-      "TYPE" : "iq2_xs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq2_s",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq3_xxs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq3_s",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq1_s",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq1_m",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq4_nl",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 32,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
-  },
-  {
-    "REPLS": {
-      "TYPE": "iq4_xs",
-      "DST_TYPE": "f32",
-      "BLOCK_SIZE": 256,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(F32_VEC)
+#ifdef F32_VEC
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
 }
-#enddecl(F32_VEC)
+#endif
 
-#decl(F32)
+#ifdef F32
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     dst[dst_base + offset] = src[src_base + offset];
 }
-#enddecl(F32)
+#endif
 
-#decl(F16)
+#ifdef F16
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     dst[dst_base + offset] = f32(src[src_base + offset]);
 }
-#enddecl(F16)
+#endif
 
-#decl(I32)
+#ifdef I32
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     dst[dst_base + offset] = src[src_base + offset];
 }
-#enddecl(I32)
+#endif
 
-#decl(Q4_0)
+#ifdef Q4_0
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block_q4_0 = src[src_base + offset];
     let d = f32(block_q4_0.d);
@@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q4_0)
+#endif
 
-#decl(Q4_1)
+#ifdef Q4_1
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block_q4_1 = src[src_base + offset];
     let d = f32(block_q4_1.d);
@@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q4_1)
+#endif
 
-#decl(Q5_0)
+#ifdef Q5_0
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block_q5_0 = src[src_base + offset];
     let d = f32(block_q5_0.d);
@@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
+#endif
 
-#enddecl(Q5_0)
-
-#decl(Q5_1)
+#ifdef Q5_1
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block_q5_1 = src[src_base + offset];
     let d = f32(block_q5_1.d);
@@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q5_1)
+#endif
 
-#decl(Q8_0)
+#ifdef Q8_0
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block_q8_0 = src[src_base + offset];
     let d = f32(block_q8_0.d);
@@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q8_0)
+#endif
 
-#decl(Q2_K)
+#ifdef Q2_K
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q2_K)
+#endif
 
-#decl(Q3_K)
+#ifdef Q3_K
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q3_K)
+#endif
 
-#decl(Q4_K)
+#ifdef Q4_K
 // 8 blocks of 32 elements each
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
@@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q4_K)
+#endif
 
-#decl(Q5_K)
+#ifdef Q5_K
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(Q5_K)
+#endif
 
-#decl(Q6_K)
+#ifdef Q6_K
 // 16 blocks of 16 elements each
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
@@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         sc_b_idx += 8;
     }
 }
+#endif
 
-#enddecl(Q6_K)
-
-#decl(IQ2_XXS)
+#ifdef IQ2_XXS
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(IQ2_XXS)
+#endif
 
-#decl(IQ2_XS)
+#ifdef IQ2_XS
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(IQ2_XS)
+#endif
 
-#decl(IQ2_S)
+#ifdef IQ2_S
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
+#endif
 
-#enddecl(IQ2_S)
-
-#decl(IQ3_XSS)
+#ifdef IQ3_XXS
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(IQ3_XSS)
+#endif
 
-#decl(IQ3_S)
+#ifdef IQ3_S
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
-#enddecl(IQ3_S)
+#endif
 
-#decl(IQ1_S)
+#ifdef IQ1_S
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
+#endif
 
-#enddecl(IQ1_S)
-
-#decl(IQ1_M)
+#ifdef IQ1_M
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
 
@@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         }
     }
 }
+#endif
 
-#enddecl(IQ1_M)
-
-#decl(IQ4_NL)
+#ifdef IQ4_NL
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         dst_i++;
     }
 }
-#enddecl(IQ4_NL)
+#endif
 
-#decl(IQ4_XS)
+#ifdef IQ4_XS
 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
     let block = src[src_base + offset];
     let d = f32(block.d);
@@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
         dst_i += 16;
     }
 }
-#enddecl(IQ4_XS)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-DECLS
+#endif
 
 @group(0) @binding(0)
-var src: array<{{TYPE}}>;
+var src: array;
 
 @group(0) @binding(1)
 var idx: array;
 
 @group(0) @binding(2)
-var dst: array<{{DST_TYPE}}>;
+var dst: array;
 
 struct Params {
     offset_src: u32, // in elements
@@ -842,8 +638,7 @@ struct Params {
 @group(0) @binding(3)
 var params: Params;
 
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
+@compute @workgroup_size(WG_SIZE)
 fn main(@builtin(global_invocation_id) gid: vec3) {
     if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
         return;
@@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) {
     let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
     let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
 
-    for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
+    for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
       copy_elements(i_src_row, i_dst_row, i);
     }
 }
 
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
similarity index 84%
rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
index 0f8e6e5a..5b9f5b36 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl
@@ -1,195 +1,24 @@
-#define(VARIANTS)
+enable f16;
 
-[
-  {
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "BLOCK_SIZE" : 1
-    },
-    "DECLS" : ["FLOAT"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "BLOCK_SIZE" : 1
-    },
-    "DECLS" : ["FLOAT"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "BLOCK_SIZE" : 1
-    },
-    "DECLS" : ["FLOAT"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q4_0",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q4_1",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q5_0",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q5_1",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q8_0",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32
-    },
-    "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q2_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q3_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q4_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q5_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "q6_k",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq2_xxs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq2_xs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq2_s",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq3_xxs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq3_s",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq1_s",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq1_m",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq4_nl",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 32,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
-  },
-  {
-    "REPLS": {
-      "SRC0_TYPE": "iq4_xs",
-      "SRC1_TYPE": "f32",
-      "BLOCK_SIZE": 256,
-    },
-    "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
-  }
-]
+#include "common_decls.tmpl"
 
-#end(VARIANTS)
+#ifdef FLOAT
+const BLOCK_SIZE = 1u;
 
-#define(DECLS)
+#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
+const BLOCK_SIZE = 32u;
 
-#decl(FLOAT)
+#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
+const BLOCK_SIZE = 256u;
+#endif
+
+#ifdef FLOAT
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
 }
-#enddecl(FLOAT)
+#endif
 
-#decl(Q4_0)
+#ifdef Q4_0
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block_q4_0 = src0[src0_idx_base + offset];
     let d = f32(block_q4_0.d);
@@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(Q4_0)
+#endif
 
-#decl(Q4_1)
+#ifdef Q4_1
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block_q4_1 = src0[src0_idx_base + offset];
     let d = f32(block_q4_1.d);
@@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(Q4_1)
+#endif
 
-#decl(Q5_0)
+#ifdef Q5_0
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block_q5_0 = src0[src0_idx_base + offset];
     let d = f32(block_q5_0.d);
@@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(Q5_0)
+#endif
 
-#decl(Q5_1)
+#ifdef Q5_1
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block_q5_1 = src0[src0_idx_base + offset];
     let d = f32(block_q5_1.d);
@@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(Q5_1)
+#endif
 
-#decl(Q8_0)
+#ifdef Q8_0
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block_q8_0 = src0[src0_idx_base + offset];
     let d = f32(block_q8_0.d);
@@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(Q8_0)
+#endif
 
-#decl(Q8_1)
+#ifdef Q8_1
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block_q8_1 = src0[src0_idx_base + offset];
     let d = f32(block_q8_1.d);
@@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(Q8_1)
+#endif
 
-#decl(Q2_K)
+#ifdef Q2_K
 // 16 blocks of 16 elements each
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
@@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(Q2_K)
-
-#decl(Q3_K)
+#ifdef Q3_K
 // 16 blocks of 16 elements each
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
@@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(Q3_K)
-
-#decl(Q4_K)
+#ifdef Q4_K
 // 8 blocks of 32 elements each
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
@@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(Q4_K)
-
-#decl(Q5_K)
+#ifdef Q5_K
 // 8 blocks of 32 elements each
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
@@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(Q5_K)
-
-#decl(Q6_K)
+#ifdef Q6_K
 // 16 blocks of 16 elements each
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
@@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(Q6_K)
-
-#decl(IQ2_XXS)
+#ifdef IQ2_XXS
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(IQ2_XXS)
-
-#decl(IQ2_XS)
+#ifdef IQ2_XS
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(IQ2_XS)
-
-#decl(IQ2_S)
+#ifdef IQ2_S
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-
-#enddecl(IQ2_S)
-
-#decl(IQ3_XSS)
+#ifdef IQ3_XXS
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(IQ3_XSS)
-
-#decl(IQ3_S)
+#ifdef IQ3_S
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-#enddecl(IQ3_S)
+#endif
 
-#decl(IQ1_S)
+#ifdef IQ1_S
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(IQ1_S)
 
-#decl(IQ1_M)
+#ifdef IQ1_M
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
 
@@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(IQ1_M)
-
-#decl(IQ4_NL)
+#ifdef IQ4_NL
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
+#endif
 
-#enddecl(IQ4_NL)
-
-#decl(IQ4_XS)
+#ifdef IQ4_XS
 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     let block = src0[src0_idx_base + offset];
     let d = f32(block.d);
@@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
     }
     return sum;
 }
-
-#enddecl(IQ4_XS)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-DECLS
+#endif
 
 struct MulMatParams {
     offset_src0: u32, // in elements/blocks
@@ -864,26 +672,31 @@ struct MulMatParams {
     broadcast3: u32
 };
 
-@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns
-@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
+@group(0) @binding(0) var src0: array; // M rows, K columns
+@group(0) @binding(1) var src1: array; // K rows, N columns (transposed)
 @group(0) @binding(2) var dst: array; // M rows, N columns
 
 @group(0) @binding(3) var params: MulMatParams;
 
 @compute @workgroup_size(256)
-fn main(@builtin(global_invocation_id) global_id: vec3) {
+fn main(@builtin(local_invocation_id) local_id: vec3,
+        @builtin(workgroup_id) wg_id: vec3,
+        @builtin(num_workgroups) num_wg: vec3) {
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+    let global_idx = wg_linear * 256u + local_id.x;
+
     let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
-    if (global_id.x >= total) {
+    if (global_idx >= total) {
         return;
     }
 
     let dst2_stride = params.m * params.n;
     let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
 
-    let dst3_idx = global_id.x / dst3_stride;
+    let dst3_idx = global_idx / dst3_stride;
     let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension
     let src13_idx = dst3_idx; // src1 is not broadcast
-    let dst3_rem = global_id.x % dst3_stride;
+    let dst3_rem = global_idx % dst3_stride;
 
     let dst2_idx = dst3_rem / dst2_stride;
     let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension
@@ -898,10 +711,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) {
     let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
 
     var sum = 0.0;
-    for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
+    for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
         sum += multiply_add(src0_idx_base, src1_idx_base, i);
     }
     dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
 }
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
index 109ff8d6..de60ebbc 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
@@ -1,58 +1,65 @@
-#decl(SHMEM_VEC)
+#ifdef VEC
+#define VEC_SIZE 4
+#define SHMEM_TYPE vec4
+#define DST_TYPE vec4
+#define SRC0_TYPE vec4
+#define SRC1_TYPE vec4
+
 fn store_shmem(val: vec4, idx: u32) {
     shmem[idx] = val.x;
     shmem[idx + 1] = val.y;
     shmem[idx + 2] = val.z;
     shmem[idx + 3] = val.w;
 }
-#enddecl(SHMEM_VEC)
+#endif // VEC
+
+#ifdef SCALAR
+#define VEC_SIZE 1
+#define SHMEM_TYPE f16
+#define DST_TYPE f32
+#define SRC0_TYPE SRC0_INNER_TYPE
+#define SRC1_TYPE SRC1_INNER_TYPE
 
-#decl(SHMEM_SCALAR)
 fn store_shmem(val: f16, idx: u32) {
     shmem[idx] = val;
 }
-#enddecl(SHMEM_SCALAR)
-
-#decl(INIT_SRC0_SHMEM_FLOAT)
+#endif // SCALAR
 
+#ifdef INIT_SRC0_SHMEM_FLOAT
 fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
-    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+    for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
         let tile_m = elem_idx / TILE_K;
         let tile_k = elem_idx % TILE_K;
         let global_m = offset_m + tile_m;
         let global_k = k_outer + tile_k;
         let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
         let src0_val = select( // taking a slight performance hit to avoid oob
-            {{SRC0_TYPE}}(0.0),
-            src0[src0_idx/{{VEC_SIZE}}],
+            SRC0_TYPE(0.0),
+            src0[src0_idx/VEC_SIZE],
             global_m < params.m && global_k < params.k);
-        store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
+        store_shmem(SHMEM_TYPE(src0_val), elem_idx);
     }
 }
+#endif // INIT_SRC0_SHMEM_FLOAT
 
-#enddecl(INIT_SRC0_SHMEM_FLOAT)
-
-#decl(INIT_SRC1_SHMEM)
-
+#ifdef INIT_SRC1_SHMEM_FLOAT
 fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
-    for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+    for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
         let tile_n = elem_idx / TILE_K;
         let tile_k = elem_idx % TILE_K;
         let global_n = offset_n + tile_n;
         let global_k = k_outer + tile_k;
         let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
         let src1_val = select(
-            {{SRC1_TYPE}}(0.0),
-            src1[src1_idx/{{VEC_SIZE}}],
+            SRC1_TYPE(0.0),
+            src1[src1_idx/VEC_SIZE],
             global_n < params.n && global_k < params.k);
-        store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
+        store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
     }
 }
+#endif // INIT_SRC1_SHMEM_FLOAT
 
-#enddecl(INIT_SRC1_SHMEM)
-
-#decl(INIT_SRC0_SHMEM_Q4_0)
-
+#ifdef INIT_SRC0_SHMEM_Q4_0
 const BLOCK_SIZE = 32u;
 // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
 override BLOCKS_K = TILE_K/BLOCK_SIZE;
@@ -93,5 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
         }
     }
 }
+#endif // INIT_SRC0_SHMEM_Q4_0
 
-#enddecl(INIT_SRC0_SHMEM_Q4_0)
+#ifdef INIT_SRC0_SHMEM_Q4_1
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+            let m = src0[scale_idx + 1u];
+
+            for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                let q_0 = src0[scale_idx + 2u + block_offset + j];
+                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+
+                let q_packed = bitcast(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+                    let q_lo = f16(q_byte & 0xF) * d + m;
+                    let q_hi = f16((q_byte >> 4) & 0xF) * d + m;
+                    shmem[shmem_idx + j * 2 + k] = q_lo;
+                    shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q4_1
+
+#ifdef INIT_SRC0_SHMEM_Q5_0
+// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+// tile_k is defined as 32u, so blocks_k ends up being 1 always
+override BLOCKS_K = TILE_K / BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx    = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx   = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m   = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k  = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+
+            let d  = src0[scale_idx];
+            let qh0 = src0[scale_idx + 1u];
+            let qh1 = src0[scale_idx + 2u];
+            let qh_packed = bitcast(vec2(qh0, qh1));
+
+            for (var j = 0u; j < 2; j++) {
+                let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
+                let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
+
+                let q_packed = bitcast(vec2(q_0, q_1));
+
+                let j_adjusted = j + (block_offset / 2u);
+
+
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+
+                    let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                    let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+                    let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                    let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d;
+
+                    shmem[shmem_idx + j * 4u + k]        = q_lo; // store first weight
+                    shmem[shmem_idx + j * 4u + k + 16u]  = q_hi; // store second weight
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q5_0
+
+#ifdef INIT_SRC0_SHMEM_Q5_1
+// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+// tile_k is defined as 32u, so blocks_k ends up being 1 always
+override BLOCKS_K = TILE_K / BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx    = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx   = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m   = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k  = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx  = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+
+            let d  = src0[scale_idx];
+            let m = src0[scale_idx + 1u];
+            let qh0 = src0[scale_idx + 2u];
+            let qh1 = src0[scale_idx + 3u];
+            let qh_packed = bitcast(vec2(qh0, qh1));
+
+            for (var j = 0u; j < 2; j++) {
+
+                let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
+                let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
+
+                let q_packed = bitcast(vec2(q_0, q_1));
+
+                let j_adjusted = j + (block_offset / 2u);
+
+
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte(q_packed, k);
+
+                    let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                    let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m;
+                    let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                    let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m;
+
+                    shmem[shmem_idx + j * 4u + k]        = q_lo; // store first weight
+                    shmem[shmem_idx + j * 4u + k + 16u]  = q_hi; // store second weight
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q5_1
+
+#ifdef INIT_SRC0_SHMEM_Q8_0
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights
+const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+
+            for (var j = 0u; j < F16_PER_THREAD; j+=2) {
+                let q_0 = src0[scale_idx + 1u + block_offset + j];
+                let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
+
+                let q_packed = bitcast(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte_i32(q_packed, k);
+
+                    let q_val = f16(q_byte) * d;
+                    shmem[shmem_idx + j * 2 + k] = q_val;
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q8_0
+
+#ifdef INIT_SRC0_SHMEM_Q8_1
+const BLOCK_SIZE = 32u;
+// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
+override BLOCKS_K = TILE_K/BLOCK_SIZE;
+const NQ = 16u;
+const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights
+const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+
+        let tile_m = blck_idx / BLOCKS_K;
+        let global_m = offset_m + tile_m;
+        let block_k = blck_idx % BLOCKS_K;
+        let global_k = k_outer / BLOCK_SIZE + block_k;
+
+        if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
+            let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
+            let scale_idx = src0_idx * F16_PER_BLOCK;
+            let d = src0[scale_idx];
+            let m = src0[scale_idx + 1u];
+
+            for (var j = 0u; j < F16_PER_THREAD; j+=2) {
+                let q_0 = src0[scale_idx + 2u + block_offset + j];
+                let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+
+                let q_packed = bitcast(vec2(q_0, q_1));
+                for (var k = 0u; k < 4u; k++) {
+                    let q_byte = get_byte_i32(q_packed, k);
+
+                    let q_val = f16(q_byte) * d + m;
+                    shmem[shmem_idx + j * 2 + k] = q_val;
+                }
+            }
+        }
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q8_1
+
+#ifdef INIT_SRC0_SHMEM_Q2_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 42u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    // Use standard thread layout instead of lane/row_group
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx + 40u];
+        let dmin = src0[scale_idx + 41u];
+
+        // Decode the element at position k_in_block
+        let block_of_32 = k_in_block / 32u;
+        let pos_in_32 = k_in_block % 32u;
+
+        let q_b_idx = (block_of_32 / 4u) * 32u;
+        let shift = (block_of_32 % 4u) * 2u;
+        let k = (pos_in_32 / 16u) * 16u;
+        let l = pos_in_32 % 16u;
+
+        let is = k_in_block / 16u;
+
+        let sc_0 = src0[scale_idx + 2u * (is / 4u)];
+        let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u];
+        let sc_packed = bitcast(vec2(sc_0, sc_1));
+        let sc = get_byte(sc_packed, is % 4u);
+
+        let dl = d * f16(sc & 0xFu);
+        let ml = dmin * f16(sc >> 4u);
+
+        let q_idx = q_b_idx + k + l;
+        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
+        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
+        let q_packed = bitcast(vec2(q_0, q_1));
+        let q_byte = get_byte(q_packed, q_idx % 4u);
+        let qs_val = (q_byte >> shift) & 3u;
+
+        let q_val = f16(qs_val) * dl - ml;
+        shmem[elem_idx] = q_val;
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q2_K
+
+#ifdef INIT_SRC0_SHMEM_Q3_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 55u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx + 54u];
+
+        // Load and unpack scales
+        let kmask1: u32 = 0x03030303u;
+        let kmask2: u32 = 0x0f0f0f0fu;
+
+        var scale_vals: array;
+        for (var i: u32 = 0u; i < 4u; i++) {
+            let scale_0 = src0[scale_idx + 48u + (2u*i)];
+            let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u];
+            scale_vals[i] = bitcast(vec2(scale_0, scale_1));
+        }
+
+        var tmp: u32 = scale_vals[2];
+        scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u);
+        scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u);
+        scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u);
+        scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u);
+
+        // Load hmask and qs arrays
+        var hmask_vals: array;
+        for (var i: u32 = 0u; i < 8u; i++) {
+            let hmask_0 = src0[scale_idx + (2u*i)];
+            let hmask_1 = src0[scale_idx + (2u*i) + 1u];
+            hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1));
+        }
+
+        var qs_vals: array;
+        for (var i: u32 = 0u; i < 16u; i++) {
+            let qs_0 = src0[scale_idx + 16u + (2u*i)];
+            let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u];
+            qs_vals[i] = bitcast(vec2(qs_0, qs_1));
+        }
+
+        let half = k_in_block / 128u;           // 0 or 1
+        let pos_in_half = k_in_block % 128u;    // 0-127
+        let shift_group = pos_in_half / 32u;    // 0-3
+        let pos_in_32 = pos_in_half % 32u;      // 0-31
+        let k_group = pos_in_32 / 16u;          // 0 or 1
+        let l = pos_in_32 % 16u;                // 0-15
+
+        let q_b_idx = half * 32u;               // 0 or 32
+        let shift = shift_group * 2u;           // 0, 2, 4, 6
+        let k = k_group * 16u;                  // 0 or 16
+        let is = k_in_block / 16u;              // 0-15
+
+        // m increments every 32 elements across entire 256 element block
+        let m_shift = k_in_block / 32u;         // 0-7
+        let m: u32 = 1u << m_shift;             // 1,2,4,8,16,32,64,128
+
+        let sc = get_byte(scale_vals[is / 4u], is % 4u);
+        let dl = d * (f16(sc) - 32.0);
+
+        let q_idx = q_b_idx + k + l;
+        let hm_idx = k + l;
+
+        let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u);
+        let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u);
+
+        let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
+        let qs_val = (q_byte >> shift) & 3u;
+
+        let q_val = (f16(qs_val) - f16(hm)) * dl;
+        shmem[elem_idx] = q_val;
+    }
+}
+
+#endif // INIT_SRC0_SHMEM_Q3_K
+
+#ifdef INIT_SRC0_SHMEM_Q4_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 72u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx];
+        let dmin = src0[scale_idx + 1u];
+
+        // Load packed scales
+        var scale_vals: array;
+        for (var i: u32 = 0u; i < 3u; i++) {
+            let scale_0 = src0[scale_idx + 2u + (2u*i)];
+            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
+            scale_vals[i] = bitcast(vec2(scale_0, scale_1));
+        }
+
+        // Map k_in_block to loop structure:
+        // Outer loop over 64-element groups (alternating q_b_idx)
+        // Inner loop over 2 shifts per group
+        let group_of_64 = k_in_block / 64u;  // 0-3 (maps to q_b_idx)
+        let pos_in_64 = k_in_block % 64u;    // 0-63
+        let shift_group = pos_in_64 / 32u;   // 0 or 1
+        let l = pos_in_64 % 32u;             // 0-31
+
+        let q_b_idx = group_of_64 * 32u;     // 0, 32, 64, 96
+        let shift = shift_group * 4u;        // 0 or 4
+        let is = k_in_block / 32u;           // 0-7
+
+        var sc: u32;
+        var mn: u32;
+
+        if (is < 4u) {
+            let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
+            let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
+            sc = sc_byte & 63u;
+            mn = min_byte & 63u;
+        } else {
+            let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
+            let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
+            let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
+
+            sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
+            mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
+        }
+
+        let dl = d * f16(sc);
+        let ml = dmin * f16(mn);
+
+        let q_idx = q_b_idx + l;
+        let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)];
+        let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u];
+        let q_packed = bitcast(vec2(q_0, q_1));
+
+        let q_byte = get_byte(q_packed, q_idx % 4u);
+        let qs_val = (q_byte >> shift) & 0xFu;
+
+        let q_val = f16(qs_val) * dl - ml;
+        shmem[elem_idx] = q_val;
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q4_K
+
+#ifdef INIT_SRC0_SHMEM_Q5_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 88u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let d = src0[scale_idx];
+        let dmin = src0[scale_idx + 1u];
+
+        // Load packed scales
+        var scale_vals: array;
+        for (var i: u32 = 0u; i < 3u; i++) {
+            let scale_0 = src0[scale_idx + 2u + (2u*i)];
+            let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u];
+            scale_vals[i] = bitcast(vec2(scale_0, scale_1));
+        }
+
+        // The original loop processes elements in groups of 64
+        // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4]
+        // But u increments EVERY 32 elements (after each l loop)
+        let group_of_64 = k_in_block / 64u;  // 0-3
+        let pos_in_64 = k_in_block % 64u;    // 0-63
+        let shift_group = pos_in_64 / 32u;   // 0 or 1
+        let l = pos_in_64 % 32u;             // 0-31
+
+        let q_b_idx = group_of_64 * 32u;     // 0, 32, 64, 96
+        let shift = shift_group * 4u;        // 0 or 4
+        let is = k_in_block / 32u;           // 0-7
+
+        // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128)
+        let u_shift = k_in_block / 32u;      // 0-7
+        let u: u32 = 1u << u_shift;
+
+        var sc: u32;
+        var mn: u32;
+
+        if (is < 4u) {
+            let sc_byte = get_byte(scale_vals[is / 4u], is % 4u);
+            let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u);
+            sc = sc_byte & 63u;
+            mn = min_byte & 63u;
+        } else {
+            let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u);
+            let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u);
+            let min_hi = get_byte(scale_vals[is / 4u], is % 4u);
+
+            sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u);
+            mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u);
+        }
+
+        let dl = d * f16(sc);
+        let ml = dmin * f16(mn);
+
+        let q_idx = q_b_idx + l;
+        let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)];
+        let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u];
+        let q_packed = bitcast(vec2(q_0, q_1));
+
+        let q_byte = get_byte(q_packed, q_idx % 4u);
+
+        let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)];
+        let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u];
+        let qh_packed = bitcast(vec2(qh_0, qh_1));
+
+        let qh_byte = get_byte(qh_packed, l % 4u);
+
+        let qs_val = (q_byte >> shift) & 0xFu;
+        let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
+
+        let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml;
+        shmem[elem_idx] = q_val;
+    }
+}
+
+#endif // INIT_SRC0_SHMEM_Q5_K
+
+#ifdef INIT_SRC0_SHMEM_Q6_K
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 105u;
+
+fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
+    for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) {
+        let tile_m = elem_idx / TILE_K;
+        let tile_k = elem_idx % TILE_K;
+
+        let global_m = offset_m + tile_m;
+        let global_k = k_outer + tile_k;
+
+        if (global_m >= params.m || global_k >= params.k) {
+            shmem[elem_idx] = f16(0.0);
+            continue;
+        }
+
+        let block_k = global_k / BLOCK_SIZE;
+        let k_in_block = global_k % BLOCK_SIZE;
+
+        let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
+        let scale_idx = src0_idx * F16_PER_BLOCK;
+
+        let half = k_in_block / 128u;
+        let pos_in_half = k_in_block % 128u;
+        let quarter = pos_in_half / 32u;
+        let l = pos_in_half % 32u;
+
+        let ql_b_idx = half * 64u;
+        let qh_b_idx = half * 32u;
+        let sc_b_idx = half * 8u;
+
+        // Load only ql13 word needed
+        let ql13_flat = ql_b_idx + l;
+        let ql13_word = ql13_flat / 4u;
+        let ql13 = bitcast(vec2(
+            src0[scale_idx + 2u * ql13_word],
+            src0[scale_idx + 2u * ql13_word + 1u]
+        ));
+        let ql13_b = get_byte(ql13, ql13_flat % 4u);
+
+        // Load only ql24 word needed
+        let ql24_flat = ql_b_idx + l + 32u;
+        let ql24_word = ql24_flat / 4u;
+        let ql24 = bitcast(vec2(
+            src0[scale_idx + 2u * ql24_word],
+            src0[scale_idx + 2u * ql24_word + 1u]
+        ));
+        let ql24_b = get_byte(ql24, ql24_flat % 4u);
+
+        // Load only qh word needed
+        let qh_flat = qh_b_idx + l;
+        let qh_word = qh_flat / 4u;
+        let qh = bitcast(vec2(
+            src0[scale_idx + 64u + 2u * qh_word],
+            src0[scale_idx + 64u + 2u * qh_word + 1u]
+        ));
+        let qh_b = get_byte(qh, qh_flat % 4u);
+
+        let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0);
+        let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0);
+        let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0);
+        let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0);
+
+        // Load only the scale word needed
+        let is = l / 16u;
+        let sc_idx = sc_b_idx + is + quarter * 2u;
+        let sc_word = sc_idx / 4u;
+        let sc = bitcast(vec2(
+            src0[scale_idx + 96u + 2u * sc_word],
+            src0[scale_idx + 96u + 2u * sc_word + 1u]
+        ));
+        let sc_val = get_byte_i32(sc, sc_idx % 4u);
+
+        let d = src0[scale_idx + 104u];
+
+        var q_val: f16;
+        if (quarter == 0u) {
+            q_val = q1;
+        } else if (quarter == 1u) {
+            q_val = q2;
+        } else if (quarter == 2u) {
+            q_val = q3;
+        } else {
+            q_val = q4;
+        }
+
+        shmem[elem_idx] = d * f16(sc_val) * q_val;
+    }
+}
+#endif // INIT_SRC0_SHMEM_Q6_K
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
similarity index 53%
rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
index 6b1dd26c..b1da421a 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl
@@ -1,115 +1,19 @@
-#define(VARIANTS)
-[
-  {
-    "SHADER_SUFFIX": "f32_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f32_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  }
-]
+enable f16;
 
-#end(VARIANTS)
+#include "common_decls.tmpl"
+#include "mul_mat_decls.tmpl"
 
-#define(DECLS)
-
-#decl(VEC)
+#ifdef VEC
 fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 {
     return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
 }
-#enddecl(VEC)
+#endif
 
-#decl(SCALAR)
+#ifdef SCALAR
 fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 {
     return f32(acc[tm][tn]);
 }
-#enddecl(SCALAR)
-
-#end(DECLS)
-
-#define(SHADER)
-enable f16;
+#endif
 
 struct MulMatParams {
     offset_src0: u32,
@@ -130,14 +34,12 @@ struct MulMatParams {
     broadcast3: u32
 };
 
-@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns
-@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
-@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
+@group(0) @binding(0) var src0: array; // M rows, K columns
+@group(0) @binding(1) var src1: array; // K rows, N columns (transposed)
+@group(0) @binding(2) var dst: array; // M rows, N columns (transposed)
 
 @group(0) @binding(3) var params: MulMatParams;
 
-DECLS
-
 fn get_local_n(thread_id: u32) -> u32 {
     return thread_id / WORKGROUP_SIZE_M;
 }
@@ -145,23 +47,16 @@ fn get_local_m(thread_id: u32) -> u32 {
     return thread_id % WORKGROUP_SIZE_M;
 }
 
-// TILE_M must be multiple of 4 for vec4 loads
-const TILE_M = {{WEBGPU_TILE_M}}u;
-const TILE_N = {{WEBGPU_TILE_N}}u;
-
-override WORKGROUP_SIZE_M: u32;
-override WORKGROUP_SIZE_N: u32;
-override TILE_K: u32;
-
-override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
-override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
-override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
+const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
+const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
+const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
 
 var shmem: array;
 
 @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
 fn main(@builtin(workgroup_id) wg_id: vec3,
-        @builtin(local_invocation_id) local_id: vec3) {
+        @builtin(local_invocation_id) local_id: vec3,
+        @builtin(num_workgroups) num_wg: vec3) {
 
     let thread_id = local_id.x;
     let local_m = get_local_m(thread_id);
@@ -171,9 +66,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
     let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
     let wg_per_matrix = wg_m_count * wg_n_count;
 
-    let batch_idx = wg_id.x / wg_per_matrix;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
 
-    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let batch_idx = wg_linear / wg_per_matrix;
+
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    let wg_in_batch = wg_linear % wg_per_matrix;
     let wg_m = wg_in_batch % wg_m_count;
     let wg_n = wg_in_batch / wg_m_count;
 
@@ -233,15 +135,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
     for (var tn = 0u; tn < TILE_N; tn++) {
         let global_col = output_col_base + tn;
         if (global_col < params.n) {
-            for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
+            for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
                 let global_row = output_row_base + tm;
                 if (global_row < params.m) {
                     let dst_idx = dst_batch_offset + global_col * params.m + global_row;
-                    dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
+                    dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
                 }
             }
         }
     }
 }
-
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl
similarity index 64%
rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl
index 47c8ce36..9f9ef279 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl
@@ -1,100 +1,12 @@
-#define(VARIANTS)
-[
-  {
-    "SHADER_SUFFIX": "f32_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f32_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE" : "vec4",
-      "SHMEM_TYPE" : "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE" : "f32",
-      "SHMEM_TYPE" : "f16",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
-  }
-]
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
 
-#end(VARIANTS)
+#include "common_decls.tmpl"
+#include "mul_mat_decls.tmpl"
 
-#define(DECLS)
-
-#decl(VEC)
+#ifdef VEC
 fn store_dst(shmem_idx: u32, dst_idx: u32) {
     dst[dst_idx] = vec4(
         f32(shmem[shmem_idx]),
@@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
         f32(shmem[shmem_idx + 3])
     );
 }
-#enddecl(VEC)
+#endif
 
-#decl(SCALAR)
+#ifdef SCALAR
 fn store_dst(shmem_idx: u32, dst_idx: u32) {
     dst[dst_idx] = f32(shmem[shmem_idx]);
 }
-#enddecl(SCALAR)
-
-#end(DECLS)
-
-#define(SHADER)
-diagnostic(off, chromium.subgroup_matrix_uniformity);
-enable f16;
-enable subgroups;
-enable chromium_experimental_subgroup_matrix;
+#endif
 
 struct MulMatParams {
     offset_src0: u32,
@@ -138,36 +42,19 @@ struct MulMatParams {
     broadcast3: u32
 };
 
-@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns
-@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
-@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
+// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
+@group(0) @binding(0) var src0: array; // M rows, K columns
+@group(0) @binding(1) var src1: array; // K rows, N columns (transposed)
+@group(0) @binding(2) var dst: array; // M rows, N columns (transposed)
 
 @group(0) @binding(3) var params: MulMatParams;
 
-DECLS
-
-// Note: These are string interpolated at build time, cannot use override constants due to limitations in
-// current Dawn version type definitions/matrix load requirements for constant memory sizes.
-const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
-const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
-// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
-// runtime subgroup size is smaller.
-const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
-
-const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
-
-const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
-const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
-const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
-
-const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
-const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
-
-const TILE_K = {{WEBGPU_TILE_K}}u;
-
 const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
 const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
 
+// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
+// runtime subgroup size is smaller.
+const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
 const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
 const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
 const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
@@ -182,7 +69,8 @@ var shmem: array;
 @compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
 fn main(@builtin(workgroup_id) wg_id: vec3,
         @builtin(local_invocation_id) local_id: vec3,
-        @builtin(subgroup_id) subgroup_id: u32) {
+        @builtin(subgroup_id) subgroup_id: u32,
+        @builtin(num_workgroups) num_wg: vec3) {
 
     let thread_id = local_id.x;
     let subgroup_m = subgroup_id % SUBGROUP_M;
@@ -192,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
     let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
     let wg_per_matrix = wg_m_count * wg_n_count;
 
-    let batch_idx = wg_id.x / wg_per_matrix;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
 
-    let wg_in_batch = wg_id.x % wg_per_matrix;
+    let batch_idx = wg_linear / wg_per_matrix;
+
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    let wg_in_batch = wg_linear % wg_per_matrix;
     let wg_m = wg_in_batch % wg_m_count;
     let wg_n = wg_in_batch / wg_m_count;
 
@@ -285,7 +180,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
     let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
     let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
 
-    for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
+    for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
         let local_row = idx % WG_TILE_STRIDE;
         let local_col = idx / WG_TILE_STRIDE;
 
@@ -294,9 +189,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3,
 
         if (global_col < params.n && global_row < params.m) {
             let dst_idx = dst_batch_offset + global_col * params.m + global_row;
-            store_dst(idx, dst_idx/{{VEC_SIZE}});
+            store_dst(idx, dst_idx/VEC_SIZE);
         }
     }
 }
 
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
deleted file mode 100644
index ffbb6403..00000000
--- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl
+++ /dev/null
@@ -1,267 +0,0 @@
-#define(VARIANTS)
-[
-  {
-    "SHADER_SUFFIX": "f32_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE": "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f32_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f32",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE": "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16_vec",
-    "REPLS": {
-      "SRC0_TYPE" : "vec4",
-      "SRC1_TYPE" : "vec4",
-      "DST_TYPE": "vec4",
-      "VEC_SIZE" : 4,
-    },
-    "DECLS": ["VEC", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "f16_f16",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f16",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
-  },
-  {
-    "SHADER_SUFFIX": "q4_0_f32",
-    "REPLS": {
-      "SRC0_TYPE" : "f16",
-      "SRC1_TYPE" : "f32",
-      "DST_TYPE": "f32",
-      "VEC_SIZE" : 1,
-    },
-    "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(VEC)
-fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
-    return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
-}
-
-fn store_val(group_base: u32) -> vec4 {
-    return vec4(partial_sums[group_base],
-                     partial_sums[group_base + THREADS_PER_OUTPUT],
-                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],
-                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
-}
-#enddecl(VEC)
-
-#decl(SCALAR)
-fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
-    return f32(src0_val) * f32(src1_val);
-}
-
-fn store_val(group_base: u32) -> f32 {
-    return partial_sums[group_base];
-}
-#enddecl(SCALAR)
-
-#decl(MUL_ACC_FLOAT)
-
-fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
-    var local_sum = 0.0;
-    for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
-        let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
-        let b = shared_vector[i / {{VEC_SIZE}}];
-        local_sum += inner_dot(a, b);
-    }
-    return local_sum;
-}
-
-#enddecl(MUL_ACC_FLOAT)
-
-#decl(MUL_ACC_Q4_0)
-
-const BLOCK_SIZE = 32;
-const NQ = 16u; // number of weights per thread
-const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
-const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
-const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
-
-fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
-    var local_sum = 0.0;
-    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
-        let blck_idx = i / BLOCK_SIZE;
-        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
-        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
-        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
-        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
-        let d = f32(src0[scale_idx]);
-        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
-            let q_0 = src0[scale_idx + 1 + block_offset + j];
-            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
-            let q_packed = bitcast(vec2(q_0, q_1));
-            for (var k: u32 = 0; k < 4; k++) {
-                let q_byte = get_byte(q_packed, k);
-                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
-                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
-                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
-                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
-            }
-        }
-    }
-    return local_sum;
-}
-
-#enddecl(MUL_ACC_Q4_0)
-
-#end(DECLS)
-
-#define(SHADER)
-enable f16;
-
-DECLS
-
-struct MulMatParams {
-    offset_src0: u32,
-    offset_src1: u32,
-    offset_dst: u32,
-    m: u32,
-    n: u32,
-    k: u32,
-    stride_01: u32,
-    stride_11: u32,
-    stride_02: u32,
-    stride_12: u32,
-    stride_03: u32,
-    stride_13: u32,
-    bs02: u32,
-    bs03: u32,
-    broadcast2: u32,
-    broadcast3: u32
-};
-
-@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
-@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
-@group(0) @binding(2) var dst: array<{{DST_TYPE}}>;  // Result vector (transposed)
-
-@group(0) @binding(3) var params: MulMatParams;
-
-override WORKGROUP_SIZE: u32;
-override TILE_K: u32;
-override OUTPUTS_PER_WG: u32;
-override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
-
-// Shared memory for collaborative loading and reduction
-var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>;  // Cache vector tile
-var partial_sums: array;   // For reduction
-
-@compute @workgroup_size(WORKGROUP_SIZE)
-fn main(
-    @builtin(local_invocation_id) local_id: vec3,
-    @builtin(workgroup_id) wg_id: vec3,
-    @builtin(num_workgroups) num_wg: vec3) {
-    let thread_id = local_id.x;
-
-    // Handle batch dimensions
-    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
-    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
-    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
-    let batch_idx = wg_linear / output_groups;
-    if (batch_idx >= total_batches) {
-        return;
-    }
-
-    // Which of the outputs does this thread belong to?
-    let thread_group = thread_id / THREADS_PER_OUTPUT;
-    let thread_in_group = thread_id % THREADS_PER_OUTPUT;
-
-    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
-    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
-
-    let dst2_stride = params.m * params.n;
-    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
-    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
-    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
-    let src03_idx = dst3_idx / params.broadcast3;
-    let src13_idx = dst3_idx;
-    let src02_idx = dst2_idx / params.broadcast2;
-    let src12_idx = dst2_idx;
-
-    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
-    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
-    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
-
-    var local_sum = 0.0;
-
-    // Each thread processes multiple K elements and accumulates
-    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
-        let tile_size = min(TILE_K, params.k - k_tile);
-
-        // Cooperatively load vector tile into shared memory (all threads)
-        for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
-            shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
-        }
-
-        workgroupBarrier();
-
-        if (output_row < params.m) {
-            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
-        }
-
-        workgroupBarrier();
-    }
-
-    // Store partial sums and reduce within each partition
-    partial_sums[thread_id] = local_sum;
-    workgroupBarrier();
-    let group_base = thread_group * THREADS_PER_OUTPUT;
-    let thread_base = group_base + thread_in_group;
-    var offset = THREADS_PER_OUTPUT / 2;
-    while (offset > 0) {
-        if (thread_in_group < offset) {
-            partial_sums[thread_base] += partial_sums[thread_base + offset];
-        }
-        offset = offset / 2;
-        workgroupBarrier();
-    }
-
-    // Store back to global memory
-    if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
-        dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
-    }
-}
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl
new file mode 100644
index 00000000..94f4bae1
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl
@@ -0,0 +1,480 @@
+enable f16;
+
+#include "common_decls.tmpl"
+
+#ifdef VEC
+
+#define VEC_SIZE 4
+#define DST_TYPE vec4
+#define SRC0_TYPE vec4
+#define SRC1_TYPE vec4
+
+fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
+    return f32(dot(SRC1_TYPE(src0_val), src1_val));
+}
+
+fn store_val(group_base: u32) -> vec4 {
+    return vec4(partial_sums[group_base],
+                     partial_sums[group_base + THREADS_PER_OUTPUT],
+                     partial_sums[group_base + THREADS_PER_OUTPUT * 2],
+                     partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
+}
+#endif
+
+#ifdef SCALAR
+
+#define VEC_SIZE 1
+#define DST_TYPE f32
+#define SRC0_TYPE SRC0_INNER_TYPE
+#define SRC1_TYPE SRC1_INNER_TYPE
+
+fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
+    return f32(src0_val) * f32(src1_val);
+}
+
+fn store_val(group_base: u32) -> f32 {
+    return partial_sums[group_base];
+}
+#endif
+
+#ifdef MUL_ACC_FLOAT
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
+        let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
+        let b = shared_vector[i / VEC_SIZE];
+        local_sum += inner_dot(a, b);
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q4_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 1 + block_offset + j];
+            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+            let q_packed = bitcast(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+                let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d;
+                let q_lo = (f32(q_byte & 0xF) - 8.0) * d;
+                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q4_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 10u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let m = f32(src0[scale_idx + 1u]);
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 2u + block_offset + j];
+            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+            let q_packed = bitcast(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+                let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
+                let q_lo = f32(q_byte & 0xF) * d + m;
+                local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q5_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 11u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let qh0 = src0[scale_idx + 1u];
+        let qh1 = src0[scale_idx + 2u];
+        let qh_packed = bitcast(vec2(qh0, qh1));
+
+        for (var j = 0u; j < 2; j++) {
+            let q_0 = src0[scale_idx + 3u + block_offset + (j*2)];
+            let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u];
+            let q_packed = bitcast(vec2(q_0, q_1));
+
+            let j_adjusted = j + (block_offset / 2u);
+
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+
+                let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
+                let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
+
+                local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
+            }
+
+        }
+    }
+    return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q5_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 12u;
+const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let m = src0[scale_idx + 1u];
+        let qh0 = src0[scale_idx + 2u];
+        let qh1 = src0[scale_idx + 3u];
+        let qh_packed = bitcast(vec2(qh0, qh1));
+
+        for (var j = 0u; j < 2; j++) {
+            let q_0 = src0[scale_idx + 4u + block_offset + (j*2)];
+            let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u];
+            let q_packed = bitcast(vec2(q_0, q_1));
+
+            let j_adjusted = j + (block_offset / 2u);
+
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte(q_packed, k);
+
+                let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10;
+                let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m);
+                let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10;
+                let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m);
+
+                local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k];
+                local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16];
+            }
+
+        }
+    }
+    return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q8_0
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 17u;
+const WEIGHTS_PER_F16 = 2u;
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 1 + block_offset + j];
+            let q_1 = src0[scale_idx + 1 + block_offset + j + 1];
+            let q_packed = bitcast(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte_i32(q_packed, k);
+                let q_val = f32(q_byte) * d;
+                local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+
+#ifdef MUL_ACC_Q8_1
+
+const BLOCK_SIZE = 32;
+const NQ = 16u; // number of weights per thread
+const F16_PER_BLOCK = 18u;
+const WEIGHTS_PER_F16 = 2u;
+const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
+
+fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    var local_sum = 0.0;
+    for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) {
+        let blck_idx = i / BLOCK_SIZE;
+        let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
+        let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK;
+        // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17]
+        let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
+        let d = f32(src0[scale_idx]);
+        let m = src0[scale_idx + 1u];
+
+        for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+            let q_0 = src0[scale_idx + 2u + block_offset + j];
+            let q_1 = src0[scale_idx + 2u + block_offset + j + 1];
+            let q_packed = bitcast(vec2(q_0, q_1));
+            for (var k: u32 = 0; k < 4; k++) {
+                let q_byte = get_byte_i32(q_packed, k);
+                let q_val = f32(q_byte) * d + f32(m);
+                local_sum += q_val * shared_vector[shmem_idx + j * 2 + k];
+            }
+        }
+    }
+    return local_sum;
+}
+#endif
+
+#ifdef MUL_ACC_Q6_K
+
+const BLOCK_SIZE = 256u;
+const F16_PER_BLOCK = 105u;
+
+fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 {
+    let aligned = byte_offset & ~3u;
+    let idx = bbase + aligned / 2u;
+    return bitcast(vec2(src0[idx], src0[idx + 1u]));
+}
+
+fn byte_of(v: u32, b: u32) -> u32 {
+    return (v >> (b * 8u)) & 0xFFu;
+}
+
+fn sbyte_of(v: u32, b: u32) -> i32 {
+    let raw = i32((v >> (b * 8u)) & 0xFFu);
+    return select(raw, raw - 256, raw >= 128);
+}
+
+fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
+    let tid = tig / 2u;
+    let ix  = tig % 2u;
+    let ip  = tid / 8u;
+    let il  = tid % 8u;
+    let l0  = 4u * il;
+    let is  = 8u * ip + l0 / 16u;
+
+    let y_offset   = 128u * ip + l0;
+    let q_offset_l =  64u * ip + l0;
+    let q_offset_h =  32u * ip + l0;
+
+    let nb = tile_size / BLOCK_SIZE;
+    let k_block_start = k_outer / BLOCK_SIZE;
+
+    // Aligned scale byte position (is can be odd)
+    let sc_base_byte = 192u + (is & ~3u);
+    let sc_byte_pos  = is & 3u;
+
+    var local_sum = 0.0;
+
+    for (var i = ix; i < nb; i += 2u) {
+        let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK;
+
+        let d_raw = load_u32_at(bbase, 208u);
+        let d = f32(bitcast>(d_raw)[0]);
+
+        let ql1_u32  = load_u32_at(bbase, q_offset_l);
+        let ql2_u32  = load_u32_at(bbase, q_offset_l + 32u);
+        let qh_u32   = load_u32_at(bbase, 128u + q_offset_h);
+        let sc_u32_0 = load_u32_at(bbase, sc_base_byte);
+        let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u);
+
+        let sc0 = sbyte_of(sc_u32_0, sc_byte_pos);
+        let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);
+        let sc4 = sbyte_of(sc_u32_1, sc_byte_pos);
+        let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u);
+
+        var sums = vec4(0.0, 0.0, 0.0, 0.0);
+
+        for (var l = 0u; l < 4u; l++) {
+            let y_base = i * BLOCK_SIZE + y_offset + l;
+            let yl0 = f32(shared_vector[y_base]);
+            let yl1 = f32(shared_vector[y_base + 32u]);
+            let yl2 = f32(shared_vector[y_base + 64u]);
+            let yl3 = f32(shared_vector[y_base + 96u]);
+
+            let q1b = byte_of(ql1_u32, l);
+            let q2b = byte_of(ql2_u32, l);
+            let qhb = byte_of(qh_u32,  l);
+
+            let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32);
+            let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32);
+            let dq2 = f32(i32((q1b >>   4u) | ((qhb & 0x30u)       )) - 32);
+            let dq3 = f32(i32((q2b >>   4u) | ((qhb & 0xC0u) >> 2u)) - 32);
+
+            sums[0] += yl0 * dq0;
+            sums[1] += yl1 * dq1;
+            sums[2] += yl2 * dq2;
+            sums[3] += yl3 * dq3;
+        }
+
+        local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) +
+                          sums[2] * f32(sc4) + sums[3] * f32(sc6));
+    }
+
+    return local_sum;
+}
+#endif
+
+struct MulMatParams {
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+    m: u32,
+    n: u32,
+    k: u32,
+    stride_01: u32,
+    stride_11: u32,
+    stride_02: u32,
+    stride_12: u32,
+    stride_03: u32,
+    stride_13: u32,
+    bs02: u32,
+    bs03: u32,
+    broadcast2: u32,
+    broadcast3: u32
+};
+
+// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
+@group(0) @binding(0) var src0: array; // M rows, K columns
+@group(0) @binding(1) var src1: array; // K rows, N columns (transposed)
+@group(0) @binding(2) var dst: array; // M rows, N columns (transposed)
+
+@group(0) @binding(3) var params: MulMatParams;
+
+const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
+
+// Shared memory for collaborative loading and reduction
+var shared_vector: array;  // Cache vector tile
+var partial_sums: array;   // For reduction
+
+@compute @workgroup_size(WG_SIZE)
+fn main(
+    @builtin(local_invocation_id) local_id: vec3,
+    @builtin(workgroup_id) wg_id: vec3,
+    @builtin(num_workgroups) num_wg: vec3) {
+    let thread_id = local_id.x;
+
+    // Handle batch dimensions
+    let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3;
+    let wg_linear = wg_id.y * num_wg.x + wg_id.x;
+    let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG;
+    let batch_idx = wg_linear / output_groups;
+    if (batch_idx >= total_batches) {
+        return;
+    }
+
+    // Which of the outputs does this thread belong to?
+    let thread_group = thread_id / THREADS_PER_OUTPUT;
+    let thread_in_group = thread_id % THREADS_PER_OUTPUT;
+
+    // Each workgroup computes OUTPUTS_PER_WG consecutive outputs
+    let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group;
+
+    let dst2_stride = params.m * params.n;
+    let dst2_idx = batch_idx % (params.bs02 * params.broadcast2);
+    let dst3_stride = dst2_stride * params.bs02 * params.broadcast2;
+    let dst3_idx = batch_idx / (params.bs02 * params.broadcast2);
+    let src03_idx = dst3_idx / params.broadcast3;
+    let src13_idx = dst3_idx;
+    let src02_idx = dst2_idx / params.broadcast2;
+    let src12_idx = dst2_idx;
+
+    let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01;
+    let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12;
+    let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row;
+
+    var local_sum = 0.0;
+
+    // Each thread processes multiple K elements and accumulates
+    for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) {
+        let tile_size = min(TILE_K, params.k - k_tile);
+
+        // Cooperatively load vector tile into shared memory (all threads)
+        for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
+            shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
+        }
+
+        workgroupBarrier();
+
+        if (output_row < params.m) {
+            local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile);
+        }
+
+        workgroupBarrier();
+    }
+
+    // Store partial sums and reduce within each partition
+    partial_sums[thread_id] = local_sum;
+    workgroupBarrier();
+    let group_base = thread_group * THREADS_PER_OUTPUT;
+    let thread_base = group_base + thread_in_group;
+    var offset: u32 = THREADS_PER_OUTPUT / 2;
+    while (offset > 0) {
+        if (thread_in_group < offset) {
+            partial_sums[thread_base] += partial_sums[thread_base + offset];
+        }
+        offset = offset / 2;
+        workgroupBarrier();
+    }
+
+    // Store back to global memory
+    if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
+        dst[dst_idx / VEC_SIZE] = store_val(group_base);
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
new file mode 100644
index 00000000..ea63b9a7
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl
@@ -0,0 +1,86 @@
+@group(0) @binding(0)
+var src: array;
+
+@group(0) @binding(1)
+var dst: array;
+
+struct Params {
+    ne: u32,            // total number of elements
+    offset_src: u32,    // in elements
+    offset_dst: u32,    // in elements
+
+    // Strides (in elements)
+    stride_src0: u32,
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    // Logical shapes
+    src_ne0: u32,
+    src_ne1: u32,
+    src_ne2: u32,
+    src_ne3: u32,
+
+    dst_ne0: u32,
+    dst_ne1: u32,
+    dst_ne2: u32,
+    dst_ne3: u32,
+
+    // Pad sizes (in elements)
+    lp0: u32,
+    rp0: u32,
+    lp1: u32,
+    rp1: u32,
+    lp2: u32,
+    rp2: u32,
+    lp3: u32,
+    rp3: u32,
+};
+
+@group(0) @binding(2)
+var params: Params;
+
+fn wrap_around(idx: i32, n: u32) -> u32 {
+    return u32(idx + i32(n)) % n;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+    if (gid.x >= params.ne) {
+        return;
+    }
+
+    var i = gid.x;
+    let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
+    let i3 = i / dst_plane;
+    i = i % dst_plane;
+    let i2 = i / (params.dst_ne1 * params.dst_ne0);
+    i = i % (params.dst_ne1 * params.dst_ne0);
+    let i1 = i / params.dst_ne0;
+    let i0 = i % params.dst_ne0;
+
+    var value: f32 = 0.0;
+
+#ifdef CIRCULAR
+    let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
+    let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
+    let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
+    let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
+    let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
+                           ci2 * params.stride_src2 + ci3 * params.stride_src3;
+    value = src[params.offset_src + circular_src_idx];
+#else
+    let is_src =
+        (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
+        (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
+        (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
+        (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
+    if (is_src) {
+        let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
+                      (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
+        value = src[params.offset_src + src_idx];
+    }
+#endif
+
+    dst[params.offset_dst + gid.x] = value;
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl
new file mode 100644
index 00000000..6e2a1a8b
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl
@@ -0,0 +1,67 @@
+enable f16;
+
+struct Params {
+    ne: u32,
+
+    offset_src0: u32,
+    offset_dst: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
+
+    a_ne0: u32,
+    a_ne1: u32,
+    a_ne2: u32,
+    a_ne3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+};
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_I32
+#define DataType i32
+#endif
+#ifdef TYPE_I16
+// same size (16-bit) is sufficient for repeat
+#define DataType f16
+#endif
+
+@group(0) @binding(0)
+var src0: array;
+
+@group(0) @binding(1)
+var dst: array;
+
+@group(0) @binding(2)
+var params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+    if (gid.x < params.ne) {
+        var i = gid.x;
+        let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+        i = i % (params.ne2 * params.ne1 * params.ne0);
+        let i2 = i / (params.ne1 * params.ne0);
+        i = i % (params.ne1 * params.ne0);
+        let i1 = i / params.ne0;
+        let i0 = i % params.ne0;
+
+        let a_i0 = i0 % params.a_ne0;
+        let a_i1 = i1 % params.a_ne1;
+        let a_i2 = i2 % params.a_ne2;
+        let a_i3 = i3 % params.a_ne3;
+
+        let a_index = a_i0 * params.stride_src0_0 +
+                           a_i1 * params.stride_src0_1 +
+                           a_i2 * params.stride_src0_2 +
+                           a_i3 * params.stride_src0_3;
+
+        dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index];
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl
similarity index 78%
rename from ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
rename to ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl
index 040e80df..3b70a876 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl
@@ -1,21 +1,11 @@
-#define(VARIANTS)
+#ifdef INPLACE
+@group(0) @binding(1)
+var params: Params;
 
-[
-  {
-    "SHADER_NAME": "scale_f32",
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "scale_f32_inplace",
-    "DECLS": ["INPLACE"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(NOT_INPLACE)
+fn store_scale(val: f32, offset: u32) {
+    src[offset] = val;
+}
+#else
 @group(0) @binding(1)
 var dst: array;
 
@@ -25,20 +15,7 @@ var params: Params;
 fn store_scale(val: f32, offset: u32) {
     dst[offset] = val;
 }
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-@group(0) @binding(1)
-var params: Params;
-
-fn store_scale(val: f32, offset: u32) {
-    src[offset] = val;
-}
-#enddecl(INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
+#endif
 
 struct Params {
     offset_src: u32,
@@ -65,10 +42,7 @@ struct Params {
 @group(0) @binding(0)
 var src: array;
 
-DECLS
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
+@compute @workgroup_size(WG_SIZE)
 fn main(@builtin(global_invocation_id) gid: vec3) {
     if (gid.x >= params.ne) {
         return;
@@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3) {
 
     store_scale(src[i_src] * params.scale + params.bias, i_dst);
 }
-#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl
deleted file mode 100644
index fca3be6b..00000000
--- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl
+++ /dev/null
@@ -1,112 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "SHADER_SUFFIX": "f16_vec",
-    "REPLS": {
-      "TYPE" : "vec4",
-      "DST_TYPE": "vec4",
-      "VEC_SIZE": 4
-    }
-  },
-  {
-    "SHADER_SUFFIX": "f16",
-    "REPLS": {
-      "TYPE" : "f32",
-      "DST_TYPE": "f16",
-      "VEC_SIZE": 1
-    }
-  }
-]
-
-#end(VARIANTS)
-
-#define(SHADER)
-
-enable f16;
-
-@group(0) @binding(0)
-var src: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var idx: array;
-
-@group(0) @binding(2)
-var dst: array<{{DST_TYPE}}>;
-
-@group(0) @binding(3)
-var error: atomic;
-
-struct Params {
-    offset_src: u32, // in elements
-    offset_idx: u32, // in elements
-    offset_dst: u32, // in elements
-
-    // Strides (in elements)
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_idx0: u32,
-    stride_idx1: u32,
-    stride_idx2: u32,
-
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Shape of src
-    ne0: u32,
-    n_rows: u32,
-    ne2: u32,
-    ne3: u32,
-
-    // Shape of idx
-    idx1: u32,
-    idx2: u32,
-};
-
-@group(0) @binding(4)
-var params: Params;
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3) {
-    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
-        return;
-    }
-
-    // getting the row from gid
-    let elems_per_row = params.ne0 / {{VEC_SIZE}};
-    var i = gid.x / elems_per_row;
-
-    let i_src3 = i / (params.ne2 * params.n_rows);
-
-    i = i % (params.ne2 * params.n_rows);
-    let i_src2 = i / params.n_rows;
-    let i_src1 = i % params.n_rows;
-
-    let i_idx2 = i_src3 % params.idx2;
-    let i_idx1 = i_src2 % params.idx1;
-    let i_idx0 = i_src1;
-
-    let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
-
-    let idx_high_val = idx[idx_high];
-    let idx_low_val = idx[idx_high + 1];
-
-    if (idx_low_val != 0) {
-        // Upper bits of index are not zero, output will be incorrect
-        atomicStore(&error, 1);
-        return;
-    }
-
-    let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
-    let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
-
-    let col_idx = (gid.x % elems_per_row);
-    dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
-}
-
-#end(SHADER)
-
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
index 3567713d..99e9192c 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl
@@ -1,16 +1,37 @@
 enable f16;
 
+#ifdef DST_F32
+#define DST_INNER_TYPE f32
+#else
+#define DST_INNER_TYPE f16
+#endif
+
+#ifdef VEC4
+#define SRC_TYPE vec4
+#define DST_TYPE vec4
+#define VEC_SIZE 4
+#else
+#define SRC_TYPE f32
+#define DST_TYPE DST_INNER_TYPE
+#define VEC_SIZE 1
+#endif
+
 @group(0) @binding(0)
-var src: array;
+var src: array;
 
 @group(0) @binding(1)
 var idx: array;
 
 @group(0) @binding(2)
-var dst: array;
+var dst: array;
 
+#ifdef I64_IDX
 @group(0) @binding(3)
 var error: atomic;
+#define PARAMS_BINDING 4
+#else
+#define PARAMS_BINDING 3
+#endif
 
 struct Params {
     offset_src: u32, // in elements
@@ -41,16 +62,19 @@ struct Params {
     idx2: u32,
 };
 
-@group(0) @binding(4)
+@group(0) @binding(PARAMS_BINDING)
 var params: Params;
 
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
+@compute @workgroup_size(WG_SIZE)
 fn main(@builtin(global_invocation_id) gid: vec3) {
-    if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
+    if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
         return;
     }
-    var i = gid.x;
+
+    // getting the row from gid
+    let elems_per_row = params.ne0 / VEC_SIZE;
+    var i = gid.x / elems_per_row;
+
     let i_src3 = i / (params.ne2 * params.n_rows);
 
     i = i % (params.ne2 * params.n_rows);
@@ -61,9 +85,10 @@ fn main(@builtin(global_invocation_id) gid: vec3) {
     let i_idx1 = i_src2 % params.idx1;
     let i_idx0 = i_src1;
 
+#ifdef I64_IDX
     let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
 
-    let idx_high_val = idx[idx_high];
+    let idx_val = idx[idx_high];
     let idx_low_val = idx[idx_high + 1];
 
     if (idx_low_val != 0) {
@@ -71,11 +96,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) {
         atomicStore(&error, 1);
         return;
     }
+#else
+    let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+    let idx_val = idx[idx_i];
+#endif
 
-    let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
+    let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
     let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
 
-    for (var i: u32 = 0; i < params.ne0; i++) {
-      dst[i_dst_row + i] = f16(src[i_src_row + i]);
-    }
+    let col_idx = (gid.x % elems_per_row);
+    dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
 }
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl
new file mode 100644
index 00000000..6ea2de9b
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl
@@ -0,0 +1,55 @@
+@group(0) @binding(0)
+var src: array;
+
+@group(0) @binding(1)
+var dst: array;
+
+struct Params {
+    offset_src: u32, // in elements
+    offset_dst: u32, // in elements
+
+    // Strides (in elements)
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32
+};
+
+@group(0) @binding(2)
+var params: Params;
+
+var shared_sum: array;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wid: vec3,
+        @builtin(local_invocation_id) lid: vec3) {
+
+    var i = wid.x;
+    let i3 = i / (params.ne2 * params.ne1);
+    i = i % (params.ne2 * params.ne1);
+    let i2 = i / params.ne1;
+    let i1 = i % params.ne1;
+    let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
+    var local_sum: f32 = 0.0;
+    for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
+        local_sum += src[i_src_row + col];
+    }
+    shared_sum[lid.x] = local_sum;
+    workgroupBarrier();
+    // reduce within workgroup
+    var offset: u32 = WG_SIZE >> 1;
+    while (offset > 0) {
+        if (lid.x < offset) {
+            shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
+        }
+        workgroupBarrier();
+        offset >>= 1;
+    }
+
+    if (lid.x == 0) {
+        dst[params.offset_dst + wid.x] = shared_sum[0];
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
new file mode 100644
index 00000000..feaf6d0a
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl
@@ -0,0 +1,193 @@
+#ifdef TYPE_F16
+enable f16;
+#define TYPE f16
+#else
+#define TYPE f32
+#endif
+
+
+@group(0) @binding(0)
+var src: array;
+
+#ifndef INPLACE
+@group(0) @binding(1)
+var dst: array;
+#define PARAMS_BINDING 2
+#else
+#define PARAMS_BINDING 1
+#endif
+
+struct Params {
+    ne: u32,            // total number of elements
+    offset_src: u32,    // in elements
+    offset_dst: u32,    // in elements
+
+    // Strides (in elements)
+    stride_src0: u32,
+    stride_src1: u32,
+    stride_src2: u32,
+    stride_src3: u32,
+
+    // Logical shapes
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+#ifdef CLAMP
+    clamp_min: f32,
+    clamp_max: f32,
+#endif
+#ifdef FILL
+    fill_val: f32,
+#endif
+#ifdef XIELU
+    alpha_n: f32,
+    alpha_p: f32,
+    beta: f32,
+    eps: f32,
+#endif
+
+};
+
+@group(0) @binding(PARAMS_BINDING)
+var params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3) {
+    if (gid.x >= params.ne) {
+      return;
+    }
+    var i = gid.x;
+    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+    i = i % (params.ne2 * params.ne1 * params.ne0);
+    let i2 = i / (params.ne1 * params.ne0);
+    i = i % (params.ne1 * params.ne0);
+    let i1 = i / params.ne0;
+    let i0 = i % params.ne0;
+
+    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
+                  i2 * params.stride_src2 + i3 * params.stride_src3;
+
+#ifdef ABS
+    let res = abs(src[params.offset_src + src_idx]);
+#endif
+#ifdef SGN
+    let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
+                     src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef NEG
+    let res = -src[params.offset_src + src_idx];
+#endif
+#ifdef STEP
+    let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
+#endif
+#ifdef TANH
+    let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
+#endif
+#ifdef RELU
+    let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef ELU
+    let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
+                     src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef HARDSIGMOID
+    let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
+#endif
+#ifdef SIGMOID
+    let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
+#endif
+#ifdef SILU
+    let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
+#endif
+#ifdef EXP
+    let res = exp(src[params.offset_src + src_idx]);
+#endif
+#ifdef LOG
+    let res = TYPE(log(f32(src[params.offset_src + src_idx])));
+#endif
+#ifdef CLAMP
+    let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
+#endif
+#ifdef FILL
+    let res = TYPE(params.fill_val);
+#endif
+#ifdef HARDSWISH
+    let res = src[params.offset_src + src_idx] *
+              min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
+#endif
+#ifdef GELU
+    let res = 0.5 * src[params.offset_src + src_idx] *
+              (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
+                               (src[params.offset_src + src_idx] +
+                                0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
+                               -9.010913, 9.010913)));
+#endif
+#ifdef GELU_QUICK
+    let res = src[params.offset_src + src_idx] * 0.5 *
+              (1.0 + tanh(clamp(0.79788456 *
+                               (src[params.offset_src + src_idx] +
+                                0.044715 * src[params.offset_src + src_idx] *
+                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
+                               -9.010913, 9.010913)));
+#endif
+#ifdef GELU_ERF
+    let res = 0.5 * src[params.offset_src + src_idx] *
+              (1.0 + tanh(clamp(0.79788456 *
+                               (src[params.offset_src + src_idx] +
+                                0.044715 * src[params.offset_src + src_idx] *
+                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
+                               -9.010913, 9.010913)));
+#endif
+#ifdef XIELU
+    let res =
+        select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
+                src[params.offset_src + src_idx]) *
+                   TYPE(params.alpha_n) +
+               TYPE(params.beta) * src[params.offset_src + src_idx],
+               TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
+                   src[params.offset_src + src_idx] +
+                   TYPE(params.beta) * src[params.offset_src + src_idx],
+               src[params.offset_src + src_idx] > 0.0);
+#endif
+#ifdef SOFTPLUS
+    let src_f32 = f32(src[params.offset_src + src_idx]);
+    let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
+#endif
+#ifdef EXPM1
+    let res = exp(src[params.offset_src + src_idx]) - 1.0;
+#endif
+#ifdef FLOOR
+    let res = floor(src[params.offset_src + src_idx]);
+#endif
+#ifdef CEIL
+    let res = ceil(src[params.offset_src + src_idx]);
+#endif
+#ifdef ROUND
+    let src_f32 = f32(src[params.offset_src + src_idx]);
+    let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
+    let res = TYPE(result);
+#endif
+#ifdef TRUNC
+    let res = trunc(src[params.offset_src + src_idx]);
+#endif
+#ifdef SQR
+    let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx];
+#endif
+#ifdef SQRT
+    let res = sqrt(src[params.offset_src + src_idx]);
+#endif
+#ifdef SIN
+    let res_f32 = sin(f32(src[params.offset_src + src_idx]));
+    let res = TYPE(res_f32);
+#endif
+#ifdef COS
+    let res_f32 = cos(f32(src[params.offset_src + src_idx]));
+    let res = TYPE(res_f32);
+#endif
+
+#ifdef INPLACE
+    src[params.offset_src + src_idx] = res;
+#else
+    dst[params.offset_dst + gid.x] = res;
+#endif
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl
deleted file mode 100644
index 25fe2854..00000000
--- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl
+++ /dev/null
@@ -1,483 +0,0 @@
-#define(REPL_TEMPLATES)
-
-{
-    "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
-    "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
-    "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
-    "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
-    "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
-    "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
-    "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
-    "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
-    "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
-    "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
-    "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
-    "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
-    "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
-    "CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);"
-}
-
-#end(REPL_TEMPLATES)
-
-#define(VARIANTS)
-
-[
-    {
-      "SHADER_NAME": "abs_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "abs_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "abs_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "abs_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "sgn_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sgn_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sgn_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sgn_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "neg_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "neg_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "neg_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "neg_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "step_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "step_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "step_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "step_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "tanh_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "tanh_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "tanh_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "tanh_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "elu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "elu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "elu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "elu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "relu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "relu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "relu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "relu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "sigmoid_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sigmoid_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sigmoid_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "sigmoid_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "silu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "silu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "silu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "silu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "exp_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "exp_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "exp_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "exp_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "hardsigmoid_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardsigmoid_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardsigmoid_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardsigmoid_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "hardswish_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardswish_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardswish_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "hardswish_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "gelu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "gelu_quick_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_quick_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_quick_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "gelu_quick_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-
-    {
-      "SHADER_NAME": "xielu_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "xielu_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
-      "DECLS": ["NOT_INPLACE"]
-    },
-    {
-      "SHADER_NAME": "xielu_inplace_f32",
-      "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-      "SHADER_NAME": "xielu_inplace_f16",
-      "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
-      "DECLS": ["INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_inplace_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    },
-    {
-        "SHADER_NAME": "gelu_erf_inplace_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    },
-
-    {
-        "SHADER_NAME": "ceil_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "ceil_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
-        "DECLS": ["NOT_INPLACE"]
-    },
-    {
-        "SHADER_NAME": "ceil_inplace_f32",
-        "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    },
-    {
-        "SHADER_NAME": "ceil_inplace_f16",
-        "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
-        "DECLS": ["INPLACE"]
-    }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(INPLACE)
-
-@group(0) @binding(1)
-var params: Params;
-
-#enddecl(INPLACE)
-
-#decl(NOT_INPLACE)
-
-@group(0) @binding(1)
-var dst: array<{{TYPE}}>;
-
-@group(0) @binding(2)
-var params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#end(DECLS)
-
-#define(SHADER)
-
-enable f16;
-
-fn update(dst_i: u32, src_i: u32) {
-    {{FUNC}}
-}
-
-@group(0) @binding(0)
-var src: array<{{TYPE}}>;
-
-DECLS
-
-struct Params {
-    ne: u32,            // total number of elements
-    offset_src: u32,    // in elements
-    offset_dst: u32,    // in elements
-
-    // Strides (in elements) — may be permuted
-    stride_src0: u32,
-    stride_src1: u32,
-    stride_src2: u32,
-    stride_src3: u32,
-
-    stride_dst0: u32,
-    stride_dst1: u32,
-    stride_dst2: u32,
-    stride_dst3: u32,
-
-    // Logical shapes
-    src_ne0: u32,
-    src_ne1: u32,
-    src_ne2: u32,
-
-    dst_ne0: u32,
-    dst_ne1: u32,
-    dst_ne2: u32,
-
-    {{EXT_PARAMS}}
-};
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3) {
-    if (gid.x >= params.ne) {
-      return;
-    }
-
-    var i = gid.x;
-    let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
-    i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
-    let i2 = i / (params.src_ne1 * params.src_ne0);
-    i = i % (params.src_ne1 * params.src_ne0);
-    let i1 = i / params.src_ne0;
-    let i0 = i % params.src_ne0;
-
-    var j = gid.x;
-    let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
-    j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
-    let j2 = j / (params.dst_ne1 * params.dst_ne0);
-    j = j % (params.dst_ne1 * params.dst_ne0);
-    let j1 = j / params.dst_ne0;
-    let j0 = j % params.dst_ne0;
-
-    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
-                  i2 * params.stride_src2 + i3 * params.stride_src3;
-
-    let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
-                  j2 * params.stride_dst2 + j3 * params.stride_dst3;
-
-
-    update(params.offset_dst + dst_idx, params.offset_src + src_idx);
-}
-
-#end(SHADER)
-
diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp
index edbeb8ee..9b6938ab 100644
--- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp
+++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp
@@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr
             continue;
         }
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         bool ok = ggml_zdnn_compute_forward(ctx, node);
         if (!ok) {
             GGML_LOG_ERROR("%s: unsupported op %s (%s)\n",
@@ -368,7 +372,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty
 }
 
 static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return true;
+    /* while it resides in host memory, additional transformation is needed */
+    return false;
 
     GGML_UNUSED(buft);
 }
diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt
index bdbfc743..9bdb4e83 100644
--- a/ggml/src/ggml-zendnn/CMakeLists.txt
+++ b/ggml/src/ggml-zendnn/CMakeLists.txt
@@ -1,12 +1,19 @@
 ggml_add_backend_library(ggml-zendnn
                          ggml-zendnn.cpp)
 
-# Get ZenDNN path
 if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "")
     set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}")
 endif()
 
-# Check if path is still empty or OFF
+if (BUILD_SHARED_LIBS)
+    set(ZENDNN_SHARED_LIB ON)
+    set(ZENDNN_ARCHIVE_LIB OFF)
+else()
+    set(ZENDNN_SHARED_LIB OFF)
+    set(ZENDNN_ARCHIVE_LIB ON)
+endif()
+
+# Download and build ZenDNN if not provided
 if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
     message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...")
     message(STATUS "This will take several minutes on first build...")
@@ -21,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
     ExternalProject_Add(
         zendnn
         GIT_REPOSITORY https://github.com/amd/ZenDNN.git
-        GIT_TAG zendnnl
+        GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf    # ZenDNN-2026-WW08
         PREFIX      ${ZENDNN_PREFIX}
         SOURCE_DIR  ${ZENDNN_SOURCE_DIR}
         BINARY_DIR  ${ZENDNN_BUILD_DIR}
@@ -32,7 +39,9 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
             -DZENDNNL_BUILD_DOXYGEN=OFF
             -DZENDNNL_BUILD_GTEST=OFF
             -DZENDNNL_BUILD_BENCHDNN=OFF
-            # Enable ALL matmul algorithm backends
+            -DZENDNNL_DEPENDS_FBGEMM=OFF
+            -DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB}
+            -DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB}
             -DZENDNNL_DEPENDS_AOCLDLP=ON
             -DZENDNNL_DEPENDS_ONEDNN=ON
             -DZENDNNL_DEPENDS_LIBXSMM=ON
@@ -45,47 +54,37 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
         LOG_INSTALL ON
     )
 
-    # Add dependency so ZenDNN builds before our library
     add_dependencies(ggml-zendnn zendnn)
-
-    # Set ZENDNN_ROOT to the installation directory
     set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR})
-
     message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}")
 else()
     message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}")
 endif()
 
-# ZenDNN headers + libs
 target_include_directories(ggml-zendnn PRIVATE
     ${ZENDNN_ROOT}/zendnnl/include
-    ${ZENDNN_ROOT}/deps/aocldlp/include
-    ${ZENDNN_ROOT}/deps/aoclutils/include
     ${ZENDNN_ROOT}/deps/json/include
-    ${ZENDNN_ROOT}/deps/libxsmm/include
+    ${ZENDNN_ROOT}/deps/aoclutils/include
+    ${ZENDNN_ROOT}/deps/aocldlp/include
     ${ZENDNN_ROOT}/deps/onednn/include
-)
+    ${ZENDNN_ROOT}/deps/libxsmm/include)
 
-target_link_directories(ggml-zendnn PRIVATE
-    ${ZENDNN_ROOT}/zendnnl/lib
-    ${ZENDNN_ROOT}/deps/aocldlp/lib
-    ${ZENDNN_ROOT}/deps/aoclutils/lib
-    ${ZENDNN_ROOT}/deps/libxsmm/lib
-    ${ZENDNN_ROOT}/deps/onednn/lib
-)
+if (ZENDNN_SHARED_LIB)
+    target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib)
+    target_link_libraries(ggml-zendnn PRIVATE zendnnl)
+elseif (ZENDNN_ARCHIVE_LIB)
+    target_link_libraries(ggml-zendnn PRIVATE
+        ${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a
+        ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a
+        ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a
+        ${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a
+        ${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a
+        ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a
+        ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a
+        ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a)
+endif()
 
-target_link_libraries(ggml-zendnn PRIVATE
-    zendnnl_archive    # ZenDNN main
-    aocl-dlp           # AOCL libraries
-    aoclutils
-    au_cpuid
-    dnnl               # OneDNN
-    xsmm               # libxsmm small matrix math
-    xsmmext
-    xsmmnoblas
-    m
-    pthread
-)
+target_link_libraries(ggml-zendnn PRIVATE m pthread)
 
 if (GGML_OPENMP)
     target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX)
diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp
index fd07f983..c8760304 100644
--- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp
+++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp
@@ -2,7 +2,6 @@
 
 #include "ggml-backend-impl.h"
 #include "ggml-impl.h"
-#include "ggml-cpu.h"
 #include "zendnnl.hpp"
 
 #include 
@@ -42,13 +41,13 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
                                const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,
                                int64_t ldc) {
 
-    zendnnl::lowoha::lowoha_params params;
+    zendnnl::lowoha::matmul::matmul_params params;
     params.dtypes.src = ggml_to_zendnn_type();
     params.dtypes.wei = ggml_to_zendnn_type();
     params.dtypes.dst = ggml_to_zendnn_type();
     params.num_threads = ctx->n_threads;
 
-    zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct(
+    zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
         'r', false, true,   // row-major, don't transpose B, transpose A (because it's column-major)
         n,                  // M: rows of B and C
         m,                  // N: cols of A^T and C
@@ -64,7 +63,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
         params              // params
     );
 
-    if (status != zendnnl::lowoha::status_t::success) {
+    if (status != zendnnl::error_handling::status_t::success) {
         GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast(status));
         return false;
     }
@@ -122,8 +121,8 @@ static void ggml_zendnn_compute_forward_mul_mat(
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
-    ggml_type         const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
-    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float;
+    ggml_type         const vec_dot_type = src0->type;
+    ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
 
     GGML_ASSERT(ne0 == ne01);
     GGML_ASSERT(ne1 == ne11);
@@ -211,6 +210,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
     for (int i = 0; i < cgraph->n_nodes; i++) {
         struct ggml_tensor * node = cgraph->nodes[i];
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            continue;
+        }
+
         switch (node->op) {
             case GGML_OP_MUL_MAT:
                 ggml_zendnn_compute_forward_mul_mat(ctx, node);
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 09b8eb46..e5b83e14 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -718,6 +718,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
         .to_float                 = (ggml_to_float_t) dequantize_row_mxfp4,
         .from_float_ref           = (ggml_from_float_t)quantize_row_mxfp4_ref,
     },
+    [GGML_TYPE_NVFP4] = {
+        .type_name                = "nvfp4",
+        .blck_size                = QK_NVFP4,
+        .type_size                = sizeof(block_nvfp4),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_nvfp4,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_nvfp4_ref,
+    },
     [GGML_TYPE_Q2_K] = {
         .type_name                = "q2_K",
         .blck_size                = QK_K,
@@ -899,7 +907,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
 };
 
 const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
-    GGML_ASSERT(type < GGML_TYPE_COUNT);
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return &type_traits[type];
 }
 
@@ -1030,6 +1039,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GATED_LINEAR_ATTN",
     "RWKV_WKV7",
     "SOLVE_TRI",
+    "GATED_DELTA_NET",
 
     "UNARY",
 
@@ -1047,7 +1057,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1139,6 +1149,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "gated_linear_attn(k, v, q, gate, s)",
     "rwkv_wkv7(r, w, k, v, a, b, s)",
     "A X = B, A triangular, solve X",
+    "gated_delta_net(q, k, v, g, beta, s)",
 
     "unary(x)",
 
@@ -1156,7 +1167,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1265,27 +1276,33 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
 }
 
 int64_t ggml_blck_size(enum ggml_type type) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return type_traits[type].blck_size;
 }
 
 size_t ggml_type_size(enum ggml_type type) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return type_traits[type].type_size;
 }
 
 size_t ggml_row_size(enum ggml_type type, int64_t ne) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     assert(ne % ggml_blck_size(type) == 0);
     return ggml_type_size(type)*ne/ggml_blck_size(type);
 }
 
-double ggml_type_sizef(enum ggml_type type) {
-    return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
-}
-
 const char * ggml_type_name(enum ggml_type type) {
-    return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
+    return type_traits[type].type_name;
 }
 
 bool ggml_is_quantized(enum ggml_type type) {
+    assert(type >= 0);
+    assert(type < GGML_TYPE_COUNT);
     return type_traits[type].is_quantized;
 }
 
@@ -1365,6 +1382,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
         case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
         case GGML_FTYPE_MOSTLY_MXFP4:         wtype = GGML_TYPE_MXFP4; break;
+        case GGML_FTYPE_MOSTLY_NVFP4:         wtype = GGML_TYPE_NVFP4; break;
         case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
         case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
         case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
@@ -1403,16 +1421,14 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
     }
     next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
     for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        if (tensor->ne[i] != 1) {
-            if (i > n) {
-                if (tensor->nb[i] != next_nb) {
-                    return false;
-                }
-                next_nb *= tensor->ne[i];
-            } else {
-                // this dimension does not need to be contiguous
-                next_nb = tensor->ne[i]*tensor->nb[i];
+        if (i > n) {
+            if (tensor->ne[i] != 1 && tensor->nb[i] != next_nb) {
+                return false;
             }
+            next_nb *= tensor->ne[i];
+        } else {
+            // this dimension does not need to be contiguous
+            next_nb = tensor->ne[i]*tensor->nb[i];
         }
     }
     return true;
@@ -1496,6 +1512,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
         (t0->nb[3] == t1->nb[3]);
 }
 
+bool ggml_is_view(const struct ggml_tensor * t) {
+    return ggml_impl_is_view(t);
+}
+
 // check if t1 can be represented as a repetition of t0
 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -1625,11 +1645,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml
     const size_t cur_end  = cur_offs + cur_size;
 
     // align to GGML_MEM_ALIGN
+    GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1));
     size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
 
     char * const mem_buffer = ctx->mem_buffer;
     struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
 
+    // integer overflow checks
+    if (cur_end > SIZE_MAX - size_needed) {
+        GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed);
+        return NULL;
+    }
+    if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) {
+        GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__,
+                cur_end, size_needed, (size_t) GGML_OBJECT_SIZE);
+        return NULL;
+    }
+
     if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
         GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
                 __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
@@ -1698,6 +1730,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         obj_alloc_size = data_size;
     }
 
+    GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size);
+
     struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
     GGML_ASSERT(obj_new);
 
@@ -3441,7 +3475,8 @@ struct ggml_tensor * ggml_cast(
 
     result->op     = GGML_OP_CPY;
     result->src[0] = a;
-    result->src[1] = result;
+    result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some
+                             //       backends for consistency with ggml_cpy_impl() above
 
     return result;
 }
@@ -4838,6 +4873,8 @@ struct ggml_tensor * ggml_pool_1d(
         a->ne[2],
         a->ne[3],
     };
+    GGML_ASSERT(ne[0] > 0);
+
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, s0, p0 };
@@ -4868,6 +4905,9 @@ struct ggml_tensor * ggml_pool_2d(
         a->ne[2],
         a->ne[3],
     };
+    GGML_ASSERT(ne[0] > 0);
+    GGML_ASSERT(ne[1] > 0);
+
     result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
@@ -5743,7 +5783,7 @@ static struct ggml_tensor * ggml_unary_impl(
         struct ggml_tensor  * a,
         enum ggml_unary_op    op,
         bool                  inplace) {
-    GGML_ASSERT(ggml_is_contiguous_1(a));
+    GGML_ASSERT(ggml_is_contiguous_rows(a));
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
@@ -6095,6 +6135,57 @@ struct ggml_tensor * ggml_solve_tri(
     return result;
 }
 
+// ggml_gated_delta_net
+
+struct ggml_tensor * ggml_gated_delta_net(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state) {
+    GGML_ASSERT(ggml_is_contiguous_rows(q));
+    GGML_ASSERT(ggml_is_contiguous_rows(k));
+    GGML_ASSERT(ggml_is_contiguous_rows(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    GGML_ASSERT(q->type == GGML_TYPE_F32);
+    GGML_ASSERT(k->type == GGML_TYPE_F32);
+    GGML_ASSERT(v->type == GGML_TYPE_F32);
+    GGML_ASSERT(g->type == GGML_TYPE_F32);
+    GGML_ASSERT(beta->type == GGML_TYPE_F32);
+    GGML_ASSERT(state->type == GGML_TYPE_F32);
+
+    const int64_t S_v      = v->ne[0];
+    const int64_t H        = v->ne[1];
+    const int64_t n_tokens = v->ne[2];
+    const int64_t n_seqs   = v->ne[3];
+
+    // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA)
+    GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
+    GGML_ASSERT(beta->ne[0] == 1);
+
+    GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
+
+    // concat output and new_state into a single tensor
+    // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
+    const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_GATED_DELTA_NET;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = g;
+    result->src[4] = beta;
+    result->src[5] = state;
+
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_hash_set ggml_hash_set_new(size_t size) {
@@ -6556,7 +6647,7 @@ static void ggml_compute_backward(
         case GGML_OP_DIAG_MASK_INF: {
             if (src0_needs_grads) {
                 /* ggml_diag_mask_inf_impl() shouldn't be here */
-                /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
+                /* ref:  https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */
                 const int n_past = ((const int32_t *) tensor->op_params)[0];
                 ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
             }
@@ -6720,20 +6811,35 @@ static void ggml_compute_backward(
     GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
 }
 
-static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
-    // check if already visited
-    size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) {
+    if (node->op != GGML_OP_NONE && compute) {
+        node->flags |= GGML_TENSOR_FLAG_COMPUTE;
+    }
+
+    const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
     GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
-    if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
-        // This is the first time we see this node in the current graph.
-        cgraph->visited_hash_set.keys[node_hash_pos] = node;
-        ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
-        cgraph->use_counts[node_hash_pos] = 0;
-    } else {
+
+    if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
         // already visited
+
+        if (compute) {
+            // update the compute flag regardless
+            for (int i = 0; i < GGML_MAX_SRC; ++i) {
+                struct ggml_tensor * src = node->src[i];
+                if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) {
+                    ggml_visit_parents_graph(cgraph, src, true);
+                }
+            }
+        }
+
         return node_hash_pos;
     }
 
+    // This is the first time we see this node in the current graph.
+    cgraph->visited_hash_set.keys[node_hash_pos] = node;
+    ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
+    cgraph->use_counts[node_hash_pos] = 0;
+
     for (int i = 0; i < GGML_MAX_SRC; ++i) {
         const int k =
             (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
@@ -6742,7 +6848,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
 
         struct ggml_tensor * src = node->src[k];
         if (src) {
-            size_t src_hash_pos = ggml_visit_parents(cgraph, src);
+            const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute);
 
             // Update the use count for this operand.
             cgraph->use_counts[src_hash_pos]++;
@@ -6773,17 +6879,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor
     return node_hash_pos;
 }
 
-static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) {
     if (!expand) {
         // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
         ggml_graph_clear(cgraph);
     }
 
-    const int n0 = cgraph->n_nodes;
+    const int n_old = cgraph->n_nodes;
 
-    ggml_visit_parents(cgraph, tensor);
+    ggml_visit_parents_graph(cgraph, tensor, compute);
 
-    const int n_new = cgraph->n_nodes - n0;
+    const int n_new = cgraph->n_nodes - n_old;
     GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
 
     if (n_new > 0) {
@@ -6792,8 +6898,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
     }
 }
 
+struct ggml_tensor * ggml_build_forward_select(
+        struct ggml_cgraph  * cgraph,
+        struct ggml_tensor ** tensors,
+        int                   n_tensors,
+        int                   idx) {
+    GGML_ASSERT(idx >= 0 && idx < n_tensors);
+
+    for (int i = 0; i < n_tensors; i++) {
+        ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false);
+    }
+
+    return tensors[idx];
+}
+
 void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
-    ggml_build_forward_impl(cgraph, tensor, true);
+    ggml_build_forward_impl(cgraph, tensor, true, true);
 }
 
 void ggml_build_backward_expand(
@@ -7224,6 +7344,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
             return false;
         }
 
+        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+            return false;
+        }
+
         if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
             continue;
         }
@@ -7305,7 +7429,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node,
             label);
 }
 
-void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) {
     char color[16];
 
     FILE * fp = ggml_fopen(filename, "w");
@@ -7326,7 +7450,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
         if (node->flags & GGML_TENSOR_FLAG_PARAM) {
             snprintf(color, sizeof(color), "yellow");
         } else if (grad) {
-            if (ggml_graph_find(gf, node)) {
+            if (ggml_graph_find(cgraph, node)) {
                 snprintf(color, sizeof(color), "green");
             } else {
                 snprintf(color, sizeof(color), "lightblue");
@@ -7478,8 +7602,11 @@ void ggml_quantize_free(void) {
 
     iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
     iq2xs_free_impl(GGML_TYPE_IQ2_XS);
+    iq2xs_free_impl(GGML_TYPE_IQ2_S);
     iq2xs_free_impl(GGML_TYPE_IQ1_S);
+    iq2xs_free_impl(GGML_TYPE_IQ1_M);
     iq3xs_free_impl(256);
+    iq3xs_free_impl(512);
 
     ggml_critical_section_end();
 }
@@ -7523,6 +7650,7 @@ size_t ggml_quantize_chunk(
         case GGML_TYPE_Q5_1:    result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q8_0:    result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_MXFP4:   result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_NVFP4:   result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q2_K:    result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q3_K:    result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q4_K:    result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
index 53504399..cbeedf6c 100644
--- a/ggml/src/gguf.cpp
+++ b/ggml/src/gguf.cpp
@@ -15,6 +15,17 @@
 #include 
 #include 
 
+#define GGUF_MAX_STRING_LENGTH  (1024*1024*1024)
+#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024)
+
+#ifdef _WIN32
+#    define gguf_ftell _ftelli64
+#    define gguf_fseek _fseeki64
+#else
+#    define gguf_ftell ftello
+#    define gguf_fseek fseeko
+#endif
+
 template 
 struct type_to_gguf_type;
 
@@ -217,17 +228,64 @@ struct gguf_context {
 };
 
 struct gguf_reader {
-    FILE * file;
+    gguf_reader(FILE * file) : file(file) {
+        // read the remaining bytes once and update on each read
+        nbytes_remain = file_remain(file);
+    }
 
-    gguf_reader(FILE * file) : file(file) {}
+    // helper for remaining bytes in a file
+    static uint64_t file_remain(FILE * file) {
+        const int64_t cur = gguf_ftell(file);
+        if (cur < 0) {
+            return 0;
+        }
+        if (gguf_fseek(file, 0, SEEK_END) != 0) {
+            gguf_fseek(file, cur, SEEK_SET);
+
+            return 0;
+        }
+        const int64_t end = gguf_ftell(file);
+        if (end < 0) {
+            gguf_fseek(file, cur, SEEK_SET);
+
+            return 0;
+        }
+        gguf_fseek(file, cur, SEEK_SET);
+        return static_cast(end - cur);
+    }
 
     template 
     bool read(T & dst) const {
-        return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
+        const size_t size = sizeof(dst);
+        if (nbytes_remain < size) {
+            return false;
+        }
+        const size_t nread = fread(&dst, 1, size, file);
+        nbytes_remain -= nread;
+        return nread == size;
     }
 
     template 
     bool read(std::vector & dst, const size_t n) const {
+        if (n > GGUF_MAX_ARRAY_ELEMENTS) {
+            return false;
+        }
+        if constexpr (std::is_same::value) {
+            // strings are prefixed with their length, so we need to account for that
+            if (n > SIZE_MAX / sizeof(uint64_t)) {
+                return false;
+            }
+            if (nbytes_remain < n * sizeof(uint64_t)) {
+                return false;
+            }
+        } else {
+            if (n > SIZE_MAX / sizeof(T)) {
+                return false;
+            }
+            if (nbytes_remain < n * sizeof(T)) {
+                return false;
+            }
+        }
         dst.resize(n);
         for (size_t i = 0; i < dst.size(); ++i) {
             if constexpr (std::is_same::value) {
@@ -273,17 +331,37 @@ struct gguf_reader {
     }
 
     bool read(std::string & dst) const {
-        uint64_t size = -1;
+        uint64_t size = 0;
         if (!read(size)) {
             return false;
         }
-        dst.resize(size);
-        return fread(dst.data(), 1, dst.length(), file) == dst.length();
+        if (size > GGUF_MAX_STRING_LENGTH) {
+            GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH);
+            return false;
+        }
+        if (size > nbytes_remain) {
+            GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain);
+            return false;
+        }
+        dst.resize(static_cast(size));
+        const size_t nread = fread(dst.data(), 1, size, file);
+        nbytes_remain -= nread;
+        return nread == size;
     }
 
     bool read(void * dst, const size_t size) const {
-        return fread(dst, 1, size, file) == size;
+        if (size > nbytes_remain) {
+            return false;
+        }
+        const size_t nread = fread(dst, 1, size, file);
+        nbytes_remain -= nread;
+        return nread == size;
     }
+
+private:
+    FILE * file;
+
+    mutable uint64_t nbytes_remain;
 };
 
 struct gguf_context * gguf_init_empty(void) {
@@ -523,7 +601,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
 
         // tensor shape
         {
-            uint32_t n_dims = -1;
+            uint32_t n_dims = 0;
             ok = ok && gr.read(n_dims);
             if (n_dims > GGML_MAX_DIMS) {
                 GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
@@ -568,8 +646,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
 
             // check that tensor type is within defined range
             if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
-                GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n",
-                    __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
+                GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n",
+                    __func__, info.t.name, info.t.type, GGML_TYPE_COUNT);
                 ok = false;
                 break;
             }
@@ -585,6 +663,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
                 break;
             }
 
+            // check that the size of the tensor in bytes is representable
+            if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) {
+                GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n",
+                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX);
+                ok = false;
+                break;
+            }
+
             // calculate byte offsets given the tensor shape and type
             info.t.nb[0] = type_size;
             info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
@@ -610,14 +696,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
     GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
 
     // we require the data section to be aligned, so take into account any padding
-    if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
+    if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) {
         GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__);
         gguf_free(ctx);
         return nullptr;
     }
 
     // store the current file offset - this is where the data section starts
-    ctx->offset = ftell(file);
+    ctx->offset = gguf_ftell(file);
 
     // compute the total size of the data section, taking into account the alignment
     {
@@ -649,10 +735,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par
         //   the ggml_tensor structs to the appropriate locations in the binary blob
 
         // compute the exact size needed for the new ggml_context
-        const size_t mem_size =
-            params.no_alloc ?
-            (n_tensors    )*ggml_tensor_overhead() :
-            (n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+        size_t mem_size = 0;
+        if (params.no_alloc) {
+            if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) {
+                GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            const size_t overhead = n_tensors * ggml_tensor_overhead();
+
+            mem_size = overhead;
+        } else {
+            if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) {
+                GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead();
+
+            if (SIZE_MAX - overhead < ctx->size) {
+                GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            mem_size = overhead + ctx->size;
+        }
 
         struct ggml_init_params pdata = {
             /*mem_size   =*/ mem_size,
@@ -734,7 +844,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
     FILE * file = ggml_fopen(fname, "rb");
 
     if (!file) {
-        GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
+        GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno));
         return nullptr;
     }
 
@@ -1166,50 +1276,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
     ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
 }
 
-struct gguf_writer {
-    std::vector & buf;
+struct gguf_writer_base {
+    size_t written_bytes {0u};
 
-    gguf_writer(std::vector & buf) : buf(buf) {}
+    ~gguf_writer_base(void) = default;
+
+    // we bet on devirtualization
+    virtual void write(int8_t val) = 0;
+    virtual void write(const std::vector & val) = 0;
+    virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0;
 
     template 
-    void write(const T & val) const {
+    void write(const T & val) {
         for (size_t i = 0; i < sizeof(val); ++i) {
-            buf.push_back(reinterpret_cast(&val)[i]);
+            write(reinterpret_cast(&val)[i]);
         }
     }
 
-    void write(const std::vector & val) const {
-        buf.insert(buf.end(), val.begin(), val.end());
-    }
-
-    void write(const bool & val) const {
+    void write(const bool & val) {
         const int8_t val8 = val ? 1 : 0;
         write(val8);
     }
 
-    void write(const std::string & val) const {
+    void write(const std::string & val) {
         {
             const uint64_t n = val.length();
             write(n);
         }
         for (size_t i = 0; i < val.length(); ++i) {
-            buf.push_back(reinterpret_cast(val.data())[i]);
+            write((val.data())[i]);
         }
     }
 
-    void write(const char * val) const {
+    void write(const char * val) {
         write(std::string(val));
     }
 
-    void write(const enum ggml_type & val) const {
+    void write(const enum ggml_type & val) {
         write(int32_t(val));
     }
 
-    void write(const enum gguf_type & val) const {
+    void write(const enum gguf_type & val) {
         write(int32_t(val));
     }
 
-    void write(const struct gguf_kv & kv) const {
+    void write(const struct gguf_kv & kv) {
         const uint64_t ne = kv.get_ne();
 
         write(kv.get_key());
@@ -1250,7 +1361,7 @@ struct gguf_writer {
         }
     }
 
-    void write_tensor_meta(const struct gguf_tensor_info & info) const {
+    void write_tensor_meta(const struct gguf_tensor_info & info) {
         write(info.t.name);
 
         const uint32_t n_dims = ggml_n_dims(&info.t);
@@ -1263,14 +1374,33 @@ struct gguf_writer {
         write(info.offset);
     }
 
-    void pad(const size_t alignment) const {
-        while (buf.size() % alignment != 0) {
+    void pad(const size_t alignment) {
+        while (written_bytes % alignment != 0) {
             const int8_t zero = 0;
             write(zero);
         }
     }
+};
 
-    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
+// vector buffer based writer
+struct gguf_writer_buf final : public gguf_writer_base {
+    std::vector & buf;
+
+    gguf_writer_buf(std::vector & buf) : buf(buf) {}
+
+    using gguf_writer_base::write;
+
+    void write(const int8_t val) override {
+        buf.push_back(val);
+        written_bytes++;
+    }
+
+    void write(const std::vector & val) override {
+        buf.insert(buf.end(), val.begin(), val.end());
+        written_bytes += val.size();
+    }
+
+    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
         GGML_ASSERT(buf.size() - offset_data == info.offset);
 
         GGML_ASSERT(ggml_is_contiguous(&info.t));
@@ -1284,14 +1414,58 @@ struct gguf_writer {
             GGML_ASSERT(info.t.data);
             memcpy(buf.data() + offset, info.t.data, nbytes);
         }
+        written_bytes += nbytes;
 
         pad(alignment);
     }
 };
 
-void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) {
-    const struct gguf_writer gw(buf);
+// file based writer
+struct gguf_writer_file final : public gguf_writer_base {
+    FILE * file;
 
+    gguf_writer_file(FILE* file) : file(file) {}
+
+    using gguf_writer_base::write;
+
+    void write(const int8_t val) override {
+        const auto real_val = static_cast(val);
+        const auto ret = fputc(real_val, file);
+        written_bytes++;
+        if (ret != real_val) {
+            throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'");
+        }
+    }
+
+    void write(const std::vector & val) override {
+        const auto ret = fwrite(val.data(), 1, val.size(), file);
+        written_bytes += val.size();
+        if (ret != val.size()) {
+            throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'");
+        }
+    }
+
+    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
+        GGML_ASSERT(written_bytes - offset_data == info.offset);
+
+        GGML_ASSERT(ggml_is_contiguous(&info.t));
+        const size_t nbytes = ggml_nbytes(&info.t);
+
+        std::vector buf(nbytes);
+        if (info.t.buffer) {
+            ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes);
+        } else {
+            GGML_ASSERT(info.t.data);
+            memcpy(buf.data(), info.t.data, nbytes);
+        }
+        write(buf);
+
+        pad(alignment);
+    }
+};
+
+template 
+static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) {
     const int64_t n_kv      = gguf_get_n_kv(ctx);
     const int64_t n_tensors = gguf_get_n_tensors(ctx);
 
@@ -1321,7 +1495,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu
         return;
     }
 
-    const size_t offset_data = gw.buf.size();
+    const size_t offset_data = gw.written_bytes;
 
     // write tensor data
     for (int64_t i = 0; i < n_tensors; ++i) {
@@ -1329,6 +1503,11 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu
     }
 }
 
+void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) {
+    gguf_writer_buf gw(buf);
+    gguf_write_out(ctx, gw, only_meta);
+}
+
 bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
     FILE * file = ggml_fopen(fname, "wb");
 
@@ -1337,11 +1516,17 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
         return false;
     }
 
-    std::vector buf;
-    gguf_write_to_buf(ctx, buf, only_meta);
-    const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size();
+    try {
+        gguf_writer_file gw(file);
+        gguf_write_out(ctx, gw, only_meta);
+    } catch (const std::runtime_error& ex) {
+        GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what());
+        fclose(file);
+        return false;
+    }
+
     fclose(file);
-    return ok;
+    return true;
 }
 
 size_t gguf_get_meta_size(const struct gguf_context * ctx) {
diff --git a/models/convert-whisper-to-openvino.py b/models/convert-whisper-to-openvino.py
index 3124dd3d..a17e5355 100644
--- a/models/convert-whisper-to-openvino.py
+++ b/models/convert-whisper-to-openvino.py
@@ -2,7 +2,6 @@ import argparse
 import torch
 from whisper import load_model
 import os
-from openvino.tools import mo
 from openvino.frontend import FrontEndManager
 from openvino.runtime import serialize
 import shutil
diff --git a/models/requirements-openvino.txt b/models/requirements-openvino.txt
index 5bfd95db..707fa58a 100644
--- a/models/requirements-openvino.txt
+++ b/models/requirements-openvino.txt
@@ -1,2 +1,2 @@
-openvino-dev[pytorch,onnx]
-openai-whisper
\ No newline at end of file
+openvino>=2023.3.0
+openai-whisper
diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh
index 1f87e231..bc7c1b2f 100755
--- a/scripts/sync-ggml-am.sh
+++ b/scripts/sync-ggml-am.sh
@@ -60,8 +60,8 @@ while read c; do
         cmake/common.cmake \
         cmake/ggml-config.cmake.in \
         src/ggml-cpu/cmake/FindSIMD.cmake \
-        src/ggml*.h \
         src/ggml* \
+        src/gguf* \
         include/ggml*.h \
         include/gguf*.h \
         examples/common.h \
@@ -105,6 +105,7 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then
     # src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake
     #
     # src/ggml* -> ggml/src/ggml*.c
+    # src/gguf* -> ggml/src/gguf*.c
     #
     # include/ggml*.h -> ggml/include/ggml*.h
     # include/gguf*.h -> ggml/include/gguf*.h
@@ -126,6 +127,7 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then
         -e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \
         -e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \
         -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)/\1ggml\/src\/ggml\2/g' \
+        -e 's/([[:space:]]| [ab]\/)src\/gguf(.*)/\1ggml\/src\/gguf\2/g' \
         -e 's/(^[[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \
         -e 's/(^[[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \
         -e 's/(^[[:space:]]| [ab]\/)examples\/common\.h/\1examples\/common.h/g' \
diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last
index 44fa890d..709d00d4 100644
--- a/scripts/sync-ggml.last
+++ b/scripts/sync-ggml.last
@@ -1 +1 @@
-b6d1f0f247adcfa25c0ca1ffe97e651fe1afd5e2
+9d0addf420778b42c257cd3837fbd38ca4599f3b
diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh
index 4296ddf5..099d5445 100755
--- a/scripts/sync-ggml.sh
+++ b/scripts/sync-ggml.sh
@@ -7,6 +7,7 @@ cp -rpv ../ggml/cmake/*              ./ggml/cmake/
 cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/
 
 cp -rpv ../ggml/src/ggml* ./ggml/src/
+cp -rpv ../ggml/src/gguf* ./ggml/src/
 
 cp -rpv ../ggml/include/ggml*.h ./ggml/include/
 cp -rpv ../ggml/include/gguf*.h ./ggml/include/
diff --git a/src/whisper.cpp b/src/whisper.cpp
index 5b6e4b4b..796bccfb 100644
--- a/src/whisper.cpp
+++ b/src/whisper.cpp
@@ -6026,6 +6026,19 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) {
     return txt[0] == ' ';
 }
 
+// Count UTF-8 characters (not bytes) in a string
+static int utf8_len(const char * str) {
+    int count = 0;
+    while (*str) {
+        // Skip continuation bytes (10xxxxxx)
+        if ((*str & 0xC0) != 0x80) {
+            count++;
+        }
+        str++;
+    }
+    return count;
+}
+
 static void whisper_exp_compute_token_level_timestamps_dtw(
             struct whisper_context * ctx,
               struct whisper_state * state,
@@ -6054,7 +6067,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
         }
 
         const auto txt = whisper_token_to_str(&ctx, token.id);
-        const int cur = strlen(txt);
+        const int cur = utf8_len(txt);  // Use UTF-8 character count instead of byte count
 
         if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
             state.result_all.back().text = std::move(text);