sycl : enhance set_rows to support q1_0, mxfp4, nvfp4 (llama/24564)

This commit is contained in:
Neo Zhang 2026-06-15 15:01:40 +08:00 committed by Georgi Gerganov
parent 3cb087c42a
commit d20057908a
3 changed files with 298 additions and 4 deletions

View File

@ -48,6 +48,287 @@ inline void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
}
}
inline void cpy_blck_f32_q1_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q1_0 * dsti = (block_q1_0 *) cdsti;
float sum_abs = 0.0f;
for (int j = 0; j < QK1_0; ++j) {
sum_abs += sycl::fabs((float) xi[j]);
}
dsti->d = sum_abs / QK1_0;
for (int j = 0; j < QK1_0 / 8; ++j) {
dsti->qs[j] = 0;
}
for (int j = 0; j < QK1_0; ++j) {
if (xi[j] >= 0.0f) {
dsti->qs[j / 8] |= (1u << (j % 8));
}
}
}
inline int best_index_mxfp4(const float x, const float e) {
int best_index = 0;
float best_err = sycl::fabs((float) (kvalues_mxfp4[0] * e - x));
for (int i = 1; i < 16; ++i) {
const float err = sycl::fabs((float) (kvalues_mxfp4[i] * e - x));
if (err < best_err) {
best_index = i;
best_err = err;
}
}
return best_index;
}
inline int nearest_int_sycl(float x) {
const float val = x + 12582912.0f;
int i;
memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
}
inline int nearest_int_ggml_sycl(float x) {
return (int) sycl::round((float) x);
}
inline uint8_t clamp_u8(const int x, const int lo, const int hi) {
return (uint8_t) dpct::max(lo, dpct::min(hi, x));
}
inline int8_t clamp_i8(const int x, const int lo, const int hi) {
return (int8_t) dpct::max(lo, dpct::min(hi, x));
}
constexpr float GROUP_MAX_EPS_SYCL = 1e-15f;
inline float make_qx_quants_sycl(int n, int nmax, const float * x, int8_t * L, int rmse_type, const float * qw) {
float max = 0.0f;
float amax = 0.0f;
for (int i = 0; i < n; ++i) {
const float ax = sycl::fabs(x[i]);
if (ax > amax) {
amax = ax;
max = x[i];
}
}
if (amax < GROUP_MAX_EPS_SYCL) {
for (int i = 0; i < n; ++i) {
L[i] = 0;
}
return 0.0f;
}
float iscale = -nmax / max;
if (rmse_type == 0) {
for (int i = 0; i < n; ++i) {
int l = nearest_int_ggml_sycl(iscale * x[i]);
L[i] = (int8_t) (nmax + dpct::max(-nmax, dpct::min(nmax - 1, l)));
}
return 1.0f / iscale;
}
bool return_early = false;
if (rmse_type < 0) {
rmse_type = -rmse_type;
return_early = true;
}
float sumlx = 0.0f;
float suml2 = 0.0f;
for (int i = 0; i < n; ++i) {
int l = nearest_int_ggml_sycl(iscale * x[i]);
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
L[i] = (int8_t) (l + nmax);
const float w = qw ? qw[i] : (rmse_type == 1 ? x[i] * x[i] :
rmse_type == 2 ? 1.0f : rmse_type == 3 ? sycl::fabs(x[i]) : sycl::sqrt(sycl::fabs(x[i])));
sumlx += w * x[i] * l;
suml2 += w * l * l;
}
float scale = suml2 ? sumlx / suml2 : 0.0f;
if (return_early) {
return suml2 > 0.0f ? 0.5f * (scale + 1.0f / iscale) : 1.0f / iscale;
}
float best = scale * sumlx;
for (int is = -9; is <= 9; ++is) {
if (is == 0) {
continue;
}
iscale = -(nmax + 0.1f * is) / max;
sumlx = 0.0f;
suml2 = 0.0f;
for (int i = 0; i < n; ++i) {
int l = nearest_int_ggml_sycl(iscale * x[i]);
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
const float w = qw ? qw[i] : (rmse_type == 1 ? x[i] * x[i] :
rmse_type == 2 ? 1.0f : rmse_type == 3 ? sycl::fabs(x[i]) : sycl::sqrt(sycl::fabs(x[i])));
sumlx += w * x[i] * l;
suml2 += w * l * l;
}
if (suml2 > 0.0f && sumlx * sumlx > best * suml2) {
for (int i = 0; i < n; ++i) {
int l = nearest_int_ggml_sycl(iscale * x[i]);
L[i] = (int8_t) (nmax + dpct::max(-nmax, dpct::min(nmax - 1, l)));
}
scale = sumlx / suml2;
best = scale * sumlx;
}
}
return scale;
}
inline float make_q3_quants_sycl(int n, int nmax, const float * x, int8_t * L, bool do_rmse) {
float max = 0.0f;
float amax = 0.0f;
for (int i = 0; i < n; ++i) {
const float ax = sycl::fabs(x[i]);
if (ax > amax) {
amax = ax;
max = x[i];
}
}
if (amax < GROUP_MAX_EPS_SYCL) {
for (int i = 0; i < n; ++i) {
L[i] = 0;
}
return 0.0f;
}
const float iscale = -nmax / max;
if (do_rmse) {
float sumlx = 0.0f;
float suml2 = 0.0f;
for (int i = 0; i < n; ++i) {
int l = nearest_int_ggml_sycl(iscale * x[i]);
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
L[i] = (int8_t) l;
const float w = x[i] * x[i];
sumlx += w * x[i] * l;
suml2 += w * l * l;
}
for (int itry = 0; itry < 5; ++itry) {
int n_changed = 0;
for (int i = 0; i < n; ++i) {
const float w = x[i] * x[i];
float slx = sumlx - w * x[i] * L[i];
if (slx > 0.0f) {
float sl2 = suml2 - w * L[i] * L[i];
int new_l = nearest_int_ggml_sycl(x[i] * sl2 / slx);
new_l = dpct::max(-nmax, dpct::min(nmax - 1, new_l));
if (new_l != L[i]) {
slx += w * x[i] * new_l;
sl2 += w * new_l * new_l;
if (sl2 > 0.0f && slx * slx * suml2 > sumlx * sumlx * sl2) {
L[i] = (int8_t) new_l;
sumlx = slx;
suml2 = sl2;
++n_changed;
}
}
}
}
if (!n_changed) {
break;
}
}
for (int i = 0; i < n; ++i) {
L[i] += nmax;
}
return suml2 > 0.0f ? sumlx / suml2 : 0.0f;
}
for (int i = 0; i < n; ++i) {
int l = nearest_int_ggml_sycl(iscale * x[i]);
l = dpct::max(-nmax, dpct::min(nmax - 1, l));
L[i] = (int8_t) (l + nmax);
}
return 1.0f / iscale;
}
inline void set_scale_min_k4(int j, uint8_t * q, uint8_t d, uint8_t m) {
if (j < 4) {
q[j] = (q[j] & 0xC0) | (d & 0x3F);
q[j + 4] = (q[j + 4] & 0xC0) | (m & 0x3F);
} else {
q[j + 4] = (d & 0x0F) | ((m & 0x0F) << 4);
q[j - 4] = (q[j - 4] & 0x3F) | ((d >> 4) << 6);
q[j - 0] = (q[j - 0] & 0x3F) | ((m >> 4) << 6);
}
}
inline void get_scale_min_k4_local(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63;
m = q[j + 4] & 63;
} else {
d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
}
}
inline void cpy_blck_f32_mxfp4(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_mxfp4 * dsti = (block_mxfp4 *) cdsti;
float amax = 0.0f;
for (int j = 0; j < QK_MXFP4; ++j) {
amax = sycl::fmax(amax, sycl::fabs((float) xi[j]));
}
const uint8_t e = amax > 0.0f ? (uint8_t) (sycl::floor(sycl::log2(amax)) - 2 + 127) : 0;
const float d = GGML_E8M0_TO_FP32_HALF(e);
dsti->e = e;
for (int j = 0; j < QK_MXFP4 / 2; ++j) {
const uint8_t x0 = best_index_mxfp4(xi[0 + j], d);
const uint8_t x1 = best_index_mxfp4(xi[QK_MXFP4 / 2 + j], d);
dsti->qs[j] = x0;
dsti->qs[j] |= x1 << 4;
}
}
inline void cpy_blck_f32_nvfp4(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_nvfp4 * dsti = (block_nvfp4 *) cdsti;
constexpr int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
for (int s = 0; s < n_sub; ++s) {
const float * xb = xi + s * QK_NVFP4_SUB;
float amax = 0.0f;
for (int j = 0; j < QK_NVFP4_SUB; ++j) {
amax = sycl::fmax(amax, sycl::fabs((float) xb[j]));
}
const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);
dsti->d[s] = ue;
const float d = ggml_ue4m3_to_fp32(ue);
for (int j = 0; j < QK_NVFP4_SUB / 2; ++j) {
const uint8_t x0 = best_index_mxfp4(xb[0 + j], d);
const uint8_t x1 = best_index_mxfp4(xb[QK_NVFP4_SUB / 2 + j], d);
dsti->qs[s * (QK_NVFP4_SUB / 2) + j] = x0 | (x1 << 4);
}
}
}
inline void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;

View File

@ -5242,10 +5242,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SET_ROWS:
{
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
auto res = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->type == GGML_TYPE_Q1_0 ||
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL ||
op->type == GGML_TYPE_MXFP4 || op->type == GGML_TYPE_NVFP4) &&
op->src[0]->type == GGML_TYPE_F32 &&
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
return res;
}
break;
case GGML_OP_CPY:

View File

@ -135,7 +135,7 @@ static void set_rows_sycl(
stream->parallel_for(
sycl::nd_range<1>(grid_size * block_size, block_size),
[=](sycl::nd_item<1> item_ct1) {
[=](sycl::nd_item<1> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
k_set_rows<TIn, TIdx, TOut>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02,
@ -202,6 +202,9 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s
case GGML_TYPE_Q8_0:
set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q1_0:
set_rows_sycl_q<TIdx, block_q1_0, QK1_0, cpy_blck_f32_q1_0>(src0_d, src1_d, (block_q1_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q5_1:
set_rows_sycl_q<TIdx, block_q5_1, QK5_1, cpy_blck_f32_q5_1>(src0_d, src1_d, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
@ -217,7 +220,12 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s
case GGML_TYPE_IQ4_NL:
set_rows_sycl_q<TIdx, block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>(src0_d, src1_d, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_MXFP4:
set_rows_sycl_q<TIdx, block_mxfp4, QK_MXFP4, cpy_blck_f32_mxfp4>(src0_d, src1_d, (block_mxfp4 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_NVFP4:
set_rows_sycl_q<TIdx, block_nvfp4, QK_NVFP4, cpy_blck_f32_nvfp4>(src0_d, src1_d, (block_nvfp4 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
default:
GGML_ABORT("Unsupported tensor type!");
break;