SYCL: fix use-after-free bug with async memcpy in MoE prefill (llama/24676)

* SYCL: fix a bug with async memcpy

* make mmid_row_mapping_host persistent

* comment on stream->wait

* Apply suggestion from @sanmai

* Apply suggestion from @sanmai

* Apply suggestion from @sanmai
This commit is contained in:
Alexey Kopytko 2026-06-17 14:57:29 +09:00 committed by Georgi Gerganov
parent dd1a6ca897
commit fddcda58a3
2 changed files with 10 additions and 6 deletions

View File

@ -324,6 +324,11 @@ void ggml_sycl_free_device(void *ptr, sycl::queue &q);
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
namespace sycl_ex = sycl::ext::oneapi::experimental;
struct ggml_backend_sycl_context {
int device;
@ -421,6 +426,8 @@ struct ggml_backend_sycl_context {
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
std::vector<mmid_row_mapping> mmid_row_mapping_host;
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);

View File

@ -4224,11 +4224,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
}
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
__dpct_inline__ static void k_copy_src1_to_contiguous(
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
const mmid_row_mapping *__restrict__ row_mapping,
@ -4399,6 +4394,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
// also ensures ctx.mmid_row_mapping_host is drained before we use it again
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
ggml_tensor src0_row = *src0;
@ -4456,7 +4453,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
std::vector<int64_t> expert_row_offsets;
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
std::vector<mmid_row_mapping> routed_row_src;
std::vector<mmid_row_mapping> & routed_row_src = ctx.mmid_row_mapping_host;
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
expert_row_counts, expert_row_offsets, routed_row_src);