Add more types in GET_ROWS OP (llama/23710)

* add to support Q1_0, NVFP4, IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ1_S, IQ1_M, IQ3_S, IQ4_NL, IQ4_XS, I32, MXFP4, Q2_K, Q3_K, Q5_K, and Q6_K in GET_ROWS OP

* correct the link
This commit is contained in:
Neo Zhang 2026-06-01 14:53:04 +08:00 committed by Georgi Gerganov
parent 687fbcb149
commit 20323e48c4
3 changed files with 565 additions and 3 deletions

View File

@ -20,6 +20,10 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int
typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs,
const int iqs, dfloat2 &v);
#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m);
#endif
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q4_0 * x = (const block_q4_0 *) vx;
@ -90,6 +94,474 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q4_K(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_q4_K * x = (const block_q4_K *) vx;
const sycl::half2 dm = x[ib].dm;
const float dall = dm[0];
const float dmin = dm[1];
auto dequantize_one = [&](const int idx) -> dfloat {
const int il = idx / 64;
const int in = idx % 64;
const int is = 2 * il + (in >= 32 ? 1 : 0);
const int off = in & 31;
const int qsi = 32 * il + off;
uint8_t sc;
uint8_t m;
get_scale_min_k4(is, x[ib].scales, sc, m);
const uint8_t q = x[ib].qs[qsi];
const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF);
return sycl::fma((dfloat) qv, (dfloat) (dall * sc), (dfloat) (-dmin * m));
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("Q4_K dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_q2_K(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_q2_K * x = (const block_q2_K *) vx;
const float dall = x[ib].dm[0];
const float dmin = x[ib].dm[1];
auto dequantize_one = [&](const int idx) -> dfloat {
const int n = idx / 128;
const int r = idx % 128;
const int g = r / 32;
const int l = r % 32;
const int is = 8 * n + l / 16;
const uint8_t q = x[ib].qs[32 * n + l];
const uint8_t sc = x[ib].scales[is + 2 * g];
const float d = dall * (sc & 0xF);
const float m = dmin * (sc >> 4);
return sycl::fma((dfloat) ((q >> (2 * g)) & 3), (dfloat) d, (dfloat) (-m));
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("Q2_K dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_q3_K(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_q3_K * x = (const block_q3_K *) vx;
const float d_all = x[ib].d;
auto dequantize_one = [&](const int idx) -> dfloat {
const int n = idx / 128;
const int r = idx % 128;
const int j = r / 32;
const int l = r % 32;
const int is0 = l / 16;
const int is = 8 * n + 2 * j + is0;
const int shift = 2 * j;
const uint8_t m = 1 << (4 * n + j);
const int8_t us = is < 4 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 8] >> 0) & 3) << 4) :
is < 8 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 4] >> 2) & 3) << 4) :
is < 12 ? (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is + 0] >> 4) & 3) << 4) :
(x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is - 4] >> 6) & 3) << 4);
const float dl = d_all * (us - 32);
const uint8_t q = x[ib].qs[32 * n + l];
const uint8_t h = x[ib].hmask[l];
const int8_t qv = ((q >> shift) & 3) - ((h & m) ? 0 : 4);
return (dfloat) (dl * qv);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("Q3_K dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_q5_K(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_q5_K * x = (const block_q5_K *) vx;
const float dall = x[ib].dm[0];
const float dmin = x[ib].dm[1];
auto dequantize_one = [&](const int idx) -> dfloat {
const int il = idx / 64;
const int in = idx % 64;
const int is = 2 * il + (in >= 32 ? 1 : 0);
const int ir = (in & 31) / 2;
const int iq = in & 1;
const uint8_t q = x[ib].qs[32 * il + 2 * ir + iq];
const uint8_t h = x[ib].qh[2 * ir + iq];
const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF);
uint8_t sc;
uint8_t m;
get_scale_min_k4(is, x[ib].scales, sc, m);
const float d = dall * sc;
const float mn = dmin * m;
const uint8_t hm = 1 << (2 * il + (in >= 32 ? 1 : 0));
return sycl::fma((dfloat) (qv + ((h & hm) ? 16 : 0)), (dfloat) d, (dfloat) (-mn));
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("Q5_K dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_q6_K(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_q6_K * x = (const block_q6_K *) vx;
const float d = x[ib].d;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ip = idx / 128;
const int in = idx % 128;
const int il = in & 31;
const int ig = in / 32;
const int is = 8 * ip + il / 16;
const uint8_t ql0 = x[ib].ql[64 * ip + il];
const uint8_t ql1 = x[ib].ql[64 * ip + il + 32];
const uint8_t qh = x[ib].qh[32 * ip + il];
const int8_t * sc = x[ib].scales + is;
uint8_t qv;
int8_t scale;
if (ig == 0) {
qv = (ql0 & 0xF) | (((qh >> 0) & 3) << 4);
scale = sc[0];
} else if (ig == 1) {
qv = (ql1 & 0xF) | (((qh >> 2) & 3) << 4);
scale = sc[2];
} else if (ig == 2) {
qv = (ql0 >> 4) | (((qh >> 4) & 3) << 4);
scale = sc[4];
} else {
qv = (ql1 >> 4) | (((qh >> 6) & 3) << 4);
scale = sc[6];
}
return (dfloat) (d * scale * ((int8_t) qv - 32));
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("Q6_K dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_mxfp4(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_mxfp4 * x = (const block_mxfp4 *) vx;
const float d = ggml_sycl_e8m0_to_fp32(x[ib].e);
const uint8_t q = x[ib].qs[iqs];
v.x() = d * kvalues_mxfp4[q & 0xF] * 0.5f;
v.y() = d * kvalues_mxfp4[q >> 4] * 0.5f;
}
static __dpct_inline__ void dequantize_q1_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q1_0 * x = (const block_q1_0 *) vx;
const dfloat d = x[ib].d;
const int bit_index_0 = iqs + 0;
const int bit_index_1 = iqs + 1;
const int bit_0 = (x[ib].qs[bit_index_0 / 8] >> (bit_index_0 % 8)) & 1;
const int bit_1 = (x[ib].qs[bit_index_1 / 8] >> (bit_index_1 % 8)) & 1;
v.x() = (2 * bit_0 - 1) * d;
v.y() = (2 * bit_1 - 1) * d;
}
static __dpct_inline__ void dequantize_nvfp4(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_nvfp4 & xb = ((const block_nvfp4 *) vx)[ib];
auto dequantize_one = [&](const int idx) -> dfloat {
const int sub = idx / QK_NVFP4_SUB;
const int j = idx % QK_NVFP4_SUB;
const int jh = j % (QK_NVFP4_SUB / 2);
const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + jh];
const uint8_t qv = (j < (QK_NVFP4_SUB / 2)) ? (q & 0x0F) : (q >> 4);
return d * kvalues_mxfp4[qv];
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
}
static __dpct_inline__ void dequantize_iq2_xxs(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const uint16_t * q2 = x[ib].qs + 4 * ib8;
const uint8_t * aux8 = (const uint8_t *) q2;
const uint8_t * grid = (const uint8_t *) (iq2xxs_grid + aux8[il]);
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127];
return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ2_XXS dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq2_xs(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq2_xs * x = (const block_iq2_xs *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const uint16_t * q2 = x[ib].qs + 4 * ib8;
const uint8_t * grid = (const uint8_t *) (iq2xs_grid + (q2[il] & 511));
const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ2_XS dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq2_s(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq2_s * x = (const block_iq2_s *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 0x300);
const uint8_t * grid = (const uint8_t *) (iq2s_grid + grid_id);
const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f;
const uint8_t signs = x[ib].qs[QK_K / 8 + 4 * ib8 + il];
return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ2_S dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq3_xxs(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const uint8_t * q3 = x[ib].qs + 8 * ib8;
const uint16_t * gas = (const uint16_t *) (x[ib].qs + QK_K / 4) + 2 * ib8;
const uint8_t * grid1 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 0]);
const uint8_t * grid2 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 1]);
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127];
if (j < 4) {
return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f);
}
return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ3_XXS dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq3_s(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq3_s * x = (const block_iq3_s *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const uint8_t * qs = x[ib].qs + 8 * ib8;
const uint16_t grid1_id = qs[2 * il + 0] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 256);
const uint16_t grid2_id = qs[2 * il + 1] | ((x[ib].qh[ib8] << (7 - 2 * il)) & 256);
const uint8_t * grid1 = (const uint8_t *) (iq3s_grid + grid1_id);
const uint8_t * grid2 = (const uint8_t *) (iq3s_grid + grid2_id);
const float d = (float) x[ib].d * (1 + 2 * ((x[ib].scales[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf));
const uint8_t signs = x[ib].signs[4 * ib8 + il];
if (j < 4) {
return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f);
}
return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ3_S dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq1_s(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq1_s * x = (const block_iq1_s *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const float delta = (x[ib].qh[ib8] & 0x8000) ? (-1.f - IQ1S_DELTA) : (-1.f + IQ1S_DELTA);
const float d = (float) x[ib].d * (2 * ((x[ib].qh[ib8] >> 12) & 7) + 1);
const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((x[ib].qh[ib8] >> (3 * il)) & 7) << 8);
const uint32_t g = iq1s_grid_gpu[grid_id];
const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F);
return d * (qv + delta);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ1_S dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq1_m(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq1_m * x = (const block_iq1_m *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int il = r / 8;
const int j = r % 8;
const uint16_t * sc = (const uint16_t *) x[ib].scales;
iq1m_scale_t scale;
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
const int ib16 = 2 * ib8 + il / 2;
const float d = (float) scale.f16 * (2 * ((sc[ib16 / 4] >> (3 * (ib16 % 4))) & 0x7) + 1);
const uint8_t qh = x[ib].qh[2 * ib8 + il / 2];
const float delta = (qh & (0x08 << (4 * (il % 2)))) ? (-1.f - IQ1M_DELTA) : (-1.f + IQ1M_DELTA);
const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((qh >> (4 * (il % 2))) & 7) << 8);
const uint32_t g = iq1s_grid_gpu[grid_id];
const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F);
return d * (qv + delta);
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ1_M dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_iq4_nl(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_iq4_nl * x = (const block_iq4_nl *) vx;
const float d = (float) x[ib].d;
auto dequantize_one = [&](const int idx) -> dfloat {
if (idx < 16) {
return d * kvalues_iq4nl[x[ib].qs[idx] & 0xF];
}
return d * kvalues_iq4nl[x[ib].qs[idx - 16] >> 4];
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
}
static __dpct_inline__ void dequantize_iq4_xs(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
#if QK_K == 256
const block_iq4_xs * x = (const block_iq4_xs *) vx;
auto dequantize_one = [&](const int idx) -> dfloat {
const int ib8 = idx / 32;
const int r = idx % 32;
const int byte_idx = (r < 16) ? r : (r - 16);
const uint8_t q = x[ib].qs[16 * ib8 + byte_idx];
const uint8_t qv = (r < 16) ? (q & 0x0F) : (q >> 4);
const float d = (float) x[ib].d * ((((x[ib].scales_l[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf) |
(((x[ib].scales_h >> (2 * ib8)) & 3) << 4)) - 32);
return d * kvalues_iq4nl[qv];
};
v.x() = dequantize_one(iqs + 0);
v.y() = dequantize_one(iqs + 1);
#else
GGML_ABORT("IQ4_XS dequantize not supported for QK_K != 256");
#endif
}
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q5_0 * x = (const block_q5_0 *) vx;

View File

@ -129,11 +129,11 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
GGML_UNUSED(ctx);
}
template <typename src0_t>
template <typename src0_t, typename dst_t>
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const src0_t *src0_dd, const int32_t *src1_dd,
float *dst_dd, queue_ptr stream) {
dst_t *dst_dd, queue_ptr stream) {
GGML_TENSOR_BINARY_OP_LOCALS
@ -170,7 +170,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32 );
GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type));
GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type));
@ -191,6 +191,66 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_I32:
get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const int32_t *)dst->src[0]->data,
src1_i32, (int32_t *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q1_0:
get_rows_sycl<QK1_0, 1, dequantize_q1_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_MXFP4:
get_rows_sycl<QK_MXFP4, 2, dequantize_mxfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_NVFP4:
get_rows_sycl<QK_NVFP4, 1, dequantize_nvfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ2_XXS:
get_rows_sycl<QK_K, 1, dequantize_iq2_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ2_XS:
get_rows_sycl<QK_K, 1, dequantize_iq2_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ2_S:
get_rows_sycl<QK_K, 1, dequantize_iq2_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ3_XXS:
get_rows_sycl<QK_K, 1, dequantize_iq3_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ1_S:
get_rows_sycl<QK_K, 1, dequantize_iq1_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ1_M:
get_rows_sycl<QK_K, 1, dequantize_iq1_m>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ3_S:
get_rows_sycl<QK_K, 1, dequantize_iq3_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ4_NL:
get_rows_sycl<QK4_NL, 1, dequantize_iq4_nl>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_IQ4_XS:
get_rows_sycl<QK_K, 1, dequantize_iq4_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q2_K:
get_rows_sycl<QK_K, 1, dequantize_q2_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q3_K:
get_rows_sycl<QK_K, 1, dequantize_q3_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q4_0:
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
@ -199,6 +259,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q4_K:
get_rows_sycl<QK_K, 1, dequantize_q4_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q5_0:
get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
@ -207,6 +271,14 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q5_K:
get_rows_sycl<QK_K, 1, dequantize_q5_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q6_K:
get_rows_sycl<QK_K, 1, dequantize_q6_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());
break;
case GGML_TYPE_Q8_0:
get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
src1_i32, (float *)dst->data, ctx.stream());

View File

@ -5301,13 +5301,31 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GET_ROWS:
{
switch (op->src[0]->type) {
case GGML_TYPE_I32:
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_F32:
case GGML_TYPE_Q1_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_Q8_0:
return true;
default: