Use fp32 in cuBLAS V100 to avoid overflows, env variables to override cuBLAS compute type (llama/19959)
* Update ggml-cuda.cu * Update ggml-cuda.cu * Update build.md * Update build.md * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml-cuda.cu * Update build.md * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update build.md * Update ggml-cuda.cu * Update ggml-cuda.cu --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
96b163e874
commit
8ad5cb1e9d
|
|
@ -1242,6 +1242,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
|
|||
}
|
||||
}
|
||||
|
||||
struct cublas_force_compute_type {
|
||||
bool fp32 = false;
|
||||
bool fp16 = false;
|
||||
};
|
||||
|
||||
static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() {
|
||||
static const cublas_force_compute_type compute_type = [] {
|
||||
cublas_force_compute_type result;
|
||||
|
||||
const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr;
|
||||
const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr;
|
||||
|
||||
GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false);
|
||||
|
||||
if (ggml_cuda_force_cublas_compute_32f_env) {
|
||||
GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n");
|
||||
result.fp32 = true;
|
||||
} else if (ggml_cuda_force_cublas_compute_16f_env) {
|
||||
GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n");
|
||||
result.fp16 = true;
|
||||
}
|
||||
|
||||
return result;
|
||||
}();
|
||||
|
||||
return compute_type;
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_mul_mat_cublas(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
|
|
@ -1324,7 +1352,13 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|||
|
||||
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
|
||||
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
|
||||
|
||||
if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
|
||||
|| GGML_CUDA_CC_IS_RDNA4(cc)
|
||||
|| cc == GGML_CUDA_CC_VOLTA
|
||||
|| force_compute_type.fp32))
|
||||
{
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
CUBLAS_CHECK(
|
||||
|
|
@ -1923,10 +1957,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|||
cudaDataType_t cu_data_type_b = traits::data_type;
|
||||
const void * alpha = traits::get_alpha();
|
||||
const void * beta = traits::get_beta();
|
||||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type();
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16;
|
||||
|
||||
// bf16 and fp32 are already being computed in fp32 (ensure it using static_assert),
|
||||
// so checking necessity of forced fp32 only for fp16 src0_type
|
||||
static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F);
|
||||
|
||||
const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc)
|
||||
|| GGML_CUDA_CC_IS_RDNA4(cc)
|
||||
|| cc == GGML_CUDA_CC_VOLTA
|
||||
|| force_compute_type.fp32);
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) {
|
||||
if constexpr (src0_type == GGML_TYPE_F32) {
|
||||
dst_t = (char *) dst_ddf; // Direct F32 output
|
||||
} else {
|
||||
|
|
@ -1936,18 +1983,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|||
}
|
||||
} else {
|
||||
dst_t = (char *) dst_ddf;
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
cu_data_type = CUDA_R_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
alpha = &alpha_f32;
|
||||
beta = &beta_f32;
|
||||
cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type;
|
||||
cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type;
|
||||
alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha();
|
||||
beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta();
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue