diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 9ec94464b..96586ea46 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -324,6 +324,11 @@ void ggml_sycl_free_device(void *ptr, sycl::queue &q); void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + namespace sycl_ex = sycl::ext::oneapi::experimental; struct ggml_backend_sycl_context { int device; @@ -421,6 +426,8 @@ struct ggml_backend_sycl_context { std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + std::vector mmid_row_mapping_host; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f029f6325..376d4376f 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4224,11 +4224,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } -struct mmid_row_mapping { - int32_t i1; - int32_t i2; -}; - __dpct_inline__ static void k_copy_src1_to_contiguous( const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, const mmid_row_mapping *__restrict__ row_mapping, @@ -4399,6 +4394,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); + + // also ensures ctx.mmid_row_mapping_host is drained before we use it again SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); ggml_tensor src0_row = *src0; @@ -4456,7 +4453,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, // where each expert's slice starts and the previous ends (row indices, right-exclusive) std::vector expert_row_offsets; // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous - std::vector routed_row_src; + std::vector & routed_row_src = ctx.mmid_row_mapping_host; mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows, expert_row_counts, expert_row_offsets, routed_row_src);