hexagon: apply repl optimization in flash attn softmax as #22993 (llama/23455)
This commit is contained in:
parent
511f8602b1
commit
b84d03487c
|
|
@ -852,9 +852,10 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
|||
v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1);
|
||||
|
||||
// Splat m_prev[r], m_prev[r+1] from the per-row accumulator.
|
||||
// vror brings the target lane to lane 0, then extract + re-splat.
|
||||
HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)));
|
||||
HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)));
|
||||
// vror brings the target lane to lane 0, then vdelta replicates it
|
||||
// across all lanes — stays in the vector domain (no store/reload).
|
||||
HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2));
|
||||
HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2));
|
||||
|
||||
// HVX max — both operands are splats, so result is splat of m_new.
|
||||
HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue