SYCL: fix use-after-free bug with async memcpy in MoE prefill (llama/24676)
* SYCL: fix a bug with async memcpy * make mmid_row_mapping_host persistent * comment on stream->wait * Apply suggestion from @sanmai * Apply suggestion from @sanmai * Apply suggestion from @sanmai
This commit is contained in:
parent
dd1a6ca897
commit
fddcda58a3
|
|
@ -324,6 +324,11 @@ void ggml_sycl_free_device(void *ptr, sycl::queue &q);
|
|||
|
||||
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
namespace sycl_ex = sycl::ext::oneapi::experimental;
|
||||
struct ggml_backend_sycl_context {
|
||||
int device;
|
||||
|
|
@ -421,6 +426,8 @@ struct ggml_backend_sycl_context {
|
|||
|
||||
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
|
||||
|
||||
std::vector<mmid_row_mapping> mmid_row_mapping_host;
|
||||
|
||||
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
|
||||
|
||||
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
|
||||
|
|
|
|||
|
|
@ -4224,11 +4224,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
}
|
||||
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
__dpct_inline__ static void k_copy_src1_to_contiguous(
|
||||
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
|
||||
const mmid_row_mapping *__restrict__ row_mapping,
|
||||
|
|
@ -4399,6 +4394,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
||||
|
||||
// also ensures ctx.mmid_row_mapping_host is drained before we use it again
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
||||
|
||||
ggml_tensor src0_row = *src0;
|
||||
|
|
@ -4456,7 +4453,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
|||
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
|
||||
std::vector<int64_t> expert_row_offsets;
|
||||
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
|
||||
std::vector<mmid_row_mapping> routed_row_src;
|
||||
std::vector<mmid_row_mapping> & routed_row_src = ctx.mmid_row_mapping_host;
|
||||
|
||||
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
|
||||
expert_row_counts, expert_row_offsets, routed_row_src);
|
||||
|
|
|
|||
Loading…
Reference in New Issue