mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-03 15:44:06 +00:00
hex-mm: hmx optimize scatter/transpose and use HMX intrinsics
This commit is contained in:
committed by
Max Krasnyansky
parent
f14e9c6f4c
commit
f8737609a5
@@ -49,7 +49,8 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = {
|
||||
0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128,
|
||||
8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128,
|
||||
24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128
|
||||
};
|
||||
|
||||
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
|
||||
@@ -253,7 +254,7 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_scales = hvx_vec_splat_f16(*scale);
|
||||
// q4x4x2 stores two int4 values per byte. Keep only the selected nibble.
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
// Shuffle before LUT
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
@@ -277,7 +278,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
// Load all 128 packed bytes (4 contiguous 32-byte groups)
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles);
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
// Shuffle before LUT
|
||||
@@ -297,10 +298,8 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx(
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
// Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter
|
||||
out[0] = v_lo; // group0 already in [0:63]
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63]
|
||||
out[2] = v_hi; // group2 already in [0:63]
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63]
|
||||
out[0] = v_lo; // group0 already in [0:63]
|
||||
out[1] = v_hi; // group2 already in [0:63]
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes.
|
||||
@@ -404,8 +403,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
size_t row_stride, int weight_type,
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
|
||||
const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block;
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
@@ -418,47 +418,46 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes)
|
||||
|
||||
for (int t = start_tile; t < end_tile; ) {
|
||||
int ct = t / n_k_tiles; // column tile index
|
||||
int kt = t % n_k_tiles; // K tile index
|
||||
unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index
|
||||
unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index
|
||||
for (unsigned t = start_tile; t < end_tile; ) {
|
||||
if (kt >= n_k_tiles) { kt = 0; ct++; }
|
||||
|
||||
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
|
||||
((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
|
||||
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes
|
||||
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE
|
||||
+ sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales
|
||||
|
||||
__fp16 *tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; }
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
HVX_Vector v0[2];
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
if (row1 < n_cols) {
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); }
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); }
|
||||
|
||||
|
||||
r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]);
|
||||
Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); }
|
||||
|
||||
t += 4;
|
||||
t += 4; kt += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -515,20 +514,19 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
if (is_q4) {
|
||||
unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16);
|
||||
|
||||
HVX_Vector v_off = v_scat_base; // reset to column 0
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t *r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row1 * row_stride;
|
||||
unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride;
|
||||
unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) {
|
||||
const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride;
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(
|
||||
r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt);
|
||||
@@ -605,7 +603,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
}
|
||||
++t;
|
||||
++t; ++kt;
|
||||
}
|
||||
|
||||
// Drain HVX scatter write buffer: a vmem load on the same HW thread retires
|
||||
@@ -673,9 +671,13 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles(
|
||||
// --- End x4x2 dequantizers ---
|
||||
|
||||
// requires external HMX lock
|
||||
static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales,
|
||||
static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales,
|
||||
int n_row_tiles, int n_col_tiles, int n_dot_tiles) {
|
||||
hmx_set_output_scales(scales);
|
||||
__builtin_assume(n_row_tiles > 0);
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(n_dot_tiles > 0);
|
||||
|
||||
Q6_bias_mxmem2_A((void *)scales);
|
||||
|
||||
for (int r = 0; r < n_row_tiles; ++r) {
|
||||
for (int c = 0; c < n_col_tiles; ++c) {
|
||||
@@ -685,12 +687,14 @@ static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const
|
||||
const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
int offset = k * HMX_FP16_TILE_N_ELMS;
|
||||
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
__fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS;
|
||||
hmx_consume_accumulator_fp16(out_tile);
|
||||
Q6_mxmem_AR_after_hf(out_tile, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1510,10 +1514,13 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
|
||||
}
|
||||
|
||||
// C += AB
|
||||
void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile,
|
||||
void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile,
|
||||
int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) {
|
||||
__builtin_assume(n_row_tiles > 0);
|
||||
__builtin_assume(n_col_tiles > 0);
|
||||
__builtin_assume(n_dot_tiles > 0);
|
||||
|
||||
hmx_set_output_scales(col_scales);
|
||||
Q6_bias_mxmem2_A((void *)col_scales);
|
||||
|
||||
for (int i = 0; i < n_row_tiles; ++i) {
|
||||
for (int j = 0; j < n_col_tiles; ++j) {
|
||||
@@ -1524,15 +1531,17 @@ void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp
|
||||
|
||||
__fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS;
|
||||
if (!zero_init) {
|
||||
hmx_load_tile_pair_fp16(accum_tile, eye_tile);
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047);
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_dot_tiles; ++k) {
|
||||
int offset = k * HMX_FP16_TILE_N_ELMS;
|
||||
hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset);
|
||||
Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
hmx_consume_accumulator_fp16(accum_tile);
|
||||
Q6_mxmem_AR_after_hf(accum_tile, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,10 +14,6 @@
|
||||
|
||||
#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline))
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) {
|
||||
asm volatile("bias = mxmem2(%0)" :: "r"(scales));
|
||||
}
|
||||
|
||||
// Initialise aligned 256-byte area with scale vector + zero padding.
|
||||
static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) {
|
||||
HVX_Vector *pv = (HVX_Vector *)out_scales;
|
||||
@@ -25,58 +21,6 @@ static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vecto
|
||||
*pv = Q6_V_vzero();
|
||||
}
|
||||
|
||||
// Load multiple contiguous tiles with :deep streaming.
|
||||
// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt].
|
||||
// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank
|
||||
// boundary, otherwise the mxmem instruction will raise a precise bus error.
|
||||
// Callers must ensure their VTCM layout satisfies this constraint.
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1;
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1):deep\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Load a single activation+weight tile pair (no :deep streaming).
|
||||
// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula
|
||||
// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047.
|
||||
// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation
|
||||
// places a tile near a 4 MB bank boundary, the oversized region crosses it and
|
||||
// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly
|
||||
// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047).
|
||||
static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile,
|
||||
const __fp16 *wt_tile) {
|
||||
asm volatile(
|
||||
"{ activation.hf = mxmem(%0, %1)\n"
|
||||
"weight.hf = mxmem(%2, %3) }\n"
|
||||
:: "r"(act_tile), "r"(2047),
|
||||
"r"(wt_tile), "r"(2047)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) {
|
||||
// Use the combined convert-and-store instruction (matches the reference
|
||||
// Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence
|
||||
// "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter.
|
||||
asm volatile(
|
||||
"mxmem(%0, %1):after.hf = acc\n"
|
||||
:: "r"(out), "r"(0)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
// Compute inner product of two vectors of tiles and store result.
|
||||
static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out,
|
||||
const __fp16 *row_tiles,
|
||||
const __fp16 *col_tiles,
|
||||
size_t n_tiles) {
|
||||
hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles);
|
||||
hmx_consume_accumulator_fp16(out);
|
||||
}
|
||||
|
||||
// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) ---
|
||||
|
||||
static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) {
|
||||
|
||||
@@ -116,9 +116,14 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) {
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) {
|
||||
#if __HVX_ARCH__ >= 81
|
||||
HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0);
|
||||
HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1);
|
||||
#else
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero);
|
||||
HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero);
|
||||
#endif
|
||||
return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user