diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 2e3f97c72..3022249c7 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1) __shared__ CubTempStorage cub_temp_storage; BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + __syncthreads(); BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); #else const int stride_s0 = src0_nb2 / sizeof(float); @@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1) regs0[n] = state; } y_block[i * stride_y + threadIdx.x] = sumf; + __syncthreads(); } #ifdef USE_CUB @@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); - const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); switch (n_tok) { case 1: