diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index ed4e5de7..0f3f017b 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, size_t temp_storage_bytes = 0; + bool is_capturing = false; +#ifdef USE_CUDA_GRAPH + // Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does. + // See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149 + // TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release. + cudaStreamCaptureStatus capture_status; + CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); + is_capturing = (capture_status != cudaStreamCaptureStatusNone); +#endif // USE_CUDA_GRAPH + if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - offset_iterator, offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, - stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } @@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream)); } } else { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, offset_iterator, - offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } }