diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 8a6d7c14ed..4a4ff0b331 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -760,8 +760,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { // ALiBi slopes — only needed when has_alibi (scheme A) HVX_Vector v_slope0, v_slope1; if (args->has_alibi) { - v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); - v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + HVX_Vector v_s = hvx_vmemu(args->slopes + r); + v_slope0 = hvx_vec_repl_f16(v_s); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero(); } const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 9e8c9966e0..e05ccfd5fc 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -180,12 +180,10 @@ next_nc: // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. -static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( - const uint8_t *packed_32, bool upper_nibbles, - const __fp16 *scale, const HVX_Vector vlut_cvt) { +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); @@ -223,9 +221,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - HVX_VectorPred q64 = Q6_Q_vsetq_R(64); - HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1])); - HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3])); + volatile HVX_Vector vscale = hvx_vmemu(scales_4); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); @@ -237,10 +236,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { - HVX_Vector vq = hvx_vmemu(quants_32); - HVX_Vector v_scales = hvx_vec_splat_f16(*scale); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } @@ -521,12 +520,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( const uint8_t *r0 = vtcm_src + row0 * row_stride; const uint8_t *r1 = vtcm_src + row1 * row_stride; - HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx( - (const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q8_0_group_hvx( - (const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) - : Q6_V_vzero(); + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index 68f174d693..f448ee3372 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -77,16 +77,18 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; - const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); - const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + assert(hex_is_aligned(p0, 128)); + assert(hex_is_aligned(p1, 128)); + assert(c_byte_step % 128 == 0); if (p1) { for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; - HVX_Vector v1 = hvx_vmemu(p1); - p1 += c_byte_step; + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step; Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); tile_base += dst_step; @@ -94,8 +96,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, } else { const HVX_Vector vzero = Q6_V_vzero(); for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); tile_base += dst_step; @@ -116,16 +117,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; - const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); - const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; if (p1) { for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; - HVX_Vector v1 = hvx_vmemu(p1); - p1 += c_byte_step; + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step; Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); tile_base += dst_step; @@ -133,8 +132,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, } else { const HVX_Vector vzero = Q6_V_vzero(); for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); tile_base += dst_step; diff --git a/ggml/src/ggml-hexagon/htp/hvx-repl.h b/ggml/src/ggml-hexagon/htp/hvx-repl.h new file mode 100644 index 0000000000..fdc7e6c7d2 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-repl.h @@ -0,0 +1,74 @@ +#ifndef HVX_REPL_H +#define HVX_REPL_H + +#include +#include +#include + +#include "hvx-base.h" + +static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) { + return Q6_V_vdelta_VV(v, hvx_vmem(ctrl)); +} + +static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) { + // vdelta control to replicate first two bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) { + // vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1] + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +#endif // HVX_REPL_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad3733..e0452811ec 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -5,6 +5,7 @@ #include "hvx-types.h" #include "hvx-copy.h" +#include "hvx-repl.h" #include "hvx-scale.h" #include "hvx-exp.h" #include "hvx-inverse.h" diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index 638823b882..1b904c77b3 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -70,5 +70,5 @@ adb $adbserial $adbhost shell " \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ --ctx-size 8192 --ubatch-size 256 -fa on \ - -ngl 99 -no-cnv --device $device $cli_opts $@ \ + -ngl 99 --device $device $cli_opts $@ \ "