mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-16 14:04:05 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff934e29bc | ||
|
|
ee051c1e4e | ||
|
|
e6f6770515 | ||
|
|
48cda24c11 | ||
|
|
871f1a2d2f | ||
|
|
20197b6fe3 | ||
|
|
ba38f3becc |
@@ -41,6 +41,7 @@
|
||||
effectiveStdenv ? if useCuda then cudaPackages.backendStdenv else stdenv,
|
||||
enableStatic ? effectiveStdenv.hostPlatform.isStatic,
|
||||
precompileMetalShaders ? false,
|
||||
useWebUi ? true,
|
||||
}:
|
||||
|
||||
let
|
||||
@@ -164,6 +165,7 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
||||
cmakeFlags =
|
||||
[
|
||||
(cmakeBool "LLAMA_BUILD_SERVER" true)
|
||||
(cmakeBool "LLAMA_BUILD_WEBUI" useWebUi)
|
||||
(cmakeBool "BUILD_SHARED_LIBS" (!enableStatic))
|
||||
(cmakeBool "CMAKE_SKIP_BUILD_RPATH" true)
|
||||
(cmakeBool "GGML_NATIVE" false)
|
||||
|
||||
@@ -108,6 +108,7 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_WEBUI "llama: build the embedded Web UI for server" ON)
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
|
||||
@@ -1079,7 +1079,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.verbose_prompt = true;
|
||||
}
|
||||
));
|
||||
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||
add_opt(common_arg(
|
||||
{"--display-prompt"},
|
||||
{"--no-display-prompt"},
|
||||
@@ -2843,6 +2843,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.webui_mcp_proxy = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY"));
|
||||
add_opt(common_arg(
|
||||
{"--tools"}, "TOOL1,TOOL2,...",
|
||||
"experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)\n"
|
||||
"specify \"all\" to enable all tools\n"
|
||||
"available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_tools = parse_csv_row(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS"));
|
||||
add_opt(common_arg(
|
||||
{"--webui"},
|
||||
{"--no-webui"},
|
||||
|
||||
@@ -613,6 +613,9 @@ struct common_params {
|
||||
bool endpoint_props = false; // only control POST requests, not GET
|
||||
bool endpoint_metrics = false;
|
||||
|
||||
// enable built-in tools
|
||||
std::vector<std::string> server_tools;
|
||||
|
||||
// router server configs
|
||||
std::string models_dir = ""; // directory containing models for the router server
|
||||
std::string models_preset = ""; // directory containing model presets for the router server
|
||||
|
||||
@@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
||||
repack_q8_0_q8x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
// IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16])
|
||||
repack_q4_0_q4x4x2(tensor, data, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
@@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
||||
repack_q8x4x2_q8_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
repack_q4x4x2_q4_0(data, tensor, size);
|
||||
break;
|
||||
|
||||
case GGML_TYPE_MXFP4:
|
||||
GGML_ASSERT(offset == 0);
|
||||
GGML_ASSERT(offset + size <= ggml_nbytes(tensor));
|
||||
@@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
if (src0->ne[0] % 32) {
|
||||
return false;
|
||||
@@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
if ((src0->ne[0] % 32)) {
|
||||
return false;
|
||||
@@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
|
||||
delete backend;
|
||||
}
|
||||
|
||||
// Map weight type to its activation quantization family.
|
||||
// Types in the same family produce identical Q8 formats in VTCM and can
|
||||
// safely share quantized activation data via SKIP_QUANTIZE.
|
||||
// When adding a new quantized type, assign it the correct family here.
|
||||
static inline int act_quant_family(enum ggml_type wtype) {
|
||||
switch (wtype) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return 1; // Q8x4x2
|
||||
default:
|
||||
return 0; // unknown / not quantized
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
|
||||
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
|
||||
return (op0 && op0->src[1] == op1->src[1] &&
|
||||
act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) &&
|
||||
act_quant_family(op0->src[0]->type) != 0);
|
||||
}
|
||||
|
||||
static inline bool is_compute_op(ggml_tensor *node)
|
||||
@@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL,
|
||||
"please update hexagon_type to match ggml_type");
|
||||
|
||||
const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL");
|
||||
const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE");
|
||||
|
||||
@@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
-8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0,
|
||||
};
|
||||
|
||||
// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value
|
||||
// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0,
|
||||
};
|
||||
|
||||
static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = {
|
||||
-127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0,
|
||||
1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0,
|
||||
@@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned
|
||||
|
||||
// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes
|
||||
#define HMX_X4X2_SCALES_PER_BLK 8
|
||||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes
|
||||
#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL)
|
||||
#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4)
|
||||
|
||||
static inline void swap_ptr(void **p1, void **p2) {
|
||||
void *t = *p1;
|
||||
@@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb
|
||||
case HTP_TYPE_Q8_0:
|
||||
return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb
|
||||
case HTP_TYPE_MXFP4:
|
||||
return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
@@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales));
|
||||
}
|
||||
|
||||
// --- MXFP4 E8M0 scale conversion and dequantization ---
|
||||
//
|
||||
// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack.
|
||||
// Scalar loads from the stack array execute on the scalar pipeline, in parallel
|
||||
// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop.
|
||||
// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10
|
||||
// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15.
|
||||
|
||||
typedef struct {
|
||||
__fp16 v[8] __attribute__((aligned(16)));
|
||||
} mxfp4_scales_t;
|
||||
|
||||
static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) {
|
||||
mxfp4_scales_t s;
|
||||
HVX_Vector v = hvx_vmemu(e8m0_8);
|
||||
HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v));
|
||||
vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112));
|
||||
vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero());
|
||||
vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30));
|
||||
vh = Q6_Vh_vasl_VhR(vh, 10);
|
||||
hvx_vec_store_u(s.v, 16, vh);
|
||||
return s;
|
||||
}
|
||||
|
||||
static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) {
|
||||
return hvx_vec_splat_f16(scales.v[idx]);
|
||||
}
|
||||
|
||||
// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16.
|
||||
static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32,
|
||||
bool upper_nibbles,
|
||||
int sub_blk,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_32);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_hf = Q6_V_lo_W(vp);
|
||||
|
||||
return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc));
|
||||
}
|
||||
|
||||
// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes).
|
||||
static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128,
|
||||
bool upper_nibbles,
|
||||
int sub_blk_base,
|
||||
const HVX_Vector vlut_cvt,
|
||||
mxfp4_scales_t scales,
|
||||
HVX_Vector out[4]) {
|
||||
HVX_Vector vq = hvx_vmemu(packed_128);
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq;
|
||||
v_quants = Q6_V_vand_VV(v_quants, mask_h4);
|
||||
|
||||
v_quants = Q6_Vb_vshuff_Vb(v_quants);
|
||||
|
||||
HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0);
|
||||
HVX_Vector v_lo = Q6_V_lo_W(vp);
|
||||
HVX_Vector v_hi = Q6_V_hi_W(vp);
|
||||
|
||||
HVX_VectorPred q64 = Q6_Q_vsetq_R(64);
|
||||
HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0),
|
||||
mxfp4_extract_splat(scales, sub_blk_base + 1));
|
||||
HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2),
|
||||
mxfp4_extract_splat(scales, sub_blk_base + 3));
|
||||
|
||||
v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01));
|
||||
v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23));
|
||||
|
||||
out[0] = v_lo;
|
||||
out[1] = Q6_V_vror_VR(v_lo, 64);
|
||||
out[2] = v_hi;
|
||||
out[3] = Q6_V_vror_VR(v_hi, 64);
|
||||
}
|
||||
|
||||
// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16.
|
||||
// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes.
|
||||
// Output: vtcm_dst in tile-major FP16 layout.
|
||||
@@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
int start_tile, int end_tile) {
|
||||
|
||||
const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS;
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int qrow_size = is_q4 ? (k_block / 2) : k_block;
|
||||
const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2);
|
||||
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL)
|
||||
? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut);
|
||||
const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) :
|
||||
(weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) :
|
||||
hvx_vmem(q4_0_to_fp16_lut);
|
||||
|
||||
// vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions.
|
||||
// Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128
|
||||
@@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
int ct = t / n_k_tiles; // column tile index
|
||||
int kt = t % n_k_tiles; // K tile index
|
||||
|
||||
// --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
// --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row ---
|
||||
if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) &&
|
||||
((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
@@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales ---
|
||||
if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4
|
||||
bool upper = (sub_blk_base >= 4);
|
||||
int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales
|
||||
|
||||
__fp16 * tile_bases[4];
|
||||
for (int g = 0; g < 4; g++) {
|
||||
tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0[4], v1[4];
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0);
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1);
|
||||
} else {
|
||||
v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero();
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
for (int g = 0; g < 4; g++) {
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]);
|
||||
}
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
|
||||
for (int g = 0; g < 4; g++) {
|
||||
(void) *(volatile HVX_Vector *) (tile_bases[g]);
|
||||
}
|
||||
|
||||
t += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
// --- Single-tile fallback ---
|
||||
__fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
if (is_q4) {
|
||||
if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) {
|
||||
int blk_idx = (kt * 32) / QK_Q4_0x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
@@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task(
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *)(tile_base);
|
||||
} else if (weight_type == HTP_TYPE_MXFP4) {
|
||||
int blk_idx = (kt * 32) / QK_MXFP4x4x2;
|
||||
int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32;
|
||||
bool upper = (sub_blk >= 4);
|
||||
int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32;
|
||||
int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE;
|
||||
|
||||
HVX_Vector v_off = v_scat_base;
|
||||
for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) {
|
||||
int row0 = ct * HMX_FP16_TILE_N_COLS + r;
|
||||
int row1 = row0 + 1;
|
||||
|
||||
const uint8_t * r0 = vtcm_src + row0 * row_stride;
|
||||
const uint8_t * r1 = vtcm_src + row1 * row_stride;
|
||||
|
||||
// Batch-convert all 8 E8M0 scales once per row (stays in HVX register)
|
||||
mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off);
|
||||
|
||||
HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8);
|
||||
HVX_Vector v1;
|
||||
if (row1 < n_cols) {
|
||||
mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off);
|
||||
v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8);
|
||||
} else {
|
||||
v1 = Q6_V_vzero();
|
||||
}
|
||||
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1);
|
||||
v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step);
|
||||
}
|
||||
(void) *(volatile HVX_Vector *) (tile_base);
|
||||
} else {
|
||||
// Q8_0
|
||||
int blk_idx = (kt * 32) / QK_Q8_0x4x2;
|
||||
@@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
|
||||
{
|
||||
qweight_fetch_task_state_t s;
|
||||
|
||||
const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL);
|
||||
const int blk_start = kk / QK_Q4_0x4x2;
|
||||
const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2;
|
||||
const int full_qrow = is_q4 ? (k / 2) : k;
|
||||
const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2);
|
||||
const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz);
|
||||
const int scale_blk_size =
|
||||
(weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE;
|
||||
|
||||
s.dst = vtcm_scratch0;
|
||||
s.src = w + nc * row_stride;
|
||||
s.n_rows = n_blk_sz;
|
||||
s.src_stride = row_stride;
|
||||
s.dst_stride = sub_row_stride;
|
||||
s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2);
|
||||
s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2);
|
||||
s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE;
|
||||
s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
|
||||
s.quant_off =
|
||||
(weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2));
|
||||
s.quant_width =
|
||||
(weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2));
|
||||
s.scale_off = full_qrow + blk_start * scale_blk_size;
|
||||
s.scale_width = nb_sub * scale_blk_size;
|
||||
|
||||
// 2D DMA: quants sub-range
|
||||
dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
|
||||
|
||||
@@ -31,6 +31,12 @@ struct htp_context {
|
||||
|
||||
uint32_t opmask;
|
||||
|
||||
// Cached src1 spad position from the last quantize pass.
|
||||
// When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM
|
||||
// at this address; the matmul must read from here instead of recomputing
|
||||
// the offset (which depends on the current op's src0 size).
|
||||
uint8_t * prev_src1_spad;
|
||||
|
||||
// HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX)
|
||||
#ifdef HTP_HAS_HMX
|
||||
int hmx_enabled; // Runtime flag: HMX initialisation succeeded
|
||||
|
||||
@@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx,
|
||||
return;
|
||||
}
|
||||
|
||||
// HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights.
|
||||
// Other types (e.g. MXFP4) fall back to HVX.
|
||||
// HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights.
|
||||
// Other types fall back to HVX.
|
||||
{
|
||||
uint32_t wtype = req->src0.type;
|
||||
if (wtype != HTP_TYPE_F16 &&
|
||||
wtype != HTP_TYPE_Q4_0 &&
|
||||
wtype != HTP_TYPE_Q8_0 &&
|
||||
wtype != HTP_TYPE_IQ4_NL) {
|
||||
if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL &&
|
||||
wtype != HTP_TYPE_MXFP4) {
|
||||
proc_matmul_req(ctx, req, bufs, n_bufs);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -60,6 +60,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
||||
0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20,
|
||||
};
|
||||
|
||||
// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue
|
||||
// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
||||
static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = {
|
||||
0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0,
|
||||
0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
||||
0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0,
|
||||
0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
@@ -68,6 +78,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = {
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
};
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) {
|
||||
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
||||
|
||||
HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes)
|
||||
HVX_Vector v2_3 = vptr[1]; // ...
|
||||
HVX_Vector v4_5 = vptr[2]; // ...
|
||||
HVX_Vector v6_7 = vptr[3]; // ...
|
||||
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
||||
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4
|
||||
HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F
|
||||
HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4
|
||||
HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F
|
||||
HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4
|
||||
HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F
|
||||
HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4
|
||||
|
||||
v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
||||
v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
||||
v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0);
|
||||
v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0);
|
||||
v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0);
|
||||
v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0);
|
||||
v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0);
|
||||
v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0);
|
||||
|
||||
HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 };
|
||||
return r;
|
||||
}
|
||||
|
||||
static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) {
|
||||
const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr;
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2; // 256
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F);
|
||||
const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut;
|
||||
|
||||
HVX_Vector_x8 r;
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nb; i++) {
|
||||
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements
|
||||
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0);
|
||||
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vptr[i]; // 256 elements (128 bytes)
|
||||
HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements
|
||||
HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements
|
||||
HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:...
|
||||
r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0);
|
||||
r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales
|
||||
|
||||
static inline size_t q8x4x2_row_size(uint32_t ne) {
|
||||
@@ -921,6 +998,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float *
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
||||
}
|
||||
|
||||
// ======== IQ4_NL x Q8_0 vec_dot kernels ========
|
||||
// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue).
|
||||
// Scale format is identical to Q4_0 (fp16 scales).
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n,
|
||||
float * restrict s0,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vy0) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales
|
||||
|
||||
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
HVX_Vector r0_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
}
|
||||
|
||||
r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
|
||||
|
||||
hvx_vec_store_u(s0, 4, r0_sum);
|
||||
}
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n,
|
||||
float * restrict s0,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vx1 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales
|
||||
|
||||
const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
HVX_Vector r0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q));
|
||||
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe));
|
||||
HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe));
|
||||
|
||||
HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d)));
|
||||
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_dd = Q6_V_vand_QV(bmask, r0_dd);
|
||||
r1_dd = Q6_V_vand_QV(bmask, r1_dd);
|
||||
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
|
||||
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
|
||||
|
||||
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
|
||||
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
|
||||
|
||||
r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
|
||||
r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n,
|
||||
float * restrict s0,
|
||||
float * restrict s1,
|
||||
const void * restrict vx0,
|
||||
const void * restrict vx1,
|
||||
const void * restrict vy0,
|
||||
const void * restrict vy1) {
|
||||
assert(n % 32 == 0);
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
assert((unsigned long) vx1 % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
assert((unsigned long) vy1 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0;
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size;
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0;
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size;
|
||||
|
||||
const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0;
|
||||
const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size;
|
||||
const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0;
|
||||
const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size;
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
||||
|
||||
const uint32_t nb = n / qk;
|
||||
const uint32_t nloe = n % qk;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
||||
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
||||
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
||||
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
||||
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
||||
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
||||
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
||||
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum);
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum);
|
||||
}
|
||||
|
||||
static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) {
|
||||
assert(n % 32 == 0); // min sub-block size
|
||||
assert((unsigned long) vx0 % 128 == 0);
|
||||
@@ -2393,6 +2757,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t
|
||||
mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1;
|
||||
mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2;
|
||||
return 0;
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
mmctx->type = "iq4nlx4x2-f32";
|
||||
mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1;
|
||||
mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1;
|
||||
mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2;
|
||||
return 0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
mmctx->type = "mxfp4x4x2-f32";
|
||||
mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1;
|
||||
@@ -2556,6 +2926,13 @@ int op_matmul(struct htp_ops_context * octx) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
// Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
// SKIP_QUANTIZE: Q8 data lives at the address written by the previous
|
||||
// quantize pass. The current op may have a different src0 size (e.g.
|
||||
// IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong.
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
@@ -2659,6 +3036,9 @@ int op_matmul_id(struct htp_ops_context * octx) {
|
||||
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
|
||||
mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs);
|
||||
octx->ctx->prev_src1_spad = octx->src1_spad.data;
|
||||
} else {
|
||||
octx->src1_spad.data = octx->ctx->prev_src1_spad;
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
|
||||
@@ -589,8 +589,10 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
ggml_backend_buffer_t buffer = tensor->buffer;
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
|
||||
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
||||
} else {
|
||||
result.buffer = 0;
|
||||
result.data = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
result.ne[i] = tensor->ne[i];
|
||||
@@ -606,7 +608,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
||||
}
|
||||
result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
|
||||
result.view_offs = tensor->view_offs;
|
||||
result.data = reinterpret_cast<uint64_t>(tensor->data);
|
||||
|
||||
// Avoid sending uninitialized data over the wire
|
||||
memset(result.name, 0, sizeof(result.name));
|
||||
@@ -1443,9 +1444,11 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
|
||||
const rpc_tensor * tensor = it_ptr->second;
|
||||
|
||||
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
|
||||
if (result == nullptr || result->buffer == nullptr) {
|
||||
GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n",
|
||||
__func__, result == nullptr ? "tensor" : "buffer", id);
|
||||
if (result == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (result->buffer == nullptr && result->data != nullptr) {
|
||||
GGML_LOG_ERROR("[%s] invalid data ptr", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
tensor_map[id] = result;
|
||||
|
||||
@@ -1377,6 +1377,16 @@ struct clip_model_loader {
|
||||
|
||||
// sanity check
|
||||
{
|
||||
if (hparams.image_size < 0) {
|
||||
// note: some models having hparams.image_size == 0, which means the image size is dynamic
|
||||
throw std::runtime_error(string_format("%s: image_size (%d) cannot be negative\n", __func__, hparams.image_size));
|
||||
}
|
||||
if (hparams.patch_size <= 0) {
|
||||
throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size));
|
||||
}
|
||||
if (hparams.n_embd <= 0) {
|
||||
throw std::runtime_error(string_format("%s: n_embd (%d) must be greater than 0\n", __func__, hparams.n_embd));
|
||||
}
|
||||
if (hparams.image_max_pixels < hparams.image_min_pixels) {
|
||||
throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
|
||||
}
|
||||
|
||||
@@ -13,23 +13,20 @@
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
void mtmd_audio_cache::fill_sin_cos_table(int n) {
|
||||
void mtmd_audio_cache::fill_sin_cos_table(uint32_t n) {
|
||||
sin_vals.resize(n);
|
||||
cos_vals.resize(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (uint32_t i = 0; i < n; i++) {
|
||||
double theta = (2 * M_PI * i) / n;
|
||||
sin_vals[i] = sinf(theta);
|
||||
cos_vals[i] = cosf(theta);
|
||||
}
|
||||
}
|
||||
|
||||
void mtmd_audio_cache::fill_hann_window(int length, bool periodic) {
|
||||
void mtmd_audio_cache::fill_hann_window(uint32_t length, bool periodic) {
|
||||
hann_window.resize(length);
|
||||
int offset = -1;
|
||||
if (periodic) {
|
||||
offset = 0;
|
||||
}
|
||||
for (int i = 0; i < length; i++) {
|
||||
int offset = periodic ? 0 : -1;
|
||||
for (uint32_t i = 0; i < length; i++) {
|
||||
hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
|
||||
}
|
||||
}
|
||||
@@ -165,6 +162,7 @@ static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, fl
|
||||
// false = input is complex-valued (interleaved real/imag, stride 2)
|
||||
template <bool Inverse, bool RealInput>
|
||||
static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) {
|
||||
GGML_ASSERT(N > 0);
|
||||
const int n_sin_cos_vals = cache.sin_vals.size();
|
||||
|
||||
if (N == 1) {
|
||||
@@ -407,6 +405,8 @@ static bool log_mel_spectrogram(
|
||||
}
|
||||
|
||||
|
||||
GGML_ASSERT(params.n_fft_bins > 0);
|
||||
GGML_ASSERT(params.hop_length > 0);
|
||||
out.n_mel = params.n_mel;
|
||||
out.n_len = (n_samples - frame_size) / frame_step + 1;
|
||||
// TODO: handle these checks better
|
||||
@@ -438,6 +438,7 @@ static bool log_mel_spectrogram(
|
||||
|
||||
const int effective_n_len = n_samples_in / frame_step;
|
||||
if (params.norm_per_feature) {
|
||||
GGML_ASSERT(effective_n_len > 1);
|
||||
for (int i = 0; i < out.n_mel; i++) {
|
||||
double mean = 0;
|
||||
for (int j = 0; j < effective_n_len; ++j) {
|
||||
@@ -639,6 +640,7 @@ mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length
|
||||
padding_to_remove((n_fft - hop_length) / 2),
|
||||
ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT
|
||||
ifft_out(n_fft * 2 * 4, 0.0f) {
|
||||
GGML_ASSERT(n_fft > 0 && hop_length > 0 && hop_length <= n_fft);
|
||||
cache.fill_sin_cos_table(n_fft);
|
||||
cache.fill_hann_window(n_fft, true);
|
||||
}
|
||||
|
||||
@@ -33,9 +33,9 @@ struct mtmd_audio_cache {
|
||||
|
||||
mtmd_audio_mel_filters filters;
|
||||
|
||||
void fill_sin_cos_table(int n);
|
||||
void fill_sin_cos_table(uint32_t n);
|
||||
|
||||
void fill_hann_window(int length, bool periodic);
|
||||
void fill_hann_window(uint32_t length, bool periodic);
|
||||
|
||||
// Build mel filterbank matrix [n_mel × n_fft_bins] at runtime.
|
||||
// n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257.
|
||||
|
||||
@@ -127,6 +127,7 @@ struct decode_embd_batch {
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) {
|
||||
GGML_ASSERT(n_tokens > 0 && n_pos_per_embd > 0 && n_mmproj_embd > 0);
|
||||
pos .resize(n_tokens * n_pos_per_embd);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
@@ -157,6 +158,7 @@ struct decode_embd_batch {
|
||||
// M-RoPE for image
|
||||
void set_position_mrope_2d(llama_pos pos_0, int nx, int ny, llama_seq_id seq_id) {
|
||||
GGML_ASSERT(n_pos_per_embd == 4);
|
||||
GGML_ASSERT(nx > 0 && ny > 0 && nx * ny == batch.n_tokens);
|
||||
seq_id_0[0] = seq_id;
|
||||
for (int y = 0; y < ny; y++) {
|
||||
for (int x = 0; x < nx; x++) {
|
||||
@@ -192,6 +194,7 @@ struct decode_embd_batch {
|
||||
}
|
||||
|
||||
llama_batch get_view(int offset, int n_tokens) {
|
||||
GGML_ASSERT(offset >= 0 && n_tokens > 0 && offset + n_tokens <= batch.n_tokens);
|
||||
llama_pos * pos_ptr;
|
||||
pos_view.clear();
|
||||
pos_view.reserve(n_tokens * n_pos_per_embd);
|
||||
@@ -235,6 +238,7 @@ int32_t mtmd_helper_decode_image_chunk(
|
||||
llama_seq_id seq_id,
|
||||
int32_t n_batch,
|
||||
llama_pos * new_n_past) {
|
||||
GGML_ASSERT(n_batch > 0);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
const char * name = chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE ? "image" : "audio";
|
||||
if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
|
||||
@@ -312,6 +316,7 @@ int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
|
||||
int32_t n_batch,
|
||||
bool logits_last,
|
||||
llama_pos * new_n_past) {
|
||||
GGML_ASSERT(n_batch > 0);
|
||||
int32_t ret;
|
||||
llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
|
||||
auto chunk_type = mtmd_input_chunk_get_type(chunk);
|
||||
@@ -508,6 +513,11 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *
|
||||
fseek(f, 0, SEEK_END);
|
||||
long file_size = ftell(f);
|
||||
fseek(f, 0, SEEK_SET);
|
||||
if (file_size < 0) {
|
||||
LOG_ERR("Failed to get file size of %s\n", fname);
|
||||
fclose(f);
|
||||
return nullptr;
|
||||
}
|
||||
buf.resize(file_size);
|
||||
|
||||
size_t n_read = fread(buf.data(), 1, file_size, f);
|
||||
|
||||
@@ -99,6 +99,8 @@ struct img_tool {
|
||||
}
|
||||
|
||||
static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
|
||||
GGML_ASSERT(x >= 0 && y >= 0 && w > 0 && h > 0);
|
||||
GGML_ASSERT(x + w <= image.nx && y + h <= image.ny);
|
||||
dst.nx = w;
|
||||
dst.ny = h;
|
||||
dst.buf.resize(3 * w * h);
|
||||
@@ -196,6 +198,7 @@ struct img_tool {
|
||||
private:
|
||||
// Bilinear resize function
|
||||
static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
|
||||
GGML_ASSERT(src.nx >= 2 && src.ny >= 2);
|
||||
dst.nx = target_width;
|
||||
dst.ny = target_height;
|
||||
dst.buf.resize(3 * target_width * target_height);
|
||||
@@ -207,8 +210,8 @@ private:
|
||||
for (int x = 0; x < target_width; x++) {
|
||||
float px = x_ratio * x;
|
||||
float py = y_ratio * y;
|
||||
int x_floor = static_cast<int>(px);
|
||||
int y_floor = static_cast<int>(py);
|
||||
int x_floor = std::min(static_cast<int>(px), src.nx - 2);
|
||||
int y_floor = std::min(static_cast<int>(py), src.ny - 2);
|
||||
float x_lerp = px - x_floor;
|
||||
float y_lerp = py - y_floor;
|
||||
|
||||
@@ -347,6 +350,7 @@ private:
|
||||
// Returns: kernel size (ksize) - number of input pixels that contribute to each output pixel
|
||||
auto precompute_weights = [&](int inSize, int outSize,
|
||||
std::vector<int> & bounds, std::vector<int32_t> & weights) -> int {
|
||||
GGML_ASSERT(inSize > 0 && outSize > 0);
|
||||
double support, scale, filterscale;
|
||||
double center, ww, ss;
|
||||
int xx, x, ksize, xmin, xmax, xcnt;
|
||||
|
||||
@@ -641,6 +641,11 @@ struct mtmd_tokenizer {
|
||||
add_text(ctx->img_beg, true); // add image begin token
|
||||
}
|
||||
|
||||
// sanity check
|
||||
GGML_ASSERT(bitmap->nx > 0 && bitmap->ny > 0);
|
||||
GGML_ASSERT(bitmap->data.size() == (size_t)bitmap->nx * bitmap->ny * 3);
|
||||
GGML_ASSERT(ctx->image_preproc != nullptr);
|
||||
|
||||
// convert mtmd_bitmap to clip_image_u8
|
||||
clip_image_u8_ptr img_u8(clip_image_u8_init());
|
||||
img_u8->nx = bitmap->nx;
|
||||
@@ -649,7 +654,6 @@ struct mtmd_tokenizer {
|
||||
std::memcpy(img_u8->buf.data(), bitmap->data.data(), img_u8->nx * img_u8->ny * 3);
|
||||
|
||||
// preprocess image
|
||||
GGML_ASSERT(ctx->image_preproc != nullptr);
|
||||
clip_image_f32_batch batch_f32;
|
||||
bool ok = ctx->image_preproc->preprocess(*img_u8, batch_f32);
|
||||
if (!ok) {
|
||||
@@ -773,6 +777,11 @@ struct mtmd_tokenizer {
|
||||
add_text(ctx->aud_beg, true); // add audio begin token
|
||||
}
|
||||
|
||||
// sanity check
|
||||
GGML_ASSERT(ctx->audio_preproc != nullptr);
|
||||
GGML_ASSERT(bitmap->data.size() > sizeof(float));
|
||||
GGML_ASSERT(bitmap->data.size() % sizeof(float) == 0);
|
||||
|
||||
// preprocess audio
|
||||
std::vector<mtmd_audio_mel> mel_spec_chunks;
|
||||
const float * samples = (const float *)bitmap->data.data();
|
||||
|
||||
@@ -13,6 +13,8 @@ add_library(${TARGET} STATIC
|
||||
server-common.h
|
||||
server-context.cpp
|
||||
server-context.h
|
||||
server-tools.cpp
|
||||
server-tools.h
|
||||
)
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
@@ -35,22 +37,29 @@ set(TARGET_SRCS
|
||||
server-models.cpp
|
||||
server-models.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
loading.html
|
||||
)
|
||||
|
||||
foreach(asset ${PUBLIC_ASSETS})
|
||||
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
|
||||
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
|
||||
list(APPEND TARGET_SRCS ${output})
|
||||
add_custom_command(
|
||||
DEPENDS "${input}"
|
||||
OUTPUT "${output}"
|
||||
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
||||
option(LLAMA_BUILD_WEBUI "Build the embedded Web UI" ON)
|
||||
|
||||
if (LLAMA_BUILD_WEBUI)
|
||||
set(PUBLIC_ASSETS
|
||||
index.html.gz
|
||||
loading.html
|
||||
)
|
||||
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
|
||||
endforeach()
|
||||
|
||||
foreach(asset ${PUBLIC_ASSETS})
|
||||
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
|
||||
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
|
||||
list(APPEND TARGET_SRCS ${output})
|
||||
add_custom_command(
|
||||
DEPENDS "${input}"
|
||||
OUTPUT "${output}"
|
||||
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
|
||||
)
|
||||
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
|
||||
endforeach()
|
||||
add_definitions(-DLLAMA_BUILD_WEBUI)
|
||||
else()
|
||||
endif()
|
||||
|
||||
add_executable(${TARGET} ${TARGET_SRCS})
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
@@ -125,6 +125,61 @@ The framework automatically starts a `llama-server` instance, sends requests, an
|
||||
|
||||
For detailed instructions, see the [test documentation](./tests/README.md).
|
||||
|
||||
### API for tools
|
||||
|
||||
This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
|
||||
|
||||
**GET /tools**
|
||||
|
||||
Get a list of tools, each tool has these fields:
|
||||
- `tool` (string): the ID name of the tool, to be used in POST call. Example: `read_file`
|
||||
- `display_name` (string): the name to be displayed on UI. Example: `Read file`
|
||||
- `type` (string): always be `"builtin"` for now
|
||||
- `permissions` (object): a mapping string --> boolean that indicates the permission required by this tool. This is useful for the UI to ask the user before calling the tool. For now, the only permission supported is `"write"`
|
||||
- `definition` (object): the OAI-compat definition of this tool
|
||||
|
||||
**POST /tools**
|
||||
|
||||
Invoke a tool call, request body is a JSON object with:
|
||||
- `tool` (string): the name of the tool
|
||||
- `params` (object): a mapping from argument name (string) to argument value
|
||||
|
||||
Returns JSON object. There are two response formats:
|
||||
|
||||
Format 1: Plain text. The text will be placed into a field called `plain_text_response`, example:
|
||||
|
||||
```json
|
||||
{
|
||||
"plain_text_response": "this is a text response"
|
||||
}
|
||||
```
|
||||
|
||||
The client should extract this value and place it inside message content (note: content is no longer a JSON), example
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "this is a text response"
|
||||
}
|
||||
```
|
||||
|
||||
Format 2: Normal JSON response, example:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "cannot open this file"
|
||||
}
|
||||
```
|
||||
|
||||
That requires `JSON.stringify` when formatted to message content:
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "{\"error\":\"cannot open this file\"}"
|
||||
}
|
||||
```
|
||||
|
||||
### Notable Related PRs
|
||||
|
||||
- Initial server implementation: https://github.com/ggml-org/llama.cpp/pull/1443
|
||||
|
||||
@@ -36,7 +36,6 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--license` | show source code license and dependencies |
|
||||
| `-cl, --cache-list` | show list of models in cache |
|
||||
| `--completion-bash` | print source-able bash completion script for llama.cpp |
|
||||
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
|
||||
| `-t, --threads N` | number of CPU threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
|
||||
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
|
||||
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
|
||||
@@ -194,6 +193,7 @@ For the full list of features, please refer to [server's changelog](https://gith
|
||||
| `--webui-config JSON` | JSON that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG) |
|
||||
| `--webui-config-file PATH` | JSON file that provides default WebUI settings (overrides WebUI defaults)<br/>(env: LLAMA_ARG_WEBUI_CONFIG_FILE) |
|
||||
| `--webui-mcp-proxy, --no-webui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)<br/>(env: LLAMA_ARG_WEBUI_MCP_PROXY) |
|
||||
| `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)<br/>specify "all" to enable all tools<br/>available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff<br/>(env: LLAMA_ARG_TOOLS) |
|
||||
| `--webui, --no-webui` | whether to enable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_WEBUI) |
|
||||
| `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
|
||||
| `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
|
||||
@@ -293,6 +293,12 @@ It is currently available in the following endpoints:
|
||||
|
||||
For more details, please refer to [multimodal documentation](../../docs/multimodal.md)
|
||||
|
||||
### Built-in tools support
|
||||
|
||||
The server includes a set of built-in tools that enable the LLM to access the local file system directly from the Web UI.
|
||||
|
||||
To use this feature, start the server with `--tools all`. You can also enable only specific tools by passing a comma-separated list: `--tools name1,name2,...`. Run `--help` for the full list of available tool names.
|
||||
|
||||
## Build
|
||||
|
||||
`llama-server` is built alongside everything else from the root of the project
|
||||
@@ -1438,6 +1444,14 @@ curl http://localhost:8080/v1/messages/count_tokens \
|
||||
{"input_tokens": 10}
|
||||
```
|
||||
|
||||
## Server built-in tools
|
||||
|
||||
The server exposes a REST API under `/tools` that allows the Web UI to call built-in tools. This endpoint is intended to be used internally by the Web UI and subject to change or to be removed in the future.
|
||||
|
||||
**Please do NOT use this endpoint in a downstream application**
|
||||
|
||||
For further documentation about this endpoint, please refer to [server internal documentation](./README-dev.md)
|
||||
|
||||
## Using multiple models
|
||||
|
||||
`llama-server` can be launched in a **router mode** that exposes an API for dynamically loading and unloading models. The main process (the "router") automatically forwards each request to the appropriate model instance.
|
||||
|
||||
Binary file not shown.
@@ -8,9 +8,11 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
// auto generated files (see README.md for details)
|
||||
#include "index.html.gz.hpp"
|
||||
#include "loading.html.hpp"
|
||||
#endif
|
||||
|
||||
//
|
||||
// HTTP implementation using cpp-httplib
|
||||
@@ -181,11 +183,14 @@ bool server_http_context::init(const common_params & params) {
|
||||
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
|
||||
bool ready = is_ready.load();
|
||||
if (!ready) {
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
auto tmp = string_split<std::string>(req.path, '.');
|
||||
if (req.path == "/" || tmp.back() == "html") {
|
||||
res.status = 503;
|
||||
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
||||
} else {
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
// no endpoints is allowed to be accessed when the server is not ready
|
||||
// this is to prevent any data races or inconsistent states
|
||||
res.status = 503;
|
||||
@@ -255,6 +260,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
#ifdef LLAMA_BUILD_WEBUI
|
||||
// using embedded static index.html
|
||||
srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
|
||||
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||
@@ -268,6 +274,7 @@ bool server_http_context::init(const common_params & params) {
|
||||
}
|
||||
return false;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
800
tools/server/server-tools.cpp
Normal file
800
tools/server/server-tools.cpp
Normal file
@@ -0,0 +1,800 @@
|
||||
#include "server-tools.h"
|
||||
|
||||
#include <sheredom/subprocess.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <climits>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
//
|
||||
// internal helpers
|
||||
//
|
||||
|
||||
static std::vector<char *> to_cstr_vec(const std::vector<std::string> & v) {
|
||||
std::vector<char *> r;
|
||||
r.reserve(v.size() + 1);
|
||||
for (const auto & s : v) {
|
||||
r.push_back(const_cast<char *>(s.c_str()));
|
||||
}
|
||||
r.push_back(nullptr);
|
||||
return r;
|
||||
}
|
||||
|
||||
struct run_proc_result {
|
||||
std::string output;
|
||||
int exit_code = -1;
|
||||
bool timed_out = false;
|
||||
};
|
||||
|
||||
static run_proc_result run_process(
|
||||
const std::vector<std::string> & args,
|
||||
size_t max_output,
|
||||
int timeout_secs) {
|
||||
run_proc_result res;
|
||||
|
||||
subprocess_s proc;
|
||||
auto argv = to_cstr_vec(args);
|
||||
|
||||
int options = subprocess_option_no_window
|
||||
| subprocess_option_combined_stdout_stderr
|
||||
| subprocess_option_inherit_environment
|
||||
| subprocess_option_search_user_path;
|
||||
|
||||
if (subprocess_create(argv.data(), options, &proc) != 0) {
|
||||
res.output = "failed to spawn process";
|
||||
return res;
|
||||
}
|
||||
|
||||
std::atomic<bool> done{false};
|
||||
std::atomic<bool> timed_out{false};
|
||||
|
||||
std::thread timeout_thread([&]() {
|
||||
auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(timeout_secs);
|
||||
while (!done.load()) {
|
||||
if (std::chrono::steady_clock::now() >= deadline) {
|
||||
timed_out.store(true);
|
||||
subprocess_terminate(&proc);
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
});
|
||||
|
||||
FILE * f = subprocess_stdout(&proc);
|
||||
std::string output;
|
||||
bool truncated = false;
|
||||
if (f) {
|
||||
char buf[4096];
|
||||
while (fgets(buf, sizeof(buf), f) != nullptr) {
|
||||
if (!truncated) {
|
||||
size_t len = strlen(buf);
|
||||
if (output.size() + len <= max_output) {
|
||||
output.append(buf, len);
|
||||
} else {
|
||||
output.append(buf, max_output - output.size());
|
||||
truncated = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done.store(true);
|
||||
if (timeout_thread.joinable()) {
|
||||
timeout_thread.join();
|
||||
}
|
||||
|
||||
subprocess_join(&proc, &res.exit_code);
|
||||
subprocess_destroy(&proc);
|
||||
|
||||
res.output = output;
|
||||
res.timed_out = timed_out.load();
|
||||
if (truncated) {
|
||||
res.output += "\n[output truncated]";
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// simple glob: * matches non-/ chars, ** matches anything including /
|
||||
static bool glob_match(const char * pattern, const char * str) {
|
||||
if (*pattern == '\0') {
|
||||
return *str == '\0';
|
||||
}
|
||||
if (pattern[0] == '*' && pattern[1] == '*') {
|
||||
const char * p = pattern + 2;
|
||||
if (*p == '/') p++;
|
||||
if (glob_match(p, str)) return true;
|
||||
if (*str != '\0') return glob_match(pattern, str + 1);
|
||||
return false;
|
||||
}
|
||||
if (*pattern == '*') {
|
||||
const char * p = pattern + 1;
|
||||
for (; *str != '\0' && *str != '/'; str++) {
|
||||
if (glob_match(p, str)) return true;
|
||||
}
|
||||
return glob_match(p, str);
|
||||
}
|
||||
if (*pattern == '?' && *str != '\0' && *str != '/') {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
if (*pattern == *str) {
|
||||
return glob_match(pattern + 1, str + 1);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool glob_match(const std::string & pattern, const std::string & str) {
|
||||
return glob_match(pattern.c_str(), str.c_str());
|
||||
}
|
||||
|
||||
json server_tool::to_json() {
|
||||
return {
|
||||
{"display_name", display_name},
|
||||
{"tool", name},
|
||||
{"type", "builtin"},
|
||||
{"permissions", json{
|
||||
{"write", permission_write}
|
||||
}},
|
||||
{"definition", get_definition()},
|
||||
};
|
||||
}
|
||||
|
||||
//
|
||||
// read_file: read a file with optional line range and line-number prefix
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_READ_FILE_MAX_SIZE = 16 * 1024; // 16 KB
|
||||
|
||||
struct server_tool_read_file : server_tool {
|
||||
server_tool_read_file() {
|
||||
name = "read_file";
|
||||
display_name = "Read file";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Read the contents of a file. Optionally specify a 1-based line range. "
|
||||
"If append_loc is true, each line is prefixed with its line number (e.g. \"1\u2192 ...\")."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path to the file"}}},
|
||||
{"start_line", {{"type", "integer"}, {"description", "First line to read, 1-based (default: 1)"}}},
|
||||
{"end_line", {{"type", "integer"}, {"description", "Last line to read, 1-based inclusive (default: end of file)"}}},
|
||||
{"append_loc", {{"type", "boolean"}, {"description", "Prefix each line with its line number"}}},
|
||||
}},
|
||||
{"required", json::array({"path"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
int start_line = json_value(params, "start_line", 1);
|
||||
int end_line = json_value(params, "end_line", -1); // -1 = no limit
|
||||
bool append_loc = json_value(params, "append_loc", false);
|
||||
|
||||
std::error_code ec;
|
||||
uintmax_t file_size = fs::file_size(path, ec);
|
||||
if (ec) {
|
||||
return {{"error", "cannot stat file: " + ec.message()}};
|
||||
}
|
||||
if (file_size > SERVER_TOOL_READ_FILE_MAX_SIZE && end_line == -1) {
|
||||
return {{"error", string_format(
|
||||
"file too large (%zu bytes, max %zu). Use start_line/end_line to read a portion.",
|
||||
(size_t)file_size, SERVER_TOOL_READ_FILE_MAX_SIZE)}};
|
||||
}
|
||||
|
||||
std::ifstream f(path);
|
||||
if (!f) {
|
||||
return {{"error", "failed to open file: " + path}};
|
||||
}
|
||||
|
||||
std::string result;
|
||||
std::string line;
|
||||
int lineno = 0;
|
||||
|
||||
while (std::getline(f, line)) {
|
||||
lineno++;
|
||||
if (lineno < start_line) continue;
|
||||
if (end_line != -1 && lineno > end_line) break;
|
||||
|
||||
std::string out_line;
|
||||
if (append_loc) {
|
||||
out_line = std::to_string(lineno) + "\u2192 " + line + "\n";
|
||||
} else {
|
||||
out_line = line + "\n";
|
||||
}
|
||||
|
||||
if (result.size() + out_line.size() > SERVER_TOOL_READ_FILE_MAX_SIZE) {
|
||||
result += "[output truncated]";
|
||||
break;
|
||||
}
|
||||
result += out_line;
|
||||
}
|
||||
|
||||
return {{"plain_text_response", result}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// file_glob_search: find files matching a glob pattern under a base directory
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_FILE_SEARCH_MAX_RESULTS = 100;
|
||||
|
||||
struct server_tool_file_glob_search : server_tool {
|
||||
server_tool_file_glob_search() {
|
||||
name = "file_glob_search";
|
||||
display_name = "File search";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Recursively search for files matching a glob pattern under a directory."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Base directory to search in"}}},
|
||||
{"include", {{"type", "string"}, {"description", "Glob pattern for files to include (e.g. \"**/*.cpp\"). Default: **"}}},
|
||||
{"exclude", {{"type", "string"}, {"description", "Glob pattern for files to exclude"}}},
|
||||
}},
|
||||
{"required", json::array({"path"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string base = params.at("path").get<std::string>();
|
||||
std::string include = json_value(params, "include", std::string("**"));
|
||||
std::string exclude = json_value(params, "exclude", std::string(""));
|
||||
|
||||
std::ostringstream output_text;
|
||||
size_t count = 0;
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : fs::recursive_directory_iterator(base,
|
||||
fs::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) continue;
|
||||
|
||||
std::string rel = fs::relative(entry.path(), base, ec).string();
|
||||
if (ec) continue;
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(include, rel)) continue;
|
||||
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
||||
|
||||
output_text << entry.path().string() << "\n";
|
||||
if (++count >= SERVER_TOOL_FILE_SEARCH_MAX_RESULTS) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
output_text << "\n---\nTotal matches: " << count << "\n";
|
||||
|
||||
return {{"plain_text_response", output_text.str()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// grep_search: search for a regex pattern in files
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_GREP_SEARCH_MAX_RESULTS = 100;
|
||||
|
||||
struct server_tool_grep_search : server_tool {
|
||||
server_tool_grep_search() {
|
||||
name = "grep_search";
|
||||
display_name = "Grep search";
|
||||
permission_write = false;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Search for a regex pattern in files under a path. Returns matching lines."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "File or directory to search in"}}},
|
||||
{"pattern", {{"type", "string"}, {"description", "Regular expression pattern to search for"}}},
|
||||
{"include", {{"type", "string"}, {"description", "Glob pattern to filter files (default: **)"}}},
|
||||
{"exclude", {{"type", "string"}, {"description", "Glob pattern to exclude files"}}},
|
||||
{"return_line_numbers", {{"type", "boolean"}, {"description", "If true, include line numbers in results"}}},
|
||||
}},
|
||||
{"required", json::array({"path", "pattern"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
std::string pat_str = params.at("pattern").get<std::string>();
|
||||
std::string include = json_value(params, "include", std::string("**"));
|
||||
std::string exclude = json_value(params, "exclude", std::string(""));
|
||||
bool show_lineno = json_value(params, "return_line_numbers", false);
|
||||
|
||||
std::regex pattern;
|
||||
try {
|
||||
pattern = std::regex(pat_str);
|
||||
} catch (const std::regex_error & e) {
|
||||
return {{"error", std::string("invalid regex: ") + e.what()}};
|
||||
}
|
||||
|
||||
std::ostringstream output_text;
|
||||
size_t total = 0;
|
||||
|
||||
auto search_file = [&](const fs::path & fpath) {
|
||||
std::ifstream f(fpath);
|
||||
if (!f) return;
|
||||
std::string line;
|
||||
int lineno = 0;
|
||||
while (std::getline(f, line) && total < SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) {
|
||||
lineno++;
|
||||
if (std::regex_search(line, pattern)) {
|
||||
output_text << fpath.string() << ":";
|
||||
if (show_lineno) {
|
||||
output_text << lineno << ":";
|
||||
}
|
||||
output_text << line << "\n";
|
||||
total++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::error_code ec;
|
||||
if (fs::is_regular_file(path, ec)) {
|
||||
search_file(path);
|
||||
} else if (fs::is_directory(path, ec)) {
|
||||
for (const auto & entry : fs::recursive_directory_iterator(path,
|
||||
fs::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) continue;
|
||||
if (total >= SERVER_TOOL_GREP_SEARCH_MAX_RESULTS) break;
|
||||
|
||||
std::string rel = fs::relative(entry.path(), path, ec).string();
|
||||
if (ec) continue;
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(include, rel)) continue;
|
||||
if (!exclude.empty() && glob_match(exclude, rel)) continue;
|
||||
|
||||
search_file(entry.path());
|
||||
}
|
||||
} else {
|
||||
return {{"error", "path does not exist: " + path}};
|
||||
}
|
||||
|
||||
output_text << "\n\n---\nTotal matches: " << total << "\n";
|
||||
|
||||
return {{"plain_text_response", output_text.str()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// exec_shell_command: run an arbitrary shell command
|
||||
//
|
||||
|
||||
static constexpr size_t SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE = 16 * 1024; // 16 KB
|
||||
static constexpr int SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT = 60; // seconds
|
||||
|
||||
struct server_tool_exec_shell_command : server_tool {
|
||||
server_tool_exec_shell_command() {
|
||||
name = "exec_shell_command";
|
||||
display_name = "Execute shell command";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Execute a shell command and return its output (stdout and stderr combined)."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"command", {{"type", "string"}, {"description", "Shell command to execute"}}},
|
||||
{"timeout", {{"type", "integer"}, {"description", string_format("Timeout in seconds (default 10, max %d)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT)}}},
|
||||
{"max_output_size", {{"type", "integer"}, {"description", string_format("Maximum output size in bytes (default %zu)", SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE)}}},
|
||||
}},
|
||||
{"required", json::array({"command"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string command = params.at("command").get<std::string>();
|
||||
int timeout = json_value(params, "timeout", 10);
|
||||
size_t max_output = (size_t) json_value(params, "max_output_size", (int) SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
||||
|
||||
timeout = std::min(timeout, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_TIMEOUT);
|
||||
max_output = std::min(max_output, SERVER_TOOL_EXEC_SHELL_COMMAND_MAX_OUTPUT_SIZE);
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> args = {"cmd", "/c", command};
|
||||
#else
|
||||
std::vector<std::string> args = {"sh", "-c", command};
|
||||
#endif
|
||||
|
||||
auto res = run_process(args, max_output, timeout);
|
||||
|
||||
std::string text_output = res.output;
|
||||
text_output += string_format("\n[exit code: %d]", res.exit_code);
|
||||
if (res.timed_out) {
|
||||
text_output += " [exit due to timed out]";
|
||||
}
|
||||
|
||||
return {{"plain_text_response", text_output}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// write_file: create or overwrite a file
|
||||
//
|
||||
|
||||
struct server_tool_write_file : server_tool {
|
||||
server_tool_write_file() {
|
||||
name = "write_file";
|
||||
display_name = "Write file";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Write content to a file, creating it (including parent directories) if it does not exist. May use with edit_file for more complex edits."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path of the file to write"}}},
|
||||
{"content", {{"type", "string"}, {"description", "Content to write"}}},
|
||||
}},
|
||||
{"required", json::array({"path", "content"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
std::string content = params.at("content").get<std::string>();
|
||||
|
||||
std::error_code ec;
|
||||
fs::path fpath(path);
|
||||
if (fpath.has_parent_path()) {
|
||||
fs::create_directories(fpath.parent_path(), ec);
|
||||
if (ec) {
|
||||
return {{"error", "failed to create directories: " + ec.message()}};
|
||||
}
|
||||
}
|
||||
|
||||
std::ofstream f(path, std::ios::binary);
|
||||
if (!f) {
|
||||
return {{"error", "failed to open file for writing: " + path}};
|
||||
}
|
||||
f << content;
|
||||
if (!f) {
|
||||
return {{"error", "failed to write file: " + path}};
|
||||
}
|
||||
|
||||
return {{"result", "file written successfully"}, {"path", path}, {"bytes", content.size()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// edit_file: edit file content via line-based changes
|
||||
//
|
||||
|
||||
struct server_tool_edit_file : server_tool {
|
||||
server_tool_edit_file() {
|
||||
name = "edit_file";
|
||||
display_name = "Edit file";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description",
|
||||
"Edit a file by applying a list of line-based changes. "
|
||||
"Each change targets a 1-based inclusive line range and has a mode: "
|
||||
"\"replace\" (replace lines with content), "
|
||||
"\"delete\" (remove lines, content must be empty string), "
|
||||
"\"append\" (insert content after line_end). "
|
||||
"Set line_start to -1 to target the end of file (line_end is ignored in that case). "
|
||||
"Changes must not overlap. They are applied in reverse line order automatically."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"path", {{"type", "string"}, {"description", "Path to the file to edit"}}},
|
||||
{"changes", {
|
||||
{"type", "array"},
|
||||
{"description", "List of changes to apply"},
|
||||
{"items", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"mode", {{"type", "string"}, {"description", "\"replace\", \"delete\", or \"append\""}}},
|
||||
{"line_start", {{"type", "integer"}, {"description", "First line of the range (1-based); use -1 for end of file"}}},
|
||||
{"line_end", {{"type", "integer"}, {"description", "Last line of the range (1-based, inclusive); ignored when line_start is -1"}}},
|
||||
{"content", {{"type", "string"}, {"description", "Content to insert; must be empty string for delete mode"}}},
|
||||
}},
|
||||
{"required", json::array({"mode", "line_start", "line_end", "content"})},
|
||||
}},
|
||||
}},
|
||||
}},
|
||||
{"required", json::array({"path", "changes"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string path = params.at("path").get<std::string>();
|
||||
const json & changes = params.at("changes");
|
||||
|
||||
if (!changes.is_array()) {
|
||||
return {{"error", "\"changes\" must be an array"}};
|
||||
}
|
||||
|
||||
// read file into lines
|
||||
std::ifstream fin(path);
|
||||
if (!fin) {
|
||||
return {{"error", "failed to open file: " + path}};
|
||||
}
|
||||
std::vector<std::string> lines;
|
||||
{
|
||||
std::string line;
|
||||
while (std::getline(fin, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
}
|
||||
fin.close();
|
||||
|
||||
// validate and collect changes, then sort descending by line_start
|
||||
struct change_entry {
|
||||
std::string mode;
|
||||
int line_start; // 1-based
|
||||
int line_end; // 1-based inclusive
|
||||
std::string content;
|
||||
};
|
||||
std::vector<change_entry> entries;
|
||||
entries.reserve(changes.size());
|
||||
|
||||
for (const auto & ch : changes) {
|
||||
change_entry e;
|
||||
e.mode = ch.at("mode").get<std::string>();
|
||||
e.line_start = ch.at("line_start").get<int>();
|
||||
e.line_end = ch.at("line_end").get<int>();
|
||||
e.content = ch.at("content").get<std::string>();
|
||||
|
||||
if (e.mode != "replace" && e.mode != "delete" && e.mode != "append") {
|
||||
return {{"error", "invalid mode \"" + e.mode + "\"; must be replace, delete, or append"}};
|
||||
}
|
||||
if (e.mode == "delete" && !e.content.empty()) {
|
||||
return {{"error", "content must be empty string for delete mode"}};
|
||||
}
|
||||
int n = (int) lines.size();
|
||||
if (e.line_start == -1) {
|
||||
// -1 means end of file; line_end is ignored — normalize to point past last line
|
||||
e.line_start = n + 1;
|
||||
e.line_end = n + 1;
|
||||
} else {
|
||||
if (e.line_start < 1 || e.line_end < e.line_start) {
|
||||
return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}};
|
||||
}
|
||||
if (e.line_end > n) {
|
||||
return {{"error", string_format("line_end %d exceeds file length %d", e.line_end, n)}};
|
||||
}
|
||||
}
|
||||
entries.push_back(std::move(e));
|
||||
}
|
||||
|
||||
// sort descending so earlier-indexed changes don't shift later ones
|
||||
std::sort(entries.begin(), entries.end(), [](const change_entry & a, const change_entry & b) {
|
||||
return a.line_start > b.line_start;
|
||||
});
|
||||
|
||||
// apply changes (0-based indices internally)
|
||||
for (const auto & e : entries) {
|
||||
int idx_start = e.line_start - 1; // 0-based
|
||||
int idx_end = e.line_end - 1; // 0-based inclusive
|
||||
|
||||
// split content into lines (preserve trailing newline awareness)
|
||||
std::vector<std::string> new_lines;
|
||||
if (!e.content.empty()) {
|
||||
std::istringstream ss(e.content);
|
||||
std::string ln;
|
||||
while (std::getline(ss, ln)) {
|
||||
new_lines.push_back(ln);
|
||||
}
|
||||
// if content ends with \n, getline consumed it — no extra empty line needed
|
||||
// if content does NOT end with \n, last line is still captured correctly
|
||||
}
|
||||
|
||||
if (e.mode == "replace") {
|
||||
// erase [idx_start, idx_end] and insert new_lines
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
lines.insert(lines.begin() + idx_start, new_lines.begin(), new_lines.end());
|
||||
} else if (e.mode == "delete") {
|
||||
lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1);
|
||||
} else { // append
|
||||
// idx_end + 1 may equal lines.size() when line_start == -1 (end of file)
|
||||
lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end());
|
||||
}
|
||||
}
|
||||
|
||||
// write file back
|
||||
std::ofstream fout(path, std::ios::binary);
|
||||
if (!fout) {
|
||||
return {{"error", "failed to open file for writing: " + path}};
|
||||
}
|
||||
for (size_t i = 0; i < lines.size(); i++) {
|
||||
fout << lines[i];
|
||||
if (i + 1 < lines.size()) {
|
||||
fout << "\n";
|
||||
}
|
||||
}
|
||||
if (!lines.empty()) {
|
||||
fout << "\n";
|
||||
}
|
||||
if (!fout) {
|
||||
return {{"error", "failed to write file: " + path}};
|
||||
}
|
||||
|
||||
return {{"result", "file edited successfully"}, {"path", path}, {"lines", (int) lines.size()}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// apply_diff: apply a unified diff via git apply
|
||||
//
|
||||
|
||||
struct server_tool_apply_diff : server_tool {
|
||||
server_tool_apply_diff() {
|
||||
name = "apply_diff";
|
||||
display_name = "Apply diff";
|
||||
permission_write = true;
|
||||
}
|
||||
|
||||
json get_definition() override {
|
||||
return {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", name},
|
||||
{"description", "Apply a unified diff to edit one or more files using git apply. Use this instead of edit_file when the changes are complex."},
|
||||
{"parameters", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"diff", {{"type", "string"}, {"description", "Unified diff content in git diff format"}}},
|
||||
}},
|
||||
{"required", json::array({"diff"})},
|
||||
}},
|
||||
}},
|
||||
};
|
||||
}
|
||||
|
||||
json invoke(json params) override {
|
||||
std::string diff = params.at("diff").get<std::string>();
|
||||
|
||||
// write diff to a temporary file
|
||||
static std::atomic<int> counter{0};
|
||||
std::string tmp_path = (fs::temp_directory_path() /
|
||||
("llama_patch_" + std::to_string(++counter) + ".patch")).string();
|
||||
|
||||
{
|
||||
std::ofstream f(tmp_path, std::ios::binary);
|
||||
if (!f) {
|
||||
return {{"error", "failed to create temp patch file"}};
|
||||
}
|
||||
f << diff;
|
||||
}
|
||||
|
||||
auto res = run_process({"git", "apply", tmp_path}, 4096, 10);
|
||||
|
||||
std::error_code ec;
|
||||
fs::remove(tmp_path, ec);
|
||||
|
||||
if (res.exit_code != 0) {
|
||||
return {{"error", "git apply failed (exit " + std::to_string(res.exit_code) + "): " + res.output}};
|
||||
}
|
||||
return {{"result", "patch applied successfully"}};
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// public API
|
||||
//
|
||||
|
||||
static std::vector<std::unique_ptr<server_tool>> build_tools() {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
tools.push_back(std::make_unique<server_tool_read_file>());
|
||||
tools.push_back(std::make_unique<server_tool_file_glob_search>());
|
||||
tools.push_back(std::make_unique<server_tool_grep_search>());
|
||||
tools.push_back(std::make_unique<server_tool_exec_shell_command>());
|
||||
tools.push_back(std::make_unique<server_tool_write_file>());
|
||||
tools.push_back(std::make_unique<server_tool_edit_file>());
|
||||
tools.push_back(std::make_unique<server_tool_apply_diff>());
|
||||
return tools;
|
||||
}
|
||||
|
||||
void server_tools::setup(const std::vector<std::string> & enabled_tools) {
|
||||
if (!enabled_tools.empty()) {
|
||||
std::unordered_set<std::string> enabled_set(enabled_tools.begin(), enabled_tools.end());
|
||||
auto all_tools = build_tools();
|
||||
|
||||
tools.clear();
|
||||
for (auto & t : all_tools) {
|
||||
if (enabled_set.count(t->name) > 0 || enabled_set.count("all") > 0) {
|
||||
tools.push_back(std::move(t));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handle_get = [this](const server_http_req &) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json result = json::array();
|
||||
for (const auto & t : tools) {
|
||||
result.push_back(t->to_json());
|
||||
}
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
|
||||
handle_post = [this](const server_http_req & req) -> server_http_res_ptr {
|
||||
auto res = std::make_unique<server_http_res>();
|
||||
try {
|
||||
json body = json::parse(req.body);
|
||||
std::string tool_name = body.at("tool").get<std::string>();
|
||||
json params = body.value("params", json::object());
|
||||
json result = invoke(tool_name, params);
|
||||
res->data = safe_json_to_str(result);
|
||||
} catch (const json::exception & e) {
|
||||
res->status = 400;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||
} catch (const std::exception & e) {
|
||||
SRV_ERR("got exception: %s\n", e.what());
|
||||
res->status = 500;
|
||||
res->data = safe_json_to_str(format_error_response(e.what(), ERROR_TYPE_SERVER));
|
||||
}
|
||||
return res;
|
||||
};
|
||||
}
|
||||
|
||||
json server_tools::invoke(const std::string & name, const json & params) {
|
||||
for (auto & t : tools) {
|
||||
if (t->name == name) {
|
||||
return t->invoke(params);
|
||||
}
|
||||
}
|
||||
return {{"error", "unknown tool: " + name}};
|
||||
}
|
||||
26
tools/server/server-tools.h
Normal file
26
tools/server/server-tools.h
Normal file
@@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
struct server_tool {
|
||||
std::string name;
|
||||
std::string display_name;
|
||||
bool permission_write = false;
|
||||
|
||||
virtual ~server_tool() = default;
|
||||
virtual json get_definition() = 0;
|
||||
virtual json invoke(json params) = 0;
|
||||
|
||||
json to_json();
|
||||
};
|
||||
|
||||
struct server_tools {
|
||||
std::vector<std::unique_ptr<server_tool>> tools;
|
||||
|
||||
void setup(const std::vector<std::string> & enabled_tools);
|
||||
json invoke(const std::string & name, const json & params);
|
||||
|
||||
server_http_context::handler_t handle_get;
|
||||
server_http_context::handler_t handle_post;
|
||||
};
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "server-http.h"
|
||||
#include "server-models.h"
|
||||
#include "server-cors-proxy.h"
|
||||
#include "server-tools.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
@@ -124,6 +125,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// register API routes
|
||||
server_routes routes(params, ctx_server);
|
||||
server_tools tools;
|
||||
|
||||
bool is_router_server = params.model.path.empty();
|
||||
std::optional<server_models_routes> models_routes{};
|
||||
@@ -211,6 +213,16 @@ int main(int argc, char ** argv) {
|
||||
ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get));
|
||||
ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post));
|
||||
}
|
||||
// EXPERIMENTAL built-in tools
|
||||
if (!params.server_tools.empty()) {
|
||||
tools.setup(params.server_tools);
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
SRV_WRN("%s", "Built-in tools are enabled, do not expose server to untrusted environments\n");
|
||||
SRV_WRN("%s", "This feature is EXPERIMENTAL and may be changed in the future\n");
|
||||
SRV_WRN("%s", "-----------------\n");
|
||||
ctx_http.get ("/tools", ex_wrapper(tools.handle_get));
|
||||
ctx_http.post("/tools", ex_wrapper(tools.handle_post));
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
|
||||
54
tools/server/webui/src/lib/actions/fade-in-view.svelte.ts
Normal file
54
tools/server/webui/src/lib/actions/fade-in-view.svelte.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Svelte action that fades in an element when it enters the viewport.
|
||||
* Uses IntersectionObserver for efficient viewport detection.
|
||||
*
|
||||
* If skipIfVisible is set and the element is already visible in the viewport
|
||||
* when the action attaches (e.g. a markdown block promoted from unstable
|
||||
* during streaming), the fade is skipped entirely to avoid a flash.
|
||||
*/
|
||||
export function fadeInView(
|
||||
node: HTMLElement,
|
||||
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible) {
|
||||
const rect = node.getBoundingClientRect();
|
||||
const isAlreadyVisible =
|
||||
rect.top < window.innerHeight &&
|
||||
rect.bottom > 0 &&
|
||||
rect.left < window.innerWidth &&
|
||||
rect.right > 0;
|
||||
|
||||
if (isAlreadyVisible) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
node.style.opacity = '0';
|
||||
node.style.transform = `translateY(${y}px)`;
|
||||
node.style.transition = `opacity ${duration}ms ease-out, transform ${duration}ms ease-out`;
|
||||
|
||||
$effect(() => {
|
||||
const observer = new IntersectionObserver(
|
||||
(entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.isIntersecting) {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
},
|
||||
{ threshold: 0.05 }
|
||||
);
|
||||
|
||||
observer.observe(node);
|
||||
|
||||
return () => {
|
||||
observer.disconnect();
|
||||
};
|
||||
});
|
||||
}
|
||||
@@ -3,14 +3,12 @@
|
||||
ChatMessageAgenticContent,
|
||||
ChatMessageActions,
|
||||
ChatMessageStatistics,
|
||||
MarkdownContent,
|
||||
ModelBadge,
|
||||
ModelsSelector
|
||||
} from '$lib/components/app';
|
||||
import { getMessageEditContext } from '$lib/contexts';
|
||||
import { useProcessingState } from '$lib/hooks/use-processing-state.svelte';
|
||||
import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { agenticStreamingToolCall } from '$lib/stores/agentic.svelte';
|
||||
import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils';
|
||||
import { tick } from 'svelte';
|
||||
import { fade } from 'svelte/transition';
|
||||
@@ -87,13 +85,7 @@
|
||||
const hasAgenticMarkers = $derived(
|
||||
messageContent?.includes(AGENTIC_TAGS.TOOL_CALL_START) ?? false
|
||||
);
|
||||
const hasStreamingToolCall = $derived(
|
||||
isChatStreaming() && agenticStreamingToolCall(message.convId) !== null
|
||||
);
|
||||
const hasReasoningMarkers = $derived(messageContent?.includes(REASONING_TAGS.START) ?? false);
|
||||
const isStructuredContent = $derived(
|
||||
hasAgenticMarkers || hasReasoningMarkers || hasStreamingToolCall
|
||||
);
|
||||
const processingState = useProcessingState();
|
||||
|
||||
let currentConfig = $derived(config());
|
||||
@@ -256,15 +248,13 @@
|
||||
{:else if message.role === MessageRole.ASSISTANT}
|
||||
{#if showRawOutput}
|
||||
<pre class="raw-output">{messageContent || ''}</pre>
|
||||
{:else if isStructuredContent}
|
||||
{:else}
|
||||
<ChatMessageAgenticContent
|
||||
content={messageContent || ''}
|
||||
isStreaming={isChatStreaming()}
|
||||
highlightTurns={highlightAgenticTurns}
|
||||
{message}
|
||||
/>
|
||||
{:else}
|
||||
<MarkdownContent content={messageContent || ''} attachments={message.extra} />
|
||||
{/if}
|
||||
{:else}
|
||||
<div class="text-sm whitespace-pre-wrap">
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { ChatMessage } from '$lib/components/app';
|
||||
import { setChatActionsContext } from '$lib/contexts';
|
||||
import { MessageRole } from '$lib/enums';
|
||||
@@ -140,13 +141,18 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="flex h-full flex-col space-y-10 pt-24 {className}" style="height: auto; ">
|
||||
<div
|
||||
class="flex h-full flex-col space-y-10 pt-24 {className}"
|
||||
style="height: auto; min-height: calc(100dvh - 14rem);"
|
||||
>
|
||||
{#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<ChatMessage
|
||||
class="mx-auto w-full max-w-[48rem]"
|
||||
{message}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
<div use:fadeInView>
|
||||
<ChatMessage
|
||||
class="mx-auto w-full max-w-[48rem]"
|
||||
{message}
|
||||
{isLastAssistantMessage}
|
||||
{siblingInfo}
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
} from '$lib/components/app';
|
||||
import * as Alert from '$lib/components/ui/alert';
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { INITIAL_SCROLL_DELAY } from '$lib/constants';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import {
|
||||
@@ -48,7 +47,7 @@
|
||||
let showFileErrorDialog = $state(false);
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
|
||||
const autoScroll = createAutoScrollController();
|
||||
const autoScroll = createAutoScrollController({ isColumnReverse: true });
|
||||
|
||||
let fileErrorData = $state<{
|
||||
generallyUnsupported: File[];
|
||||
@@ -310,13 +309,15 @@
|
||||
|
||||
afterNavigate(() => {
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
autoScroll.enable();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
autoScroll.startObserving();
|
||||
|
||||
if (!disableAutoScroll) {
|
||||
setTimeout(() => autoScroll.scrollToBottom('instant'), INITIAL_SCROLL_DELAY);
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
@@ -333,10 +334,6 @@
|
||||
$effect(() => {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.updateInterval(isCurrentConversationLoading);
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if isDragOver}
|
||||
@@ -351,7 +348,7 @@
|
||||
<div
|
||||
bind:this={chatScrollContainer}
|
||||
aria-label="Chat interface with file drop zone"
|
||||
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
|
||||
class="flex h-full flex-col-reverse overflow-y-auto px-4 md:px-6"
|
||||
ondragenter={handleDragEnter}
|
||||
ondragleave={handleDragLeave}
|
||||
ondragover={handleDragOver}
|
||||
@@ -359,57 +356,59 @@
|
||||
onscroll={handleScroll}
|
||||
role="main"
|
||||
>
|
||||
<ChatMessages
|
||||
class="mb-16 md:mb-24"
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
autoScroll.scrollToBottom();
|
||||
}}
|
||||
/>
|
||||
<div class="flex flex-col">
|
||||
<ChatMessages
|
||||
class="mb-16 md:mb-24"
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
autoScroll.scrollToBottom();
|
||||
}}
|
||||
/>
|
||||
|
||||
<div
|
||||
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
|
||||
in:slide={{ duration: 150, axis: 'y' }}
|
||||
>
|
||||
<ChatScreenProcessingInfo />
|
||||
<div
|
||||
class="pointer-events-none sticky right-0 bottom-4 left-0 mt-auto"
|
||||
in:slide={{ duration: 150, axis: 'y' }}
|
||||
>
|
||||
<ChatScreenProcessingInfo />
|
||||
|
||||
{#if hasPropsError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
in:fly={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={isServerLoading}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
|
||||
{isServerLoading ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
{#if hasPropsError}
|
||||
<div
|
||||
class="pointer-events-auto mx-auto mb-4 max-w-[48rem] px-1"
|
||||
in:fly={{ y: 10, duration: 250 }}
|
||||
>
|
||||
<Alert.Root variant="destructive">
|
||||
<AlertTriangle class="h-4 w-4" />
|
||||
<Alert.Title class="flex items-center justify-between">
|
||||
<span>Server unavailable</span>
|
||||
<button
|
||||
onclick={() => serverStore.fetch()}
|
||||
disabled={isServerLoading}
|
||||
class="flex items-center gap-1.5 rounded-lg bg-destructive/20 px-2 py-1 text-xs font-medium hover:bg-destructive/30 disabled:opacity-50"
|
||||
>
|
||||
<RefreshCw class="h-3 w-3 {isServerLoading ? 'animate-spin' : ''}" />
|
||||
{isServerLoading ? 'Retrying...' : 'Retry'}
|
||||
</button>
|
||||
</Alert.Title>
|
||||
<Alert.Description>{serverError()}</Alert.Description>
|
||||
</Alert.Root>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={false}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
showHelperText={false}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
|
||||
interface Props {
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
@@ -598,7 +599,7 @@
|
||||
: ''}"
|
||||
>
|
||||
{#each renderedBlocks as block (block.id)}
|
||||
<div class="markdown-block" data-block-id={block.id}>
|
||||
<div class="markdown-block" data-block-id={block.id} use:fadeInView={{ skipIfVisible: true }}>
|
||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||
{@html block.html}
|
||||
</div>
|
||||
@@ -651,7 +652,6 @@
|
||||
/>
|
||||
|
||||
<style>
|
||||
.markdown-block,
|
||||
.markdown-block--unstable {
|
||||
display: contents;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
export const AUTO_SCROLL_INTERVAL = 100;
|
||||
export const INITIAL_SCROLL_DELAY = 50;
|
||||
export const AUTO_SCROLL_AT_BOTTOM_THRESHOLD = 10;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants';
|
||||
|
||||
export interface AutoScrollOptions {
|
||||
/** Whether auto-scroll is disabled globally (e.g., from settings) */
|
||||
disabled?: boolean;
|
||||
isColumnReverse?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -12,6 +12,7 @@ export interface AutoScrollOptions {
|
||||
* - Auto-scrolls to bottom during streaming/loading
|
||||
* - Stops auto-scroll when user manually scrolls up
|
||||
* - Resumes auto-scroll when user scrolls back to bottom
|
||||
* - Supports both normal and column-reverse scroll containers
|
||||
*/
|
||||
export class AutoScrollController {
|
||||
private _autoScrollEnabled = $state(true);
|
||||
@@ -21,9 +22,14 @@ export class AutoScrollController {
|
||||
private _scrollTimeout: ReturnType<typeof setTimeout> | undefined;
|
||||
private _container: HTMLElement | undefined;
|
||||
private _disabled: boolean;
|
||||
private _isColumnReverse: boolean;
|
||||
private _mutationObserver: MutationObserver | null = null;
|
||||
private _rafPending = false;
|
||||
private _observerEnabled = false;
|
||||
|
||||
constructor(options: AutoScrollOptions = {}) {
|
||||
this._disabled = options.disabled ?? false;
|
||||
this._isColumnReverse = options.isColumnReverse ?? false;
|
||||
}
|
||||
|
||||
get autoScrollEnabled(): boolean {
|
||||
@@ -38,7 +44,12 @@ export class AutoScrollController {
|
||||
* Binds the controller to a scrollable container element.
|
||||
*/
|
||||
setContainer(container: HTMLElement | undefined): void {
|
||||
this._doStopObserving();
|
||||
this._container = container;
|
||||
|
||||
if (this._observerEnabled && container && !this._disabled) {
|
||||
this._doStartObserving();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -49,6 +60,9 @@ export class AutoScrollController {
|
||||
if (disabled) {
|
||||
this._autoScrollEnabled = false;
|
||||
this.stopInterval();
|
||||
this._doStopObserving();
|
||||
} else if (this._observerEnabled && this._container && !this._mutationObserver) {
|
||||
this._doStartObserving();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,10 +73,23 @@ export class AutoScrollController {
|
||||
if (this._disabled || !this._container) return;
|
||||
|
||||
const { scrollTop, scrollHeight, clientHeight } = this._container;
|
||||
const distanceFromBottom = scrollHeight - scrollTop - clientHeight;
|
||||
|
||||
let distanceFromBottom: number;
|
||||
let isScrollingUp: boolean;
|
||||
|
||||
if (this._isColumnReverse) {
|
||||
// column-reverse: scrollTop=0 at bottom, negative when scrolled up
|
||||
distanceFromBottom = Math.abs(scrollTop);
|
||||
isScrollingUp = scrollTop < this._lastScrollTop;
|
||||
} else {
|
||||
// normal: scrollTop=0 at top, increases when scrolled down
|
||||
distanceFromBottom = scrollHeight - clientHeight - scrollTop;
|
||||
isScrollingUp = scrollTop < this._lastScrollTop;
|
||||
}
|
||||
|
||||
const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD;
|
||||
|
||||
if (scrollTop < this._lastScrollTop && !isAtBottom) {
|
||||
if (isScrollingUp && !isAtBottom) {
|
||||
this._userScrolledUp = true;
|
||||
this._autoScrollEnabled = false;
|
||||
} else if (isAtBottom && this._userScrolledUp) {
|
||||
@@ -90,10 +117,12 @@ export class AutoScrollController {
|
||||
scrollToBottom(behavior: ScrollBehavior = 'smooth'): void {
|
||||
if (this._disabled || !this._container) return;
|
||||
|
||||
this._container.scrollTo({
|
||||
top: this._container.scrollHeight,
|
||||
behavior
|
||||
});
|
||||
if (this._isColumnReverse) {
|
||||
// column-reverse: scrollTop=0 is the bottom
|
||||
this._container.scrollTo({ top: 0, behavior });
|
||||
} else {
|
||||
this._container.scrollTo({ top: this._container.scrollHeight, behavior });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -150,11 +179,69 @@ export class AutoScrollController {
|
||||
*/
|
||||
destroy(): void {
|
||||
this.stopInterval();
|
||||
this._doStopObserving();
|
||||
|
||||
if (this._scrollTimeout) {
|
||||
clearTimeout(this._scrollTimeout);
|
||||
this._scrollTimeout = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a MutationObserver on the container that auto-scrolls to bottom
|
||||
* on content changes. More responsive than interval-based polling.
|
||||
*/
|
||||
startObserving(): void {
|
||||
this._observerEnabled = true;
|
||||
|
||||
if (this._container && !this._disabled && !this._mutationObserver) {
|
||||
this._doStartObserving();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the MutationObserver.
|
||||
*/
|
||||
stopObserving(): void {
|
||||
this._observerEnabled = false;
|
||||
this._doStopObserving();
|
||||
}
|
||||
|
||||
private _doStartObserving(): void {
|
||||
if (!this._container || this._mutationObserver) return;
|
||||
|
||||
const isReverse = this._isColumnReverse;
|
||||
|
||||
this._mutationObserver = new MutationObserver(() => {
|
||||
if (!this._autoScrollEnabled || this._rafPending) return;
|
||||
this._rafPending = true;
|
||||
requestAnimationFrame(() => {
|
||||
this._rafPending = false;
|
||||
if (this._autoScrollEnabled && this._container) {
|
||||
if (isReverse) {
|
||||
// column-reverse: scrollTop=0 is the bottom
|
||||
this._container.scrollTop = 0;
|
||||
} else {
|
||||
this._container.scrollTop = this._container.scrollHeight;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
this._mutationObserver.observe(this._container, {
|
||||
childList: true,
|
||||
subtree: true,
|
||||
characterData: true
|
||||
});
|
||||
}
|
||||
|
||||
private _doStopObserving(): void {
|
||||
if (this._mutationObserver) {
|
||||
this._mutationObserver.disconnect();
|
||||
this._mutationObserver = null;
|
||||
}
|
||||
this._rafPending = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user