hexagon: eliminate scalar VTCM loads via HVX splat helpers (#22993)

* hexagon: add hvx_vec_repl helpers and use those for splat-from-vtcm usecase

* hmx-mm: optimize per-group scale handling

* hmx-fa: optimize slope load from vtcm

* hmx-fa: use aligned access where possible in hmx-utils

* hexagon: add hvx_vec_repl_2x_f16 helper and consolidate repl helpers

---------

Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
Trivikram Reddy
2026-05-12 19:28:02 -05:00
committed by GitHub
parent a9883db8ee
commit 856c3adac1
6 changed files with 107 additions and 38 deletions

View File

@@ -760,8 +760,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
// ALiBi slopes — only needed when has_alibi (scheme A)
HVX_Vector v_slope0, v_slope1;
if (args->has_alibi) {
v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]);
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero();
HVX_Vector v_s = hvx_vmemu(args->slopes + r);
v_slope0 = hvx_vec_repl_f16(v_s);
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero();
}
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c)

View File

@@ -180,12 +180,10 @@ next_nc:
// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes.
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
// of the same 32 packed bytes.
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
const uint8_t *packed_32, bool upper_nibbles,
const __fp16 *scale, const HVX_Vector vlut_cvt) {
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
HVX_Vector vq = hvx_vmemu(packed_32);
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
@@ -223,9 +221,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1]));
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3]));
volatile HVX_Vector vscale = hvx_vmemu(scales_4);
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
@@ -237,10 +236,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
HVX_Vector vq = hvx_vmemu(quants_32);
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
HVX_Vector vq = hvx_vmemu(quants_32);
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
}
@@ -521,12 +520,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
const uint8_t *r0 = vtcm_src + row0 * row_stride;
const uint8_t *r1 = vtcm_src + row1 * row_stride;
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx(
(const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
HVX_Vector v1 = (row1 < n_cols)
? dequantize_x4x2_q8_0_group_hvx(
(const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off))
: Q6_V_vzero();
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);

View File

@@ -77,16 +77,18 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
assert(hex_is_aligned(p0, 128));
assert(hex_is_aligned(p1, 128));
assert(c_byte_step % 128 == 0);
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1);
p1 += c_byte_step;
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
tile_base += dst_step;
@@ -94,8 +96,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
tile_base += dst_step;
@@ -116,16 +117,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
if (p1) {
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1);
p1 += c_byte_step;
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
tile_base += dst_step;
@@ -133,8 +132,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
} else {
const HVX_Vector vzero = Q6_V_vzero();
for (int i = 0; i < n_c_iters; ++i) {
HVX_Vector v0 = hvx_vmemu(p0);
p0 += c_byte_step;
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
tile_base += dst_step;

View File

@@ -0,0 +1,74 @@
#ifndef HVX_REPL_H
#define HVX_REPL_H
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) {
return Q6_V_vdelta_VV(v, hvx_vmem(ctrl));
}
static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) {
// vdelta control to replicate first 4 bytes across all lanes
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
return hvx_vec_repl(v, repl);
}
static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) {
// vdelta control to replicate first 4 bytes across all lanes
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
};
return hvx_vec_repl(v, repl);
}
static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) {
// vdelta control to replicate first two bytes across all lanes
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
return hvx_vec_repl(v, repl);
}
static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) {
// vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1]
static const uint8_t __attribute__((aligned(128))) repl[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
};
return hvx_vec_repl(v, repl);
}
#endif // HVX_REPL_H

View File

@@ -5,6 +5,7 @@
#include "hvx-types.h"
#include "hvx-copy.h"
#include "hvx-repl.h"
#include "hvx-scale.h"
#include "hvx-exp.h"
#include "hvx-inverse.h"

View File

@@ -70,5 +70,5 @@ adb $adbserial $adbhost shell " \
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
--ctx-size 8192 --ubatch-size 256 -fa on \
-ngl 99 -no-cnv --device $device $cli_opts $@ \
-ngl 99 --device $device $cli_opts $@ \
"