ggml : implement fast walsh-hadamard transform for kv rotation (#21352) (llama/22631)
This commit is contained in:
parent
fd184cf07a
commit
3d3cc92541
|
|
@ -438,6 +438,12 @@ extern "C" {
|
|||
GGML_PREC_F32 = 10,
|
||||
};
|
||||
|
||||
// op hint
|
||||
enum ggml_op_hint {
|
||||
GGML_HINT_NONE = 0,
|
||||
GGML_HINT_SRC0_IS_HADAMARD = 1,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum ggml_ftype {
|
||||
GGML_FTYPE_UNKNOWN = -1,
|
||||
|
|
@ -1419,6 +1425,11 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
// change the hint of a matrix multiplication
|
||||
GGML_API void ggml_mul_mat_set_hint(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op_hint hint);
|
||||
|
||||
// indirect matrix multiplication
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
|
|
|
|||
|
|
@ -1245,6 +1245,12 @@ void ggml_compute_forward_mul_mat(
|
|||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
const int32_t hint = ggml_get_op_params_i32(dst, 1);
|
||||
if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) {
|
||||
ggml_compute_forward_fwht(params, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
|
|
|
|||
|
|
@ -11212,3 +11212,91 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t n = ne10;
|
||||
GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2
|
||||
|
||||
const int64_t nr = ne11 * ne12 * ne13;
|
||||
const int64_t rows_per_thread = (nr + nth - 1) / nth;
|
||||
const int64_t start_row = ith * rows_per_thread;
|
||||
const int64_t end_row = MIN(start_row + rows_per_thread, nr);
|
||||
|
||||
const float scale = 1.0f / sqrtf((float)n);
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f);
|
||||
#endif
|
||||
|
||||
for (int64_t r = start_row; r < end_row; r++) {
|
||||
const int64_t i13 = r / (ne11 * ne12);
|
||||
const int64_t i12 = (r - i13 * ne11 * ne12) / ne11;
|
||||
const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11;
|
||||
|
||||
const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13);
|
||||
float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3);
|
||||
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
dst_row[j] = src_row[j] * scale;
|
||||
}
|
||||
|
||||
// Scalar passes
|
||||
#if defined(GGML_SIMD)
|
||||
const int step = GGML_F32_EPR;
|
||||
#else
|
||||
const int step = n;
|
||||
#endif
|
||||
for (int64_t len = 1; len < step && len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j++) {
|
||||
float u = dst_row[i + j];
|
||||
float v = dst_row[i + len + j];
|
||||
dst_row[i + j] = u + v;
|
||||
dst_row[i + len + j] = u - v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SIMD passes using GGML_F32_VEC_* macros for multi-architecture support
|
||||
#if defined(GGML_SIMD)
|
||||
for (int64_t len = step; len < n; len <<= 1) {
|
||||
for (int64_t i = 0; i < n; i += 2 * len) {
|
||||
for (int64_t j = 0; j < len; j += step) {
|
||||
GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j);
|
||||
GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j);
|
||||
|
||||
GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v));
|
||||
GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
switch (src1->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_fwht_f32(params, dst);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error - fwht is F32 only");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
|||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3264,6 +3264,16 @@ void ggml_mul_mat_set_prec(
|
|||
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||
}
|
||||
|
||||
void ggml_mul_mat_set_hint(
|
||||
struct ggml_tensor * a,
|
||||
enum ggml_op_hint hint) {
|
||||
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
|
||||
|
||||
const int32_t hint_i32 = (int32_t) hint;
|
||||
|
||||
ggml_set_op_params_i32(a, 1, hint_i32);
|
||||
}
|
||||
|
||||
// ggml_mul_mat_id
|
||||
|
||||
/*
|
||||
|
|
|
|||
Loading…
Reference in New Issue