hexagon: apply repl optimization in flash attn softmax as #22993 (llama/23455)

This commit is contained in:
Yiwei Shao 2026-05-23 19:56:59 -07:00 committed by Georgi Gerganov
parent 511f8602b1
commit b84d03487c
1 changed files with 4 additions and 3 deletions

View File

@ -852,9 +852,10 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1);
// Splat m_prev[r], m_prev[r+1] from the per-row accumulator.
// vror brings the target lane to lane 0, then extract + re-splat.
HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)));
HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)));
// vror brings the target lane to lane 0, then vdelta replicates it
// across all lanes — stays in the vector domain (no store/reload).
HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2));
HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2));
// HVX max — both operands are splats, so result is splat of m_new.
HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0);