mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 12:34:05 +00:00
hexagon: eliminate scalar VTCM loads via HVX splat helpers (#22993)
* hexagon: add hvx_vec_repl helpers and use those for splat-from-vtcm usecase * hmx-mm: optimize per-group scale handling * hmx-fa: optimize slope load from vtcm * hmx-fa: use aligned access where possible in hmx-utils * hexagon: add hvx_vec_repl_2x_f16 helper and consolidate repl helpers --------- Co-authored-by: Max Krasnyansky <maxk@qti.qualcomm.com>
This commit is contained in:
@@ -760,8 +760,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||
// ALiBi slopes — only needed when has_alibi (scheme A)
|
||||
HVX_Vector v_slope0, v_slope1;
|
||||
if (args->has_alibi) {
|
||||
v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]);
|
||||
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero();
|
||||
HVX_Vector v_s = hvx_vmemu(args->slopes + r);
|
||||
v_slope0 = hvx_vec_repl_f16(v_s);
|
||||
v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero();
|
||||
}
|
||||
|
||||
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c)
|
||||
|
||||
@@ -180,12 +180,10 @@ next_nc:
|
||||
// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes.
|
||||
// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles
|
||||
// of the same 32 packed bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
|
||||
const uint8_t *packed_32, bool upper_nibbles,
|
||||
const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||||
static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
|
||||
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
@@ -223,9 +221,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16]
|
||||
|
||||
// Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b
|
||||
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
|
||||
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1]));
|
||||
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3]));
|
||||
volatile HVX_Vector vscale = hvx_vmemu(scales_4);
|
||||
|
||||
HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale);
|
||||
HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
@@ -237,10 +236,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) {
|
||||
HVX_Vector vq = hvx_vmemu(quants_32);
|
||||
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
|
||||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
|
||||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||||
HVX_Vector vq = hvx_vmemu(quants_32);
|
||||
HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale));
|
||||
HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq));
|
||||
HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0);
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
@@ -521,12 +520,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx(
|
||||
(const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
|
||||
HVX_Vector v1 = (row1 < n_cols)
|
||||
? dequantize_x4x2_q8_0_group_hvx(
|
||||
(const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off))
|
||||
: Q6_V_vzero();
|
||||
HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off));
|
||||
HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero();
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
|
||||
@@ -77,16 +77,18 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
assert(hex_is_aligned(p0, 128));
|
||||
assert(hex_is_aligned(p1, 128));
|
||||
assert(c_byte_step % 128 == 0);
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
@@ -94,8 +96,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
@@ -116,16 +117,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
__fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride);
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1);
|
||||
p1 += c_byte_step;
|
||||
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1);
|
||||
tile_base += dst_step;
|
||||
@@ -133,8 +132,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0);
|
||||
p0 += c_byte_step;
|
||||
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
tile_base += dst_step;
|
||||
|
||||
74
ggml/src/ggml-hexagon/htp/hvx-repl.h
Normal file
74
ggml/src/ggml-hexagon/htp/hvx-repl.h
Normal file
@@ -0,0 +1,74 @@
|
||||
#ifndef HVX_REPL_H
|
||||
#define HVX_REPL_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "hvx-base.h"
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) {
|
||||
return Q6_V_vdelta_VV(v, hvx_vmem(ctrl));
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) {
|
||||
// vdelta control to replicate first 4 bytes across all lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
return hvx_vec_repl(v, repl);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) {
|
||||
// vdelta control to replicate first 4 bytes across all lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
return hvx_vec_repl(v, repl);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) {
|
||||
// vdelta control to replicate first two bytes across all lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
return hvx_vec_repl(v, repl);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) {
|
||||
// vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1]
|
||||
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04,
|
||||
};
|
||||
return hvx_vec_repl(v, repl);
|
||||
}
|
||||
|
||||
#endif // HVX_REPL_H
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "hvx-types.h"
|
||||
#include "hvx-copy.h"
|
||||
#include "hvx-repl.h"
|
||||
#include "hvx-scale.h"
|
||||
#include "hvx-exp.h"
|
||||
#include "hvx-inverse.h"
|
||||
|
||||
@@ -70,5 +70,5 @@ adb $adbserial $adbhost shell " \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 -no-cnv --device $device $cli_opts $@ \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
"
|
||||
|
||||
Reference in New Issue
Block a user