Templating to differenciate the block_q4_0
This commit is contained in:
parent
d3e56e5be5
commit
b0c631cfb6
|
|
@ -1419,7 +1419,14 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0: {
|
||||
ggml_compute_forward_get_rows_q4_0x8(params, dst);
|
||||
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
|
||||
if (src0->ne[1] % 8 == 0) {
|
||||
ggml_compute_forward_get_rows_q4_0<block_q4_0x8>(params, dst, 8);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("Unsupported block interleaved size for get_rows function");
|
||||
}
|
||||
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -1427,9 +1434,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_get_rows_q4_0x8(
|
||||
template<typename BLOCK_TYPE>
|
||||
static void ggml_compute_forward_get_rows_q4_0(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
ggml_tensor * dst,
|
||||
int nrows_interleaved) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
|
|
@ -1453,8 +1462,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
constexpr int nrows_interleaved = 8;
|
||||
const size_t sizeof_one_repacked_block = sizeof(block_q4_0x8);
|
||||
const size_t sizeof_one_repacked_block = sizeof(BLOCK_TYPE);
|
||||
|
||||
const int num_repacked_blocks_per_row_width = nc / QK4_0;
|
||||
|
||||
|
|
@ -1473,11 +1481,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
|
||||
const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
|
||||
|
||||
// Pointer to the first block_q4_0x8 of the identified row_group_idx
|
||||
const block_q4_0x8 * p_first_repacked_block_of_group_x8 = (const block_q4_0x8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
|
||||
// Pointer to the first <BLOCK_TYPE> of the identified row_group_idx
|
||||
const BLOCK_TYPE * p_first_repacked_block_of_group_block_type = (const BLOCK_TYPE *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
|
||||
|
||||
dequantize_row_q4_0x8(
|
||||
p_first_repacked_block_of_group_x8,
|
||||
dequantize_row_q4_0<block_q4_0x8>(
|
||||
p_first_repacked_block_of_group_block_type,
|
||||
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
|
||||
}
|
||||
}
|
||||
|
|
@ -1490,8 +1498,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
* @param k Total number of elements (columns) in the logical row.
|
||||
* @param row_idx_in_group Index (0-7) of the logical row to dequantize.
|
||||
*/
|
||||
static void dequantize_row_q4_0x8(
|
||||
const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
|
||||
template<typename BLOCK_TYPE>
|
||||
static void dequantize_row_q4_0(
|
||||
const BLOCK_TYPE * GGML_RESTRICT p_repacked_group_column_blocks,
|
||||
float * GGML_RESTRICT y,
|
||||
int64_t k,
|
||||
int row_idx_in_group) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue