cuda: reserve space for quantize kv-cache at startup (llama/23907)
* cuda: reserve space for quantize kv-cache at startup * address review comments * remove forward decl Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * remove assert in ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
f110ff540c
commit
d5a49ebec8
|
|
@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)(
|
|||
typedef float (*vec_dot_KQ_t)(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
||||
|
||||
struct ggml_cuda_flash_attn_ext_f16_extra_data {
|
||||
uintptr_t K;
|
||||
uintptr_t V;
|
||||
uintptr_t end;
|
||||
};
|
||||
|
||||
static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data(
|
||||
const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) {
|
||||
GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
|
||||
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
|
||||
|
||||
ggml_cuda_flash_attn_ext_f16_extra_data data = {};
|
||||
data.end = (uintptr_t) dst->data + ggml_nbytes(dst);
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
data.end = GGML_PAD(data.end, 128);
|
||||
data.K = data.end;
|
||||
data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V_is_K_view) {
|
||||
data.V = data.K;
|
||||
} else {
|
||||
data.end = GGML_PAD(data.end, 128);
|
||||
data.V = data.end;
|
||||
data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16);
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
template <int D, int nthreads>
|
||||
static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
|
||||
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
||||
|
|
@ -952,8 +992,9 @@ void launch_fattn(
|
|||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
|
||||
ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V);
|
||||
|
||||
ggml_cuda_pool_alloc<int> KV_max(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
|
@ -972,10 +1013,11 @@ void launch_fattn(
|
|||
const size_t bs = ggml_blck_size(K->type);
|
||||
const size_t ts = ggml_type_size(K->type);
|
||||
|
||||
K_f16.alloc(ggml_nelements(K));
|
||||
GGML_ASSERT(f16_extra.K != 0);
|
||||
half * K_f16 = (half *) f16_extra.K;
|
||||
if (ggml_is_contiguously_allocated(K)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
||||
to_fp16(K_data, K_f16, ggml_nelements(K), main_stream);
|
||||
|
||||
nb11 = nb11*bs*sizeof(half)/ts;
|
||||
nb12 = nb12*bs*sizeof(half)/ts;
|
||||
|
|
@ -986,13 +1028,13 @@ void launch_fattn(
|
|||
const int64_t s01 = nb11 / ts;
|
||||
const int64_t s02 = nb12 / ts;
|
||||
const int64_t s03 = nb13 / ts;
|
||||
to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
||||
to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb11 = K->ne[0] * sizeof(half);
|
||||
nb12 = K->ne[1] * nb11;
|
||||
nb13 = K->ne[2] * nb12;
|
||||
}
|
||||
K_data = (char *) K_f16.ptr;
|
||||
K_data = (char *) K_f16;
|
||||
}
|
||||
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
|
|
@ -1005,11 +1047,12 @@ void launch_fattn(
|
|||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
GGML_ASSERT(f16_extra.V != 0);
|
||||
half * V_f16 = (half *) f16_extra.V;
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
to_fp16(V_data, V_f16, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
|
|
@ -1020,13 +1063,13 @@ void launch_fattn(
|
|||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
V_data = (char *) V_f16;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -537,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_TILE;
|
||||
}
|
||||
|
||||
size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K != nullptr);
|
||||
GGML_ASSERT(V != nullptr);
|
||||
|
||||
const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst);
|
||||
|
||||
bool need_f16_K = false;
|
||||
bool need_f16_V = false;
|
||||
|
||||
switch (kernel) {
|
||||
case BEST_FATTN_KERNEL_TILE:
|
||||
case BEST_FATTN_KERNEL_WMMA_F16:
|
||||
case BEST_FATTN_KERNEL_MMA_F16:
|
||||
need_f16_K = true;
|
||||
need_f16_V = true;
|
||||
break;
|
||||
case BEST_FATTN_KERNEL_VEC:
|
||||
need_f16_K = K->type == GGML_TYPE_F32;
|
||||
need_f16_V = V->type == GGML_TYPE_F32;
|
||||
break;
|
||||
case BEST_FATTN_KERNEL_NONE:
|
||||
break;
|
||||
}
|
||||
|
||||
const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra =
|
||||
ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V);
|
||||
|
||||
return f16_extra.end - (uintptr_t) dst->data;
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
|
||||
|
|
|
|||
|
|
@ -3,3 +3,5 @@
|
|||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);
|
||||
|
||||
size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst);
|
||||
|
|
|
|||
|
|
@ -801,7 +801,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty
|
|||
}
|
||||
|
||||
static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
size_t size = ggml_nbytes(tensor);
|
||||
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context;
|
||||
|
||||
size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT
|
||||
? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor)
|
||||
: ggml_nbytes(tensor);
|
||||
int64_t ne0 = tensor->ne[0];
|
||||
|
||||
if (ggml_is_quantized(tensor->type)) {
|
||||
|
|
@ -812,8 +816,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t
|
|||
}
|
||||
|
||||
return size;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue