Templating to differenciate the block_q4_0's in get_rows function

This commit is contained in:
Swetha B S 2025-07-16 01:01:39 -07:00
parent b0c631cfb6
commit d39e4e6eb0
1 changed files with 4 additions and 5 deletions

View File

@ -1419,7 +1419,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
switch (src0->type) {
case GGML_TYPE_Q4_0: {
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
if (ggml_cpu_has_avx2()) {
if (src0->ne[1] % 8 == 0) {
ggml_compute_forward_get_rows_q4_0<block_q4_0x8>(params, dst, 8);
}
@ -1484,23 +1484,22 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
// Pointer to the first <BLOCK_TYPE> of the identified row_group_idx
const BLOCK_TYPE * p_first_repacked_block_of_group_block_type = (const BLOCK_TYPE *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
dequantize_row_q4_0<block_q4_0x8>(
dequantize_row_q4_0(
p_first_repacked_block_of_group_block_type,
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
}
}
/**
* Dequantizes a single logical row from data repacked with quant interleaving.
* Dequantizes a single logical row from data repacked with quant interleaving for repacked block_q4_0x8
*
* @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group Index (0-7) of the logical row to dequantize.
*/
template<typename BLOCK_TYPE>
static void dequantize_row_q4_0(
const BLOCK_TYPE * GGML_RESTRICT p_repacked_group_column_blocks,
const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
float * GGML_RESTRICT y,
int64_t k,
int row_idx_in_group) {