sycl : enhance set_rows to support q1_0, mxfp4, nvfp4 (llama/24564)
This commit is contained in:
parent
3cb087c42a
commit
d20057908a
|
|
@ -48,6 +48,287 @@ inline void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
|||
}
|
||||
}
|
||||
|
||||
inline void cpy_blck_f32_q1_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q1_0 * dsti = (block_q1_0 *) cdsti;
|
||||
|
||||
float sum_abs = 0.0f;
|
||||
for (int j = 0; j < QK1_0; ++j) {
|
||||
sum_abs += sycl::fabs((float) xi[j]);
|
||||
}
|
||||
|
||||
dsti->d = sum_abs / QK1_0;
|
||||
|
||||
for (int j = 0; j < QK1_0 / 8; ++j) {
|
||||
dsti->qs[j] = 0;
|
||||
}
|
||||
|
||||
for (int j = 0; j < QK1_0; ++j) {
|
||||
if (xi[j] >= 0.0f) {
|
||||
dsti->qs[j / 8] |= (1u << (j % 8));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline int best_index_mxfp4(const float x, const float e) {
|
||||
int best_index = 0;
|
||||
float best_err = sycl::fabs((float) (kvalues_mxfp4[0] * e - x));
|
||||
for (int i = 1; i < 16; ++i) {
|
||||
const float err = sycl::fabs((float) (kvalues_mxfp4[i] * e - x));
|
||||
if (err < best_err) {
|
||||
best_index = i;
|
||||
best_err = err;
|
||||
}
|
||||
}
|
||||
return best_index;
|
||||
}
|
||||
|
||||
inline int nearest_int_sycl(float x) {
|
||||
const float val = x + 12582912.0f;
|
||||
int i;
|
||||
memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
inline int nearest_int_ggml_sycl(float x) {
|
||||
return (int) sycl::round((float) x);
|
||||
}
|
||||
|
||||
inline uint8_t clamp_u8(const int x, const int lo, const int hi) {
|
||||
return (uint8_t) dpct::max(lo, dpct::min(hi, x));
|
||||
}
|
||||
|
||||
inline int8_t clamp_i8(const int x, const int lo, const int hi) {
|
||||
return (int8_t) dpct::max(lo, dpct::min(hi, x));
|
||||
}
|
||||
|
||||
constexpr float GROUP_MAX_EPS_SYCL = 1e-15f;
|
||||
|
||||
inline float make_qx_quants_sycl(int n, int nmax, const float * x, int8_t * L, int rmse_type, const float * qw) {
|
||||
float max = 0.0f;
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const float ax = sycl::fabs(x[i]);
|
||||
if (ax > amax) {
|
||||
amax = ax;
|
||||
max = x[i];
|
||||
}
|
||||
}
|
||||
if (amax < GROUP_MAX_EPS_SYCL) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
L[i] = 0;
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float iscale = -nmax / max;
|
||||
if (rmse_type == 0) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int_ggml_sycl(iscale * x[i]);
|
||||
L[i] = (int8_t) (nmax + dpct::max(-nmax, dpct::min(nmax - 1, l)));
|
||||
}
|
||||
return 1.0f / iscale;
|
||||
}
|
||||
|
||||
bool return_early = false;
|
||||
if (rmse_type < 0) {
|
||||
rmse_type = -rmse_type;
|
||||
return_early = true;
|
||||
}
|
||||
|
||||
float sumlx = 0.0f;
|
||||
float suml2 = 0.0f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int_ggml_sycl(iscale * x[i]);
|
||||
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
|
||||
L[i] = (int8_t) (l + nmax);
|
||||
|
||||
const float w = qw ? qw[i] : (rmse_type == 1 ? x[i] * x[i] :
|
||||
rmse_type == 2 ? 1.0f : rmse_type == 3 ? sycl::fabs(x[i]) : sycl::sqrt(sycl::fabs(x[i])));
|
||||
|
||||
sumlx += w * x[i] * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
|
||||
float scale = suml2 ? sumlx / suml2 : 0.0f;
|
||||
if (return_early) {
|
||||
return suml2 > 0.0f ? 0.5f * (scale + 1.0f / iscale) : 1.0f / iscale;
|
||||
}
|
||||
|
||||
float best = scale * sumlx;
|
||||
for (int is = -9; is <= 9; ++is) {
|
||||
if (is == 0) {
|
||||
continue;
|
||||
}
|
||||
iscale = -(nmax + 0.1f * is) / max;
|
||||
sumlx = 0.0f;
|
||||
suml2 = 0.0f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int_ggml_sycl(iscale * x[i]);
|
||||
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
|
||||
const float w = qw ? qw[i] : (rmse_type == 1 ? x[i] * x[i] :
|
||||
rmse_type == 2 ? 1.0f : rmse_type == 3 ? sycl::fabs(x[i]) : sycl::sqrt(sycl::fabs(x[i])));
|
||||
sumlx += w * x[i] * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
|
||||
if (suml2 > 0.0f && sumlx * sumlx > best * suml2) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int_ggml_sycl(iscale * x[i]);
|
||||
L[i] = (int8_t) (nmax + dpct::max(-nmax, dpct::min(nmax - 1, l)));
|
||||
}
|
||||
scale = sumlx / suml2;
|
||||
best = scale * sumlx;
|
||||
}
|
||||
}
|
||||
|
||||
return scale;
|
||||
}
|
||||
|
||||
inline float make_q3_quants_sycl(int n, int nmax, const float * x, int8_t * L, bool do_rmse) {
|
||||
float max = 0.0f;
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const float ax = sycl::fabs(x[i]);
|
||||
if (ax > amax) {
|
||||
amax = ax;
|
||||
max = x[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (amax < GROUP_MAX_EPS_SYCL) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
L[i] = 0;
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
const float iscale = -nmax / max;
|
||||
if (do_rmse) {
|
||||
float sumlx = 0.0f;
|
||||
float suml2 = 0.0f;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int_ggml_sycl(iscale * x[i]);
|
||||
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
|
||||
L[i] = (int8_t) l;
|
||||
const float w = x[i] * x[i];
|
||||
sumlx += w * x[i] * l;
|
||||
suml2 += w * l * l;
|
||||
}
|
||||
|
||||
for (int itry = 0; itry < 5; ++itry) {
|
||||
int n_changed = 0;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const float w = x[i] * x[i];
|
||||
float slx = sumlx - w * x[i] * L[i];
|
||||
if (slx > 0.0f) {
|
||||
float sl2 = suml2 - w * L[i] * L[i];
|
||||
int new_l = nearest_int_ggml_sycl(x[i] * sl2 / slx);
|
||||
new_l = dpct::max(-nmax, dpct::min(nmax - 1, new_l));
|
||||
if (new_l != L[i]) {
|
||||
slx += w * x[i] * new_l;
|
||||
sl2 += w * new_l * new_l;
|
||||
if (sl2 > 0.0f && slx * slx * suml2 > sumlx * sumlx * sl2) {
|
||||
L[i] = (int8_t) new_l;
|
||||
sumlx = slx;
|
||||
suml2 = sl2;
|
||||
++n_changed;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!n_changed) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
L[i] += nmax;
|
||||
}
|
||||
return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int l = nearest_int_ggml_sycl(iscale * x[i]);
|
||||
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
|
||||
L[i] = (int8_t) (l + nmax);
|
||||
}
|
||||
|
||||
return 1.0f / iscale;
|
||||
}
|
||||
|
||||
inline void set_scale_min_k4(int j, uint8_t * q, uint8_t d, uint8_t m) {
|
||||
if (j < 4) {
|
||||
q[j] = (q[j] & 0xC0) | (d & 0x3F);
|
||||
q[j + 4] = (q[j + 4] & 0xC0) | (m & 0x3F);
|
||||
} else {
|
||||
q[j + 4] = (d & 0x0F) | ((m & 0x0F) << 4);
|
||||
q[j - 4] = (q[j - 4] & 0x3F) | ((d >> 4) << 6);
|
||||
q[j - 0] = (q[j - 0] & 0x3F) | ((m >> 4) << 6);
|
||||
}
|
||||
}
|
||||
|
||||
inline void get_scale_min_k4_local(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||
if (j < 4) {
|
||||
d = q[j] & 63;
|
||||
m = q[j + 4] & 63;
|
||||
} else {
|
||||
d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
|
||||
m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
inline void cpy_blck_f32_mxfp4(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_mxfp4 * dsti = (block_mxfp4 *) cdsti;
|
||||
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < QK_MXFP4; ++j) {
|
||||
amax = sycl::fmax(amax, sycl::fabs((float) xi[j]));
|
||||
}
|
||||
|
||||
const uint8_t e = amax > 0.0f ? (uint8_t) (sycl::floor(sycl::log2(amax)) - 2 + 127) : 0;
|
||||
const float d = GGML_E8M0_TO_FP32_HALF(e);
|
||||
|
||||
dsti->e = e;
|
||||
|
||||
for (int j = 0; j < QK_MXFP4 / 2; ++j) {
|
||||
const uint8_t x0 = best_index_mxfp4(xi[0 + j], d);
|
||||
const uint8_t x1 = best_index_mxfp4(xi[QK_MXFP4 / 2 + j], d);
|
||||
|
||||
dsti->qs[j] = x0;
|
||||
dsti->qs[j] |= x1 << 4;
|
||||
}
|
||||
}
|
||||
|
||||
inline void cpy_blck_f32_nvfp4(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_nvfp4 * dsti = (block_nvfp4 *) cdsti;
|
||||
|
||||
constexpr int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
|
||||
|
||||
for (int s = 0; s < n_sub; ++s) {
|
||||
const float * xb = xi + s * QK_NVFP4_SUB;
|
||||
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < QK_NVFP4_SUB; ++j) {
|
||||
amax = sycl::fmax(amax, sycl::fabs((float) xb[j]));
|
||||
}
|
||||
|
||||
const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);
|
||||
dsti->d[s] = ue;
|
||||
const float d = ggml_ue4m3_to_fp32(ue);
|
||||
|
||||
for (int j = 0; j < QK_NVFP4_SUB / 2; ++j) {
|
||||
const uint8_t x0 = best_index_mxfp4(xb[0 + j], d);
|
||||
const uint8_t x1 = best_index_mxfp4(xb[QK_NVFP4_SUB / 2 + j], d);
|
||||
|
||||
dsti->qs[s * (QK_NVFP4_SUB / 2) + j] = x0 | (x1 << 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q4_0 * dsti = (block_q4_0 *) cdsti;
|
||||
|
|
|
|||
|
|
@ -5242,10 +5242,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||
|
||||
auto res = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
|
||||
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
|
||||
op->type == GGML_TYPE_Q1_0 ||
|
||||
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL ||
|
||||
op->type == GGML_TYPE_MXFP4 || op->type == GGML_TYPE_NVFP4) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 &&
|
||||
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
|
||||
return res;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ static void set_rows_sycl(
|
|||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(grid_size * block_size, block_size),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
[=](sycl::nd_item<1> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
k_set_rows<TIn, TIdx, TOut>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02,
|
||||
|
|
@ -202,6 +202,9 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s
|
|||
case GGML_TYPE_Q8_0:
|
||||
set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q1_0:
|
||||
set_rows_sycl_q<TIdx, block_q1_0, QK1_0, cpy_blck_f32_q1_0>(src0_d, src1_d, (block_q1_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
set_rows_sycl_q<TIdx, block_q5_1, QK5_1, cpy_blck_f32_q5_1>(src0_d, src1_d, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
|
|
@ -217,7 +220,12 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s
|
|||
case GGML_TYPE_IQ4_NL:
|
||||
set_rows_sycl_q<TIdx, block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>(src0_d, src1_d, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
set_rows_sycl_q<TIdx, block_mxfp4, QK_MXFP4, cpy_blck_f32_mxfp4>(src0_d, src1_d, (block_mxfp4 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
set_rows_sycl_q<TIdx, block_nvfp4, QK_NVFP4, cpy_blck_f32_nvfp4>(src0_d, src1_d, (block_nvfp4 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported tensor type!");
|
||||
break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue