mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-19 15:34:08 +00:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7a98e0fff | ||
|
|
8f8f2274ee | ||
|
|
c959b676be | ||
|
|
cd08fc3ecc | ||
|
|
cb5bb6cc05 | ||
|
|
a91d035b90 | ||
|
|
745cbcf2fe | ||
|
|
1cbd80f8cf | ||
|
|
85286f3548 | ||
|
|
d5fabe3682 | ||
|
|
8ff206097c | ||
|
|
77475530b8 | ||
|
|
3913f8730e | ||
|
|
76888d202e | ||
|
|
f1fbffb5c0 | ||
|
|
51abc96bdc | ||
|
|
07808ebb07 | ||
|
|
6d758839ff | ||
|
|
3d4053f77f | ||
|
|
dc381aa9a6 | ||
|
|
10d197409b | ||
|
|
b907255f4b | ||
|
|
28c39da7c6 | ||
|
|
106220562a | ||
|
|
a68f31edd7 |
@@ -22,6 +22,13 @@ AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLambdasOnASingleLine: Inline
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakBeforeMultilineStrings: true
|
||||
# Treat CUDA keywords/attributes as "attribute macros" and avoid breaking lines inside them
|
||||
AttributeMacros:
|
||||
- __host__
|
||||
- __device__
|
||||
- __global__
|
||||
- __forceinline__
|
||||
- __launch_bounds__
|
||||
BinPackArguments: true
|
||||
BinPackParameters: false # OnePerLine
|
||||
BitFieldColonSpacing: Both
|
||||
|
||||
@@ -17,14 +17,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||
# gfx906 is deprecated
|
||||
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201'
|
||||
#ARG ROCM_DOCKER_ARCH=gfx1100
|
||||
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
|
||||
#ARG ROCM_DOCKER_ARCH='gfx1151'
|
||||
|
||||
# Set ROCm architectured
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
# Enable ROCm
|
||||
# ENV CC=/opt/rocm/llvm/bin/clang
|
||||
# ENV CXX=/opt/rocm/llvm/bin/clang++
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
@@ -39,8 +36,16 @@ WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN git clone https://github.com/rocm/rocwmma --branch develop --depth 1
|
||||
|
||||
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
|
||||
cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
|
||||
cmake -S . -B build \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DCMAKE_HIP_FLAGS="-I$(pwd)/rocwmma/library/include/" \
|
||||
-DAMDGPU_TARGETS="$ROCM_DOCKER_ARCH" \
|
||||
-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
|
||||
&& cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib \
|
||||
|
||||
@@ -52,3 +52,11 @@ insert_final_newline = unset
|
||||
[vendor/miniaudio/miniaudio.h]
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
[tools/server/webui/**]
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
end_of_line = unset
|
||||
charset = unset
|
||||
trim_trailing_whitespace = unset
|
||||
insert_final_newline = unset
|
||||
|
||||
27
.github/workflows/build.yml
vendored
27
.github/workflows/build.yml
vendored
@@ -56,7 +56,7 @@ env:
|
||||
|
||||
jobs:
|
||||
macOS-latest-cmake-arm64:
|
||||
runs-on: macos-14
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
ctest -L main --verbose --timeout 900
|
||||
|
||||
macOS-latest-cmake-arm64-webgpu:
|
||||
runs-on: macos-14
|
||||
runs-on: macos-latest
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -711,6 +711,7 @@ jobs:
|
||||
|
||||
macOS-latest-swift:
|
||||
runs-on: macos-latest
|
||||
needs: ios-xcode-build
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -727,6 +728,12 @@ jobs:
|
||||
key: macOS-latest-swift
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Download xcframework artifact
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
continue-on-error: true
|
||||
@@ -748,11 +755,6 @@ jobs:
|
||||
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64"
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: xcodebuild for swift package
|
||||
id: xcodebuild
|
||||
run: |
|
||||
./build-xcframework.sh
|
||||
|
||||
windows-msys2:
|
||||
runs-on: windows-2025
|
||||
|
||||
@@ -1170,8 +1172,17 @@ jobs:
|
||||
run: |
|
||||
./build-xcframework.sh
|
||||
|
||||
- name: Upload xcframework artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
retention-days: 1
|
||||
|
||||
- name: Build Xcode project
|
||||
run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
|
||||
run: |
|
||||
xcodebuild -downloadPlatform iOS
|
||||
xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build
|
||||
|
||||
android-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
@@ -530,15 +530,13 @@ jobs:
|
||||
runs-on: windows-2022
|
||||
|
||||
env:
|
||||
# The ROCm version must correspond to the version used in the HIP SDK.
|
||||
ROCM_VERSION: "6.4.2"
|
||||
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- name: "radeon"
|
||||
gpu_targets: "gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -548,7 +546,7 @@ jobs:
|
||||
- name: Clone rocWMMA repository
|
||||
id: clone_rocwmma
|
||||
run: |
|
||||
git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
|
||||
git clone https://github.com/rocm/rocwmma --branch develop --depth 1
|
||||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
|
||||
229
.github/workflows/server.yml
vendored
229
.github/workflows/server.yml
vendored
@@ -76,51 +76,206 @@ jobs:
|
||||
run: |
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
# Setup nodejs (to be used for verifying bundled index.html)
|
||||
- uses: actions/setup-node@v4
|
||||
webui-setup:
|
||||
name: WebUI Setup
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
node-version: '22.11.0'
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: WebUI - Install dependencies
|
||||
id: webui_lint
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Cache node_modules
|
||||
uses: actions/cache@v4
|
||||
id: cache-node-modules
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.cache-node-modules.outputs.cache-hit != 'true'
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-check:
|
||||
needs: webui-setup
|
||||
name: WebUI Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Run type checking
|
||||
run: npm run check
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run linting
|
||||
run: npm run lint
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-build:
|
||||
needs: webui-check
|
||||
name: WebUI Build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Build application
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
webui-tests:
|
||||
needs: webui-build
|
||||
name: Run WebUI tests
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
|
||||
- name: Restore node_modules cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: tools/server/webui/node_modules
|
||||
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-node-modules-
|
||||
|
||||
- name: Install Playwright browsers
|
||||
run: npx playwright install --with-deps
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build Storybook
|
||||
run: npm run build-storybook
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run Client tests
|
||||
run: npm run test:client
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run Server tests
|
||||
run: npm run test:server
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run UI tests
|
||||
run: npm run test:ui
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Run E2E tests
|
||||
run: npm run test:e2e
|
||||
working-directory: tools/server/webui
|
||||
|
||||
server-build:
|
||||
needs: [webui-tests]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
|
||||
build_type: [RelWithDebInfo]
|
||||
include:
|
||||
- build_type: Release
|
||||
sanitizer: ""
|
||||
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
|
||||
|
||||
steps:
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
cd tools/server/webui
|
||||
npm ci
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install \
|
||||
build-essential \
|
||||
xxd \
|
||||
git \
|
||||
cmake \
|
||||
curl \
|
||||
wget \
|
||||
language-pack-en \
|
||||
libcurl4-openssl-dev
|
||||
|
||||
- name: WebUI - Check code format
|
||||
id: webui_format
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
git config --global --add safe.directory $(realpath .)
|
||||
cd tools/server/webui
|
||||
git status
|
||||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
npm run format
|
||||
git status
|
||||
modified_files="$(git status -s)"
|
||||
echo "Modified files: ${modified_files}"
|
||||
if [ -n "${modified_files}" ]; then
|
||||
echo "Files do not follow coding style. To fix: npm run format"
|
||||
echo "${modified_files}"
|
||||
exit 1
|
||||
fi
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: "tools/server/webui/package-lock.json"
|
||||
|
||||
- name: Verify bundled index.html
|
||||
id: verify_server_index_html
|
||||
run: |
|
||||
git config --global --add safe.directory $(realpath .)
|
||||
cd tools/server/webui
|
||||
git status
|
||||
- name: Install WebUI dependencies
|
||||
run: npm ci
|
||||
working-directory: tools/server/webui
|
||||
|
||||
npm run build
|
||||
git status
|
||||
modified_files="$(git status -s)"
|
||||
echo "Modified files: ${modified_files}"
|
||||
if [ -n "${modified_files}" ]; then
|
||||
echo "Repository is dirty or server/webui is not built as expected"
|
||||
echo "Hint: You may need to follow Web UI build guide in server/README.md"
|
||||
echo "${modified_files}"
|
||||
exit 1
|
||||
fi
|
||||
- name: Build WebUI
|
||||
run: npm run build
|
||||
working-directory: tools/server/webui
|
||||
|
||||
- name: Build (no OpenMP)
|
||||
id: cmake_build_no_openmp
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -148,3 +148,7 @@ poetry.toml
|
||||
/run-vim.sh
|
||||
/run-chat.sh
|
||||
.ccache/
|
||||
|
||||
# Code Workspace
|
||||
*.code-workspace
|
||||
|
||||
|
||||
7
.windsurf/rules/css-architecture.md
Normal file
7
.windsurf/rules/css-architecture.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
#### Tailwind & CSS
|
||||
|
||||
- We are using Tailwind v4 which uses oklch colors so we now want to refer to the CSS vars directly, without wrapping it with any color function like `hsla/hsl`, `rgba` etc.
|
||||
48
.windsurf/rules/sveltekit-architecture.md
Normal file
48
.windsurf/rules/sveltekit-architecture.md
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
# Coding rules
|
||||
|
||||
## Svelte & SvelteKit
|
||||
|
||||
### Services vs Stores Separation Pattern
|
||||
|
||||
#### `lib/services/` - Pure Business Logic
|
||||
|
||||
- **Purpose**: Stateless business logic and external communication
|
||||
- **Contains**:
|
||||
- API calls to external services (ApiService)
|
||||
- Pure business logic functions (ChatService, etc.)
|
||||
- **Rules**:
|
||||
- NO Svelte runes ($state, $derived, $effect)
|
||||
- NO reactive state management
|
||||
- Pure functions and classes only
|
||||
- Can import types but not stores
|
||||
- Focus on "how" - implementation details
|
||||
|
||||
#### `lib/stores/` - Reactive State Management
|
||||
|
||||
- **Purpose**: Svelte-specific reactive state with runes
|
||||
- **Contains**:
|
||||
- Reactive state classes with $state, $derived, $effect
|
||||
- Database operations (DatabaseStore)
|
||||
- UI-focused state management
|
||||
- Store orchestration logic
|
||||
- **Rules**:
|
||||
- USE Svelte runes for reactivity
|
||||
- Import and use services for business logic
|
||||
- NO direct database operations
|
||||
- NO direct API calls (use services)
|
||||
- Focus on "what" - reactive state for UI
|
||||
|
||||
#### Enforcement
|
||||
|
||||
- Services should be testable without Svelte
|
||||
- Stores should leverage Svelte's reactivity system
|
||||
- Clear separation: services handle data, stores handle state
|
||||
- Services can be reused across multiple stores
|
||||
|
||||
#### Misc
|
||||
|
||||
- Always use `let` for $derived state variables
|
||||
9
.windsurf/rules/tests.md
Normal file
9
.windsurf/rules/tests.md
Normal file
@@ -0,0 +1,9 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
# Automated Tests
|
||||
|
||||
## General rules
|
||||
|
||||
- NEVER include any test code in the production code - we should always have it in a separate dedicated files
|
||||
7
.windsurf/rules/typescript-architecture.md
Normal file
7
.windsurf/rules/typescript-architecture.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
trigger: manual
|
||||
---
|
||||
|
||||
## TypeScript
|
||||
|
||||
- Add JSDocs for functions
|
||||
@@ -58,6 +58,12 @@ if (MSVC)
|
||||
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT OFF)
|
||||
else()
|
||||
set(LLAMA_TOOLS_INSTALL_DEFAULT ${LLAMA_STANDALONE})
|
||||
endif()
|
||||
|
||||
#
|
||||
# option list
|
||||
#
|
||||
@@ -82,6 +88,7 @@ option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
|
||||
|
||||
@@ -1704,7 +1704,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.system_prompt = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
|
||||
add_opt(common_arg(
|
||||
{"--no-perf"},
|
||||
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
|
||||
@@ -2548,7 +2548,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--cpu-moe", "-cmoe"},
|
||||
"keep all Mixture of Experts (MoE) weights in the CPU",
|
||||
[](common_params & params) {
|
||||
params.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
|
||||
params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
|
||||
}
|
||||
).set_env("LLAMA_ARG_CPU_MOE"));
|
||||
add_opt(common_arg(
|
||||
@@ -2561,7 +2561,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
for (int i = 0; i < value; ++i) {
|
||||
// keep strings alive and avoid leaking memory by storing them in a static vector
|
||||
static std::list<std::string> buft_overrides;
|
||||
buft_overrides.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
|
||||
buft_overrides.push_back(llm_ffn_exps_block_regex(i));
|
||||
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
|
||||
}
|
||||
}
|
||||
@@ -2570,7 +2570,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
{"--cpu-moe-draft", "-cmoed"},
|
||||
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
|
||||
[](common_params & params) {
|
||||
params.speculative.tensor_buft_overrides.push_back({"\\.ffn_(up|down|gate)_exps", ggml_backend_cpu_buffer_type()});
|
||||
params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
@@ -2582,7 +2582,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
}
|
||||
for (int i = 0; i < value; ++i) {
|
||||
static std::list<std::string> buft_overrides_draft;
|
||||
buft_overrides_draft.push_back(string_format("blk\\.%d\\.ffn_(up|down|gate)_exps", i));
|
||||
buft_overrides_draft.push_back(llm_ffn_exps_block_regex(i));
|
||||
params.speculative.tensor_buft_overrides.push_back({buft_overrides_draft.back().c_str(), ggml_backend_cpu_buffer_type()});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -734,6 +734,20 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
}
|
||||
|
||||
//
|
||||
// MoE utils
|
||||
//
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
}
|
||||
|
||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||
}
|
||||
|
||||
//
|
||||
// training utils
|
||||
//
|
||||
|
||||
@@ -257,12 +257,13 @@ std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
||||
};
|
||||
|
||||
static bool is_reserved_name(const std::string & name) {
|
||||
static std::unordered_set<std::string> RESERVED_NAMES;
|
||||
if (RESERVED_NAMES.empty()) {
|
||||
RESERVED_NAMES.insert("root");
|
||||
for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
|
||||
for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
|
||||
}
|
||||
static const std::unordered_set<std::string> RESERVED_NAMES = [] {
|
||||
std::unordered_set<std::string> s;
|
||||
s.insert("root");
|
||||
for (const auto & p : PRIMITIVE_RULES) s.insert(p.first);
|
||||
for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first);
|
||||
return s;
|
||||
}();
|
||||
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
|
||||
}
|
||||
|
||||
|
||||
@@ -888,6 +888,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
|
||||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
|
||||
res = "llada-moe"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -2390,7 +2393,10 @@ class SmolVLMModel(MmprojModel):
|
||||
return [] # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Llama4ForConditionalGeneration")
|
||||
@ModelBase.register(
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Llama4ForCausalLM",
|
||||
)
|
||||
class Llama4Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA4
|
||||
undo_permute = False
|
||||
@@ -2408,6 +2414,10 @@ class Llama4Model(LlamaModel):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
|
||||
if "layer_types" in self.hparams:
|
||||
if all(lt == "full_attention" for lt in self.hparams["layer_types"]):
|
||||
# all layers are full attention (for MobileLLM), disable swa
|
||||
self.gguf_writer.add_sliding_window(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
if name.startswith("language_model."):
|
||||
@@ -6006,9 +6016,34 @@ class SeedOssModel(TextModel):
|
||||
|
||||
|
||||
@ModelBase.register("Olmo2ForCausalLM")
|
||||
@ModelBase.register("Olmo3ForCausalLM")
|
||||
class Olmo2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.OLMO2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_attn_factors(rope_scaling["attention_factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
if "sliding_window" in self.hparams:
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
|
||||
sliding_window_pattern = []
|
||||
if "layer_types" in self.hparams:
|
||||
sliding_window_pattern = [t == "sliding_attention" for t in self.hparams["layer_types"]]
|
||||
else:
|
||||
# Olmo2 does not use sliding window attention.
|
||||
# Olmo3 defaults to using sliding window for all layers except every 4th.
|
||||
for i in range(self.hparams["num_hidden_layers"]):
|
||||
sliding_window_pattern.append((i + 1) % 4 != 0)
|
||||
|
||||
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
|
||||
|
||||
|
||||
@ModelBase.register("OlmoeForCausalLM")
|
||||
class OlmoeModel(TextModel):
|
||||
@@ -8239,6 +8274,76 @@ class HunYuanMoEModel(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("LLaDAMoEModel", "LLaDAMoEModelLM")
|
||||
class LLaDAMoEModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLADA_MOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
# number of experts used per token (top-k)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
|
||||
self.gguf_writer.add_mask_token_id(156895)
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_diffusion_shift_logits(False)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
# Copied from: Qwen2MoeModel
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
# Copied from: Qwen2MoeModel
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("HunYuanDenseV1ForCausalLM")
|
||||
class HunYuanModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
|
||||
|
||||
@@ -139,6 +139,7 @@ models = [
|
||||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
||||
@@ -510,19 +510,27 @@ static void diffusion_generate(llama_context * ctx,
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
|
||||
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
auto chat_templates = common_chat_templates_init(model, "");
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
inputs.add_generation_prompt = true;
|
||||
common_chat_msg system_msg;
|
||||
|
||||
if (!system_prompt.empty()) {
|
||||
system_msg.role = "system";
|
||||
system_msg.content = system_prompt;
|
||||
inputs.messages.push_back(system_msg);
|
||||
}
|
||||
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
|
||||
inputs.messages.push_back(user_msg);
|
||||
inputs.add_generation_prompt = true;
|
||||
|
||||
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
|
||||
|
||||
@@ -579,7 +587,8 @@ int main(int argc, char ** argv) {
|
||||
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
|
||||
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
|
||||
|
||||
std::vector<llama_token> input_tokens = common_tokenize(vocab,
|
||||
formatted_prompt,
|
||||
@@ -596,6 +605,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
llama_token mask_token_id = llama_vocab_mask(vocab);
|
||||
|
||||
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
||||
|
||||
bool visual_mode = params.diffusion.visual_mode;
|
||||
|
||||
@@ -145,6 +145,20 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
if (llama_encode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
||||
decoder_start_token_id = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
batch = llama_batch_get_one(&decoder_start_token_id, 1);
|
||||
}
|
||||
|
||||
// main loop
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
@@ -526,7 +526,10 @@ struct ggml_backend_cann_context {
|
||||
*/
|
||||
aclrtStream stream(int stream) {
|
||||
if (streams[stream] == nullptr) {
|
||||
ggml_cann_set_device(device);
|
||||
// If the device is not set here, destroying the stream later may cause a mismatch
|
||||
// between the thread contexts where the stream was created and destroyed.
|
||||
// However, I printed the device_id, thread_id, and stream, and they are all consistent.
|
||||
ACL_CHECK(aclrtSetDevice(device));
|
||||
ACL_CHECK(aclrtCreateStream(&streams[stream]));
|
||||
}
|
||||
return streams[stream];
|
||||
|
||||
@@ -75,13 +75,12 @@
|
||||
* @param device The device ID to set.
|
||||
*/
|
||||
void ggml_cann_set_device(const int32_t device) {
|
||||
// TODO: uncomment these lines after empty context has fixed.
|
||||
// int current_device;
|
||||
// ACL_CHECK(aclrtGetDevice(¤t_device));
|
||||
int current_device = -1;
|
||||
aclrtGetDevice(¤t_device);
|
||||
|
||||
// if (device == current_device) {
|
||||
// return;
|
||||
// }
|
||||
if (device == current_device) {
|
||||
return;
|
||||
}
|
||||
ACL_CHECK(aclrtSetDevice(device));
|
||||
}
|
||||
|
||||
@@ -1729,6 +1728,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
||||
ggml_cann_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl;
|
||||
ggml_cann_set_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
|
||||
@@ -8599,7 +8599,6 @@ static void ggml_compute_forward_timestep_embedding_f32(
|
||||
}
|
||||
if (dim % 2 != 0 && ith == 0) {
|
||||
embed_data[2 * half] = 0.f;
|
||||
embed_data[dim] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +75,8 @@
|
||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
|
||||
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
|
||||
#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
|
||||
#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
|
||||
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
|
||||
|
||||
// Moore Threads
|
||||
@@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
}
|
||||
|
||||
// Maximum number of bytes that can be copied in a single instruction.
|
||||
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
|
||||
#ifdef GGML_USE_HIP
|
||||
return 16;
|
||||
#else
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
return 16;
|
||||
#else
|
||||
return 8;
|
||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
#endif // GGML_USE_HIP
|
||||
}
|
||||
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
|
||||
@@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||
}
|
||||
|
||||
template<int D> // D == head size
|
||||
#if !defined(GGML_USE_HIP)
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP)
|
||||
static __global__ void flash_attn_combine_results(
|
||||
const float * __restrict__ VKQ_parts,
|
||||
const float2 * __restrict__ VKQ_meta,
|
||||
@@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results(
|
||||
float VKQ_numerator = 0.0f;
|
||||
float VKQ_denominator = 0.0f;
|
||||
for (int l = 0; l < parallel_blocks; ++l) {
|
||||
const float diff = meta[l].x - kqmax;
|
||||
float KQ_max_scale = expf(diff);
|
||||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
const float KQ_max_scale = expf(meta[l].x - kqmax);
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
@@ -836,11 +831,10 @@ void launch_fattn(
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
int parallel_blocks = 1;
|
||||
|
||||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
int parallel_blocks = max_blocks_per_sm;
|
||||
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
@@ -862,9 +856,6 @@ void launch_fattn(
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
|
||||
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
||||
|
||||
@@ -2,20 +2,30 @@
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
|
||||
#define FATTN_TILE_NTHREADS 256
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
} else {
|
||||
return 64;
|
||||
}
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
@@ -49,24 +59,28 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
return 128;
|
||||
case 128:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
#else
|
||||
return 64;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
case 256:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
#else
|
||||
return 64;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
@@ -100,17 +114,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return 64;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
case 256:
|
||||
#if defined(GCN) || defined(CDNA)
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
#else
|
||||
return ncols <= 16 ? 64 : 256;
|
||||
#endif // defined(GCN) || defined(CDNA)
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
@@ -120,9 +125,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 256:
|
||||
return ncols <= 16 ? 64 : 128;
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
@@ -142,12 +146,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
#ifdef GGML_USE_HIP
|
||||
__launch_bounds__(FATTN_TILE_NTHREADS, 1)
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
__launch_bounds__(FATTN_TILE_NTHREADS, 2)
|
||||
#endif // GGML_USE_HIP
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
@@ -193,7 +212,7 @@ static __global__ void flash_attn_tile(
|
||||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
@@ -206,90 +225,126 @@ static __global__ void flash_attn_tile(
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
constexpr int cpy_nb = 16;
|
||||
#else
|
||||
constexpr int cpy_nb = 8;
|
||||
#endif // defined(GGML_USE_HIP) && defined(GCN)
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
__shared__ float KQ[ncols][kq_stride];
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
|
||||
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
|
||||
float kqmax[ncols/nwarps];
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float kqsum[ncols/nwarps] = {0.0f};
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
|
||||
Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float kqmax_new[ncols/nwarps];
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
kqmax_new[j] = kqmax[j];
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
|
||||
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
|
||||
#else
|
||||
const float2 tmp_f2 = __half22float2(tmp_h2);
|
||||
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
|
||||
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -298,12 +353,12 @@ static __global__ void flash_attn_tile(
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[ncols/nwarps][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[ncols/nwarps][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
@@ -311,29 +366,29 @@ static __global__ void flash_attn_tile(
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -344,104 +399,77 @@ static __global__ void flash_attn_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
||||
|
||||
KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
|
||||
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||
|
||||
float kqsum_add = 0.0f;
|
||||
if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
|
||||
const int i = i0 + 4*threadIdx.x;
|
||||
|
||||
float4 val = *(const float4 *) &KQ[j][i];
|
||||
val.x = expf(val.x - kqmax[j0/nwarps]);
|
||||
val.y = expf(val.y - kqmax[j0/nwarps]);
|
||||
val.z = expf(val.z - kqmax[j0/nwarps]);
|
||||
val.w = expf(val.w - kqmax[j0/nwarps]);
|
||||
kqsum_add += val.x + val.y + val.z + val.w;
|
||||
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
} else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
|
||||
const int i = i0 + 2*threadIdx.x;
|
||||
|
||||
float2 val = *(const float2 *) &KQ[j][i];
|
||||
val.x = expf(val.x - kqmax[j0/nwarps]);
|
||||
val.y = expf(val.y - kqmax[j0/nwarps]);
|
||||
kqsum_add += val.x + val.y;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 tmp = make_half2(val.x, val.y);
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ[j][i] - kqmax[j0/nwarps];
|
||||
const float val = expf(diff);
|
||||
kqsum_add += val;
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
((half *) KQ[j])[i] = val;
|
||||
#else
|
||||
KQ[j][i] = val;
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
}
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
@@ -449,65 +477,96 @@ static __global__ void flash_attn_tile(
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KV_tmp_h2[k_tile*(D/2) + i] = tmp;
|
||||
#else
|
||||
KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[ncols/nwarps];
|
||||
#else
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[ncols/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
|
||||
#else
|
||||
V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
|
||||
#else
|
||||
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
|
||||
#else
|
||||
VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
|
||||
VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -519,69 +578,92 @@ static __global__ void flash_attn_tile(
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
||||
kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
|
||||
kqmax[j0/nwarps] = kqmax_new_j;
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - kqmax[j0/nwarps]);
|
||||
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum[j0/nwarps] += val;
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += warp_size) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
@@ -602,15 +684,29 @@ template <int D, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
@@ -619,6 +715,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
|
||||
@@ -122,11 +122,14 @@ static __global__ void im2col_3d_kernel(
|
||||
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
|
||||
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
|
||||
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
|
||||
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= IC_KD_KH_KW) {
|
||||
return;
|
||||
}
|
||||
GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
|
||||
GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
|
||||
|
||||
const int64_t iic = i / KD_KH_KW;
|
||||
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
|
||||
@@ -148,7 +151,7 @@ static __global__ void im2col_3d_kernel(
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
dst[offset_dst] = 0.0f;
|
||||
} else {
|
||||
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
|
||||
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
|
||||
dst[offset_dst] = src[offset_src];
|
||||
}
|
||||
}
|
||||
@@ -159,6 +162,7 @@ template <typename T>
|
||||
static void im2col_3d_cuda(const float * src, T* dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
const int64_t OH_OW = OH*OW;
|
||||
const int64_t KD_KH_KW = KD*KH*KW;
|
||||
@@ -179,23 +183,30 @@ static void im2col_3d_cuda(const float * src, T* dst,
|
||||
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
|
||||
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
|
||||
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f16(const float * src, half * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
static void im2col_3d_cuda_f32(const float * src, float * dst,
|
||||
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
|
||||
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
|
||||
int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
|
||||
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
|
||||
|
||||
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -235,9 +246,19 @@ void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||
const int64_t OH = ne2;
|
||||
const int64_t OW = ne1;
|
||||
|
||||
const size_t es = ggml_element_size(src1);
|
||||
const int64_t stride_x = src1->nb[0] / es;
|
||||
const int64_t stride_y = src1->nb[1] / es;
|
||||
const int64_t stride_z = src1->nb[2] / es;
|
||||
const int64_t stride_q = src1->nb[3] / es;
|
||||
|
||||
if(dst->type == GGML_TYPE_F16) {
|
||||
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
} else {
|
||||
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
|
||||
stride_q, stride_z, stride_y, stride_x,
|
||||
s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,31 +57,33 @@ static __global__ void mul_mat_f(
|
||||
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||
|
||||
if constexpr (has_ids) {
|
||||
__shared__ int has_any;
|
||||
if (threadIdx.y == 0) {
|
||||
int local_has_any = 0;
|
||||
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
|
||||
int slot = -1;
|
||||
for (int k = 0; k < nchannels_dst; ++k) {
|
||||
const int idv = ids[j*stride_row_id + k*stride_col_id];
|
||||
if (idv == expert_idx) {
|
||||
slot = k;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (j < cols_per_block) {
|
||||
local_has_any |= (slot >= 0);
|
||||
slot_map[j] = slot;
|
||||
int found = 0;
|
||||
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
const int32_t * __restrict__ id_row = ids + j*stride_row_id;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
slot_map[j] = -1;
|
||||
}
|
||||
|
||||
for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
|
||||
int match = id_row[k*stride_col_id] == expert_idx;
|
||||
|
||||
if (match) {
|
||||
slot_map[j] = k;
|
||||
found = 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
has_any = warp_reduce_any(local_has_any);
|
||||
}
|
||||
__syncthreads();
|
||||
if (has_any == 0) {
|
||||
|
||||
if (!__syncthreads_or(found)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||
tile_A A[ntA][warp_size / tile_A::J];
|
||||
#pragma unroll
|
||||
@@ -106,14 +108,7 @@ static __global__ void mul_mat_f(
|
||||
if constexpr (!has_ids) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
|
||||
} else {
|
||||
float val = 0.0f;
|
||||
if (j < cols_per_block) {
|
||||
const int slot = slot_map[j];
|
||||
if (slot >= 0) {
|
||||
val = y[slot*stride_channel_y + j*stride_col_y + col];
|
||||
}
|
||||
}
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
@@ -125,14 +120,7 @@ static __global__ void mul_mat_f(
|
||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
} else {
|
||||
float2 tmp = make_float2(0.0f, 0.0f);
|
||||
if (j < cols_per_block) {
|
||||
const int slot = slot_map[j];
|
||||
if (slot >= 0) {
|
||||
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
|
||||
tmp = y2_slot[j*stride_col_y + col];
|
||||
}
|
||||
}
|
||||
float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
}
|
||||
@@ -221,7 +209,7 @@ static inline void mul_mat_f_switch_ids(
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
|
||||
@@ -7,11 +7,11 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d
|
||||
int j = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
float * embed_data = (float *)((char *)dst + i*nb1);
|
||||
|
||||
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
||||
embed_data[dim] = 0.f;
|
||||
int half = dim / 2;
|
||||
if (dim % 2 != 0 && j == half) {
|
||||
embed_data[2 * half] = 0.f;
|
||||
}
|
||||
|
||||
int half = dim / 2;
|
||||
if (j >= half) {
|
||||
return;
|
||||
}
|
||||
|
||||
34
ggml/src/ggml-cuda/vendors/hip.h
vendored
34
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -158,41 +158,41 @@
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN
|
||||
#endif
|
||||
|
||||
#if defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN5
|
||||
#endif
|
||||
#endif // defined(__gfx900__) || defined(__gfx906__)
|
||||
|
||||
#if defined(__gfx803__)
|
||||
#define GCN4
|
||||
#endif
|
||||
#endif // defined(__gfx803__)
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
#define CDNA // For the entire family
|
||||
#endif
|
||||
#if defined(GCN5) || defined(GCN4)
|
||||
#define GCN
|
||||
#endif // defined(GCN5) || defined(GCN4)
|
||||
|
||||
#if defined(__gfx942__)
|
||||
#define CDNA3
|
||||
#endif
|
||||
#endif // defined(__gfx942__)
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
#define CDNA2
|
||||
#endif
|
||||
#endif // defined(__gfx90a__)
|
||||
|
||||
#if defined(__gfx908__)
|
||||
#define CDNA1
|
||||
#endif
|
||||
#endif // defined(__gfx908__)
|
||||
|
||||
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
#define CDNA // For the entire family
|
||||
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
|
||||
#if defined(__GFX12__)
|
||||
#define RDNA4
|
||||
#endif
|
||||
#endif // defined(__GFX12__)
|
||||
|
||||
#if defined(__GFX11__)
|
||||
#define RDNA3
|
||||
#endif
|
||||
#endif // defined(__GFX11__)
|
||||
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||
@@ -201,7 +201,11 @@
|
||||
|
||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||
#define RDNA1
|
||||
#endif
|
||||
#endif // defined(__gfx1010__) || defined(__gfx1012__)
|
||||
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
#define RDNA // For the entire family
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
|
||||
@@ -4167,7 +4167,7 @@ kernel void kernel_timestep_embedding_f32(
|
||||
}
|
||||
|
||||
if (args.dim % 2 != 0 && tpitg.x == 0) {
|
||||
embed_data[args.dim] = 0.f;
|
||||
embed_data[2 * half_] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ kernel void kernel_timestep_embedding(
|
||||
local_half_dim = logical_dim / 2;
|
||||
local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
|
||||
|
||||
if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) {
|
||||
local_embed_data_ptr[logical_dim] = 0.0f;
|
||||
if (logical_dim % 2 != 0 && local_j == local_half_dim) {
|
||||
local_embed_data_ptr[2 * local_half_dim] = 0.0f;
|
||||
}
|
||||
|
||||
if (local_j >= local_half_dim) {
|
||||
|
||||
@@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
|
||||
@@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_sub(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_count_equal(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
|
||||
ggml_sycl_op_mul(ctx, dst);
|
||||
|
||||
@@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_count_equal(const float a, const float b) {
|
||||
return (a == b) ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
static __dpct_inline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
@@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
case GGML_OP_SUB:
|
||||
ggml_sycl_sub(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
ggml_sycl_count_equal(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
ggml_sycl_acc(ctx, dst);
|
||||
break;
|
||||
@@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_COUNT_EQUAL:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_REPEAT:
|
||||
|
||||
@@ -21,11 +21,12 @@ static void timestep_embedding_f32(
|
||||
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
float * embed_data = (float *)((char *)dst + i*nb1);
|
||||
|
||||
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
|
||||
embed_data[dim] = 0.f;
|
||||
int half = dim / 2;
|
||||
|
||||
if (dim % 2 != 0 && j == half) {
|
||||
embed_data[2 * half] = 0.f;
|
||||
}
|
||||
|
||||
int half = dim / 2;
|
||||
if (j >= half) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4423,8 +4423,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
|
||||
static bool ggml_vk_instance_validation_ext_available();
|
||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||
|
||||
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
|
||||
|
||||
static void ggml_vk_instance_init() {
|
||||
if (vk_instance_initialized) {
|
||||
@@ -4540,7 +4540,7 @@ static void ggml_vk_instance_init() {
|
||||
new_driver.pNext = &new_id;
|
||||
devices[i].getProperties2(&new_props);
|
||||
|
||||
if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) {
|
||||
if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
|
||||
// Check if there are two physical devices corresponding to the same GPU
|
||||
auto old_device = std::find_if(
|
||||
vk_instance.device_indices.begin(),
|
||||
@@ -12738,6 +12738,20 @@ static bool ggml_vk_instance_debug_utils_ext_available(
|
||||
UNUSED(instance_extensions);
|
||||
}
|
||||
|
||||
static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
|
||||
VkPhysicalDeviceFeatures2 device_features2;
|
||||
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
||||
|
||||
VkPhysicalDeviceVulkan11Features vk11_features;
|
||||
vk11_features.pNext = nullptr;
|
||||
vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
|
||||
device_features2.pNext = &vk11_features;
|
||||
|
||||
vkGetPhysicalDeviceFeatures2(vkdev, &device_features2);
|
||||
|
||||
return vk11_features.storageBuffer16BitAccess;
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
|
||||
@@ -24,11 +24,12 @@ void main() {
|
||||
const uint j = gl_GlobalInvocationID.x;
|
||||
const uint d_offset = i * p.nb1;
|
||||
|
||||
if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
|
||||
data_d[d_offset + p.dim] = 0.f;
|
||||
const uint half_dim = p.dim / 2;
|
||||
|
||||
if (p.dim % 2 != 0 && j == half_dim) {
|
||||
data_d[d_offset + 2 * half_dim] = 0.f;
|
||||
}
|
||||
|
||||
const uint half_dim = p.dim / 2;
|
||||
if (j >= half_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4923,12 +4923,8 @@ struct ggml_tensor * ggml_timestep_embedding(
|
||||
struct ggml_tensor * timesteps,
|
||||
int dim,
|
||||
int max_period) {
|
||||
int actual_dim = dim;
|
||||
if (dim % 2 != 0) {
|
||||
actual_dim = dim + 1;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
|
||||
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, dim);
|
||||
ggml_set_op_params_i32(result, 1, max_period);
|
||||
|
||||
@@ -399,6 +399,7 @@ class MODEL_ARCH(IntEnum):
|
||||
DREAM = auto()
|
||||
SMALLTHINKER = auto()
|
||||
LLADA = auto()
|
||||
LLADA_MOE = auto()
|
||||
SEED_OSS = auto()
|
||||
|
||||
|
||||
@@ -735,6 +736,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DREAM: "dream",
|
||||
MODEL_ARCH.SMALLTHINKER: "smallthinker",
|
||||
MODEL_ARCH.LLADA: "llada",
|
||||
MODEL_ARCH.LLADA_MOE: "llada-moe",
|
||||
MODEL_ARCH.SEED_OSS: "seed_oss",
|
||||
}
|
||||
|
||||
@@ -2693,6 +2695,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.LLADA_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
{ LLM_ARCH_LLADA, "llada" },
|
||||
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
|
||||
{ LLM_ARCH_SEED_OSS, "seed_oss" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
@@ -2147,6 +2148,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLADA_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_SEED_OSS,
|
||||
{
|
||||
@@ -2427,6 +2448,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -100,6 +100,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
LLM_ARCH_LLADA,
|
||||
LLM_ARCH_LLADA_MOE,
|
||||
LLM_ARCH_SEED_OSS,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -149,7 +149,7 @@ struct llama_hparams {
|
||||
bool causal_attn = true;
|
||||
bool use_alibi = false;
|
||||
bool attn_soft_cap = false;
|
||||
bool use_kq_norm = true;
|
||||
bool use_kq_norm = false;
|
||||
|
||||
// for Classifiers
|
||||
uint32_t n_cls_out = 1;
|
||||
|
||||
@@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_80M: return "80M";
|
||||
case LLM_TYPE_109M: return "109M";
|
||||
case LLM_TYPE_137M: return "137M";
|
||||
case LLM_TYPE_140M: return "140M";
|
||||
case LLM_TYPE_160M: return "160M";
|
||||
case LLM_TYPE_190M: return "190M";
|
||||
case LLM_TYPE_220M: return "220M";
|
||||
@@ -44,6 +45,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_270M: return "270M";
|
||||
case LLM_TYPE_335M: return "335M";
|
||||
case LLM_TYPE_350M: return "350M";
|
||||
case LLM_TYPE_360M: return "360M";
|
||||
case LLM_TYPE_410M: return "410M";
|
||||
case LLM_TYPE_450M: return "450M";
|
||||
case LLM_TYPE_475M: return "475M";
|
||||
@@ -51,6 +53,7 @@ const char * llm_type_name(llm_type type) {
|
||||
case LLM_TYPE_700M: return "700M";
|
||||
case LLM_TYPE_770M: return "770M";
|
||||
case LLM_TYPE_780M: return "780M";
|
||||
case LLM_TYPE_950M: return "950M";
|
||||
case LLM_TYPE_0_3B: return "0.3B";
|
||||
case LLM_TYPE_0_5B: return "0.5B";
|
||||
case LLM_TYPE_0_6B: return "0.6B";
|
||||
@@ -622,19 +625,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
if (found_swa && hparams.n_swa == 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192;
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
}
|
||||
|
||||
switch (hparams.n_expert) {
|
||||
case 0: {
|
||||
// MobileLLM (no MoE)
|
||||
switch (hparams.n_embd) {
|
||||
case 2048: type = LLM_TYPE_140M; break;
|
||||
case 4096: type = LLM_TYPE_360M; break;
|
||||
case 6144: type = LLM_TYPE_950M; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case 16: type = LLM_TYPE_17B_16E; break;
|
||||
case 128: type = LLM_TYPE_17B_128E; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
if (type == LLM_TYPE_17B_128E) {
|
||||
hparams.use_kq_norm = false;
|
||||
}
|
||||
hparams.use_kq_norm = type != LLM_TYPE_17B_128E;
|
||||
} break;
|
||||
case LLM_ARCH_ARCEE:
|
||||
{
|
||||
@@ -936,6 +952,18 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
hparams.causal_attn = false;
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
// diffusion language model uses non-causal attention
|
||||
hparams.causal_attn = false;
|
||||
switch (hparams.n_layer) {
|
||||
case 16: type = LLM_TYPE_A1_7B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||
@@ -1338,6 +1366,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
|
||||
if (found_swa && hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(4);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 16: type = LLM_TYPE_1B; break;
|
||||
case 32: type = LLM_TYPE_7B; break;
|
||||
@@ -2387,6 +2423,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe");
|
||||
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe");
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LLAMA4:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -2400,9 +2470,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
|
||||
bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0;
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
||||
@@ -6274,6 +6343,14 @@ struct llm_build_llama : public llm_graph_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (hparams.use_kq_norm) {
|
||||
// Llama4TextL2Norm
|
||||
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
|
||||
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
@@ -6381,7 +6458,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
|
||||
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
|
||||
(il + 1) % hparams.n_no_rope_layer_step != 0;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
@@ -12187,6 +12265,7 @@ struct llm_build_olmo : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_olmo2 : public llm_graph_context {
|
||||
llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@@ -12202,7 +12281,14 @@ struct llm_build_olmo2 : public llm_graph_context {
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||
inp_attn_type * inp_attn = nullptr;
|
||||
|
||||
if constexpr (iswa) {
|
||||
inp_attn = build_attn_inp_kv_iswa();
|
||||
} else {
|
||||
inp_attn = build_attn_inp_kv();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
@@ -12235,17 +12321,36 @@ struct llm_build_olmo2 : public llm_graph_context {
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
if (is_swa) {
|
||||
// For sliding window layers, Olmo3 use regular rope with no yarn rope scaling.
|
||||
// This is achieved here by setting freq_scale and attn_factor to 1.
|
||||
// We also set ext_factor to 0 to avoid a few unnecessary computations.
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
|
||||
0.0, 1.0, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
|
||||
0.0, 1.0, beta_fast, beta_slow
|
||||
);
|
||||
} else {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
@@ -12444,6 +12549,132 @@ struct llm_build_olmoe : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_llada_moe : public llm_graph_context {
|
||||
llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_no_cache();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_openelm : public llm_graph_context {
|
||||
llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
@@ -18636,6 +18867,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
@@ -18773,7 +19005,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
} break;
|
||||
case LLM_ARCH_LLAMA4:
|
||||
{
|
||||
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
|
||||
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
|
||||
llm = std::make_unique<llm_build_llama>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DECI:
|
||||
{
|
||||
@@ -18841,6 +19077,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
llm = std::make_unique<llm_build_llada>(*this, params);
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_llada_moe>(*this, params);
|
||||
}
|
||||
break;
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
{
|
||||
llm = std::make_unique<llm_build_qwen2vl>(*this, params);
|
||||
@@ -18953,7 +19194,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
} break;
|
||||
case LLM_ARCH_OLMO2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_olmo2>(*this, params);
|
||||
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
|
||||
llm = std::make_unique<llm_build_olmo2<true>>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_olmo2<false>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OLMOE:
|
||||
{
|
||||
@@ -19307,6 +19552,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_QWEN2MOE:
|
||||
case LLM_ARCH_QWEN3:
|
||||
case LLM_ARCH_QWEN3MOE:
|
||||
case LLM_ARCH_LLADA_MOE:
|
||||
case LLM_ARCH_OLMO2:
|
||||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_PHI2:
|
||||
|
||||
@@ -28,6 +28,7 @@ enum llm_type {
|
||||
LLM_TYPE_80M,
|
||||
LLM_TYPE_109M,
|
||||
LLM_TYPE_137M,
|
||||
LLM_TYPE_140M,
|
||||
LLM_TYPE_160M,
|
||||
LLM_TYPE_190M,
|
||||
LLM_TYPE_220M,
|
||||
@@ -36,6 +37,7 @@ enum llm_type {
|
||||
LLM_TYPE_270M,
|
||||
LLM_TYPE_335M,
|
||||
LLM_TYPE_350M,
|
||||
LLM_TYPE_360M,
|
||||
LLM_TYPE_410M,
|
||||
LLM_TYPE_450M,
|
||||
LLM_TYPE_475M,
|
||||
@@ -43,6 +45,7 @@ enum llm_type {
|
||||
LLM_TYPE_700M,
|
||||
LLM_TYPE_770M,
|
||||
LLM_TYPE_780M,
|
||||
LLM_TYPE_950M,
|
||||
LLM_TYPE_0_3B,
|
||||
LLM_TYPE_0_5B,
|
||||
LLM_TYPE_0_6B,
|
||||
|
||||
@@ -725,7 +725,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
// attention layers have a non-zero number of kv heads
|
||||
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
|
||||
if (llama_model_has_encoder(&model)) {
|
||||
n_attn_layer *= 3;
|
||||
// now n_attn_layer is the number of attention layers in the encoder
|
||||
// for each decoder block, there are 2 attention layers
|
||||
n_attn_layer += 2 * model.hparams.dec_n_layer;
|
||||
}
|
||||
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
||||
}
|
||||
|
||||
@@ -1962,7 +1962,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "bailingmoe") {
|
||||
tokenizer_pre == "bailingmoe" ||
|
||||
tokenizer_pre == "llada-moe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-batched-bench)
|
||||
add_executable(${TARGET} batched-bench.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-cvector-generator)
|
||||
add_executable(${TARGET} cvector-generator.cpp pca.hpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-export-lora)
|
||||
add_executable(${TARGET} export-lora.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-gguf-split)
|
||||
add_executable(${TARGET} gguf-split.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-imatrix)
|
||||
add_executable(${TARGET} imatrix.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-bench)
|
||||
add_executable(${TARGET} llama-bench.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -250,6 +250,7 @@ struct cmd_params {
|
||||
std::vector<bool> cpu_strict;
|
||||
std::vector<int> poll;
|
||||
std::vector<int> n_gpu_layers;
|
||||
std::vector<int> n_cpu_moe;
|
||||
std::vector<std::string> rpc_servers;
|
||||
std::vector<llama_split_mode> split_mode;
|
||||
std::vector<int> main_gpu;
|
||||
@@ -286,6 +287,7 @@ static const cmd_params cmd_params_defaults = {
|
||||
/* cpu_strict */ { false },
|
||||
/* poll */ { 50 },
|
||||
/* n_gpu_layers */ { 99 },
|
||||
/* n_cpu_moe */ { 0 },
|
||||
/* rpc_servers */ { "" },
|
||||
/* split_mode */ { LLAMA_SPLIT_MODE_LAYER },
|
||||
/* main_gpu */ { 0 },
|
||||
@@ -353,6 +355,8 @@ static void print_usage(int /* argc */, char ** argv) {
|
||||
printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str());
|
||||
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n",
|
||||
join(cmd_params_defaults.n_gpu_layers, ",").c_str());
|
||||
printf(" -ncmoe, --n-cpu-moe <n> (default: %s)\n",
|
||||
join(cmd_params_defaults.n_cpu_moe, ",").c_str());
|
||||
if (llama_supports_rpc()) {
|
||||
printf(" -rpc, --rpc <rpc_servers> (default: %s)\n",
|
||||
join(cmd_params_defaults.rpc_servers, ",").c_str());
|
||||
@@ -564,6 +568,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
}
|
||||
auto p = parse_int_range(argv[i]);
|
||||
params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
|
||||
} else if (arg == "-ncmoe" || arg == "--n-cpu-moe") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
auto p = parse_int_range(argv[i]);
|
||||
params.n_cpu_moe.insert(params.n_cpu_moe.end(), p.begin(), p.end());
|
||||
} else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@@ -841,6 +852,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
|
||||
if (params.n_gpu_layers.empty()) {
|
||||
params.n_gpu_layers = cmd_params_defaults.n_gpu_layers;
|
||||
}
|
||||
if (params.n_cpu_moe.empty()) {
|
||||
params.n_cpu_moe = cmd_params_defaults.n_cpu_moe;
|
||||
}
|
||||
if (params.rpc_servers.empty()) {
|
||||
params.rpc_servers = cmd_params_defaults.rpc_servers;
|
||||
}
|
||||
@@ -901,6 +915,7 @@ struct cmd_params_instance {
|
||||
bool cpu_strict;
|
||||
int poll;
|
||||
int n_gpu_layers;
|
||||
int n_cpu_moe;
|
||||
std::string rpc_servers_str;
|
||||
llama_split_mode split_mode;
|
||||
int main_gpu;
|
||||
@@ -973,20 +988,50 @@ struct cmd_params_instance {
|
||||
mparams.tensor_split = tensor_split.data();
|
||||
mparams.use_mmap = use_mmap;
|
||||
|
||||
if (tensor_buft_overrides.empty()) {
|
||||
mparams.tensor_buft_overrides = nullptr;
|
||||
if (n_cpu_moe <= 0) {
|
||||
if (tensor_buft_overrides.empty()) {
|
||||
mparams.tensor_buft_overrides = nullptr;
|
||||
} else {
|
||||
GGML_ASSERT(tensor_buft_overrides.back().pattern == nullptr &&
|
||||
"Tensor buffer overrides not terminated with empty pattern");
|
||||
mparams.tensor_buft_overrides = tensor_buft_overrides.data();
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
|
||||
mparams.tensor_buft_overrides = tensor_buft_overrides.data();
|
||||
static std::vector<llama_model_tensor_buft_override> merged;
|
||||
static std::vector<std::string> patterns;
|
||||
|
||||
merged.clear();
|
||||
patterns.clear();
|
||||
|
||||
auto first = tensor_buft_overrides.begin();
|
||||
auto last = tensor_buft_overrides.end();
|
||||
if (first != last && (last - 1)->pattern == nullptr) {
|
||||
--last;
|
||||
}
|
||||
merged.insert(merged.end(), first, last);
|
||||
|
||||
patterns.reserve((size_t) n_cpu_moe);
|
||||
merged.reserve(merged.size() + (size_t) n_cpu_moe + 1);
|
||||
|
||||
for (int i = 0; i < n_cpu_moe; ++i) {
|
||||
patterns.push_back(llm_ffn_exps_block_regex(i));
|
||||
merged.push_back({ patterns.back().c_str(),
|
||||
ggml_backend_cpu_buffer_type() });
|
||||
}
|
||||
|
||||
merged.push_back({ nullptr, nullptr });
|
||||
|
||||
mparams.tensor_buft_overrides = merged.data();
|
||||
}
|
||||
|
||||
return mparams;
|
||||
}
|
||||
|
||||
bool equal_mparams(const cmd_params_instance & other) const {
|
||||
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
|
||||
split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
|
||||
tensor_split == other.tensor_split && vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides);
|
||||
return model == other.model && n_gpu_layers == other.n_gpu_layers && n_cpu_moe == other.n_cpu_moe &&
|
||||
rpc_servers_str == other.rpc_servers_str && split_mode == other.split_mode &&
|
||||
main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split &&
|
||||
vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides);
|
||||
}
|
||||
|
||||
llama_context_params to_llama_cparams() const {
|
||||
@@ -1014,6 +1059,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
// clang-format off
|
||||
for (const auto & m : params.model)
|
||||
for (const auto & nl : params.n_gpu_layers)
|
||||
for (const auto & ncmoe : params.n_cpu_moe)
|
||||
for (const auto & rpc : params.rpc_servers)
|
||||
for (const auto & sm : params.split_mode)
|
||||
for (const auto & mg : params.main_gpu)
|
||||
@@ -1051,6 +1097,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .cpu_strict = */ cs,
|
||||
/* .poll = */ pl,
|
||||
/* .n_gpu_layers = */ nl,
|
||||
/* .n_cpu_moe = */ ncmoe,
|
||||
/* .rpc_servers = */ rpc,
|
||||
/* .split_mode = */ sm,
|
||||
/* .main_gpu = */ mg,
|
||||
@@ -1083,6 +1130,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .cpu_strict = */ cs,
|
||||
/* .poll = */ pl,
|
||||
/* .n_gpu_layers = */ nl,
|
||||
/* .n_cpu_moe = */ ncmoe,
|
||||
/* .rpc_servers = */ rpc,
|
||||
/* .split_mode = */ sm,
|
||||
/* .main_gpu = */ mg,
|
||||
@@ -1115,6 +1163,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
|
||||
/* .cpu_strict = */ cs,
|
||||
/* .poll = */ pl,
|
||||
/* .n_gpu_layers = */ nl,
|
||||
/* .n_cpu_moe = */ ncmoe,
|
||||
/* .rpc_servers = */ rpc,
|
||||
/* .split_mode = */ sm,
|
||||
/* .main_gpu = */ mg,
|
||||
@@ -1152,6 +1201,7 @@ struct test {
|
||||
ggml_type type_k;
|
||||
ggml_type type_v;
|
||||
int n_gpu_layers;
|
||||
int n_cpu_moe;
|
||||
llama_split_mode split_mode;
|
||||
int main_gpu;
|
||||
bool no_kv_offload;
|
||||
@@ -1186,6 +1236,7 @@ struct test {
|
||||
type_k = inst.type_k;
|
||||
type_v = inst.type_v;
|
||||
n_gpu_layers = inst.n_gpu_layers;
|
||||
n_cpu_moe = inst.n_cpu_moe;
|
||||
split_mode = inst.split_mode;
|
||||
main_gpu = inst.main_gpu;
|
||||
no_kv_offload = inst.no_kv_offload;
|
||||
@@ -1236,12 +1287,14 @@ struct test {
|
||||
|
||||
static const std::vector<std::string> & get_fields() {
|
||||
static const std::vector<std::string> fields = {
|
||||
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
|
||||
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
|
||||
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
|
||||
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", "test_time",
|
||||
"avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
|
||||
"build_commit", "build_number", "cpu_info", "gpu_info", "backends",
|
||||
"model_filename", "model_type", "model_size", "model_n_params", "n_batch",
|
||||
"n_ubatch", "n_threads", "cpu_mask", "cpu_strict", "poll",
|
||||
"type_k", "type_v", "n_gpu_layers", "n_cpu_moe", "split_mode",
|
||||
"main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
|
||||
"use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen",
|
||||
"n_depth", "test_time", "avg_ns", "stddev_ns", "avg_ts",
|
||||
"stddev_ts"
|
||||
};
|
||||
return fields;
|
||||
}
|
||||
@@ -1251,8 +1304,8 @@ struct test {
|
||||
static field_type get_field_type(const std::string & field) {
|
||||
if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" ||
|
||||
field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" ||
|
||||
field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "n_depth" ||
|
||||
field == "avg_ns" || field == "stddev_ns" || field == "no_op_offload") {
|
||||
field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "n_depth" || field == "avg_ns" ||
|
||||
field == "stddev_ns" || field == "no_op_offload" || field == "n_cpu_moe") {
|
||||
return INT;
|
||||
}
|
||||
if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" ||
|
||||
@@ -1320,6 +1373,7 @@ struct test {
|
||||
ggml_type_name(type_k),
|
||||
ggml_type_name(type_v),
|
||||
std::to_string(n_gpu_layers),
|
||||
std::to_string(n_cpu_moe),
|
||||
split_mode_str(split_mode),
|
||||
std::to_string(main_gpu),
|
||||
std::to_string(no_kv_offload),
|
||||
@@ -1568,6 +1622,9 @@ struct markdown_printer : public printer {
|
||||
if (!is_cpu_backend) {
|
||||
fields.emplace_back("n_gpu_layers");
|
||||
}
|
||||
if (params.n_cpu_moe.size() > 1) {
|
||||
fields.emplace_back("n_cpu_moe");
|
||||
}
|
||||
if (params.n_threads.size() > 1 || params.n_threads != cmd_params_defaults.n_threads || is_cpu_backend) {
|
||||
fields.emplace_back("n_threads");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-cli)
|
||||
add_executable(${TARGET} main.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -55,7 +55,7 @@ add_executable(llama-qwen2vl-cli deprecation-warning.cpp)
|
||||
set(TARGET llama-mtmd-cli)
|
||||
add_executable (${TARGET} mtmd-cli.cpp)
|
||||
set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
|
||||
if(NOT CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
target_link_libraries (${TARGET} PRIVATE common mtmd Threads::Threads)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
set(TARGET llama-perplexity)
|
||||
add_executable(${TARGET} perplexity.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -1931,7 +1931,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
LOG("Maximum KLD: %10.6f\n", kld_values.back());
|
||||
LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f));
|
||||
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
|
||||
LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f));
|
||||
LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f));
|
||||
LOG("Median KLD: %10.6f\n", kld_median);
|
||||
LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f));
|
||||
LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f));
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
set(TARGET llama-quantize)
|
||||
add_executable(${TARGET} quantize.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_include_directories(${TARGET} PRIVATE ../../common)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
|
||||
@@ -10,6 +10,8 @@ if (LLAMA_CURL)
|
||||
set(LLAMA_RUN_EXTRA_LIBS ${LLAMA_RUN_EXTRA_LIBS} ${CURL_LIBRARIES})
|
||||
endif ()
|
||||
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
endif()
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
@@ -407,39 +407,22 @@ class HttpClient {
|
||||
}
|
||||
|
||||
std::string output_file_partial;
|
||||
curl = curl_easy_init();
|
||||
if (!curl) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
progress_data data;
|
||||
File out;
|
||||
if (!output_file.empty()) {
|
||||
output_file_partial = output_file + ".partial";
|
||||
if (!out.open(output_file_partial, "ab")) {
|
||||
printe("Failed to open file for writing\n");
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (out.lock()) {
|
||||
printe("Failed to exclusively lock file\n");
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
set_write_options(response_str, out);
|
||||
data.file_size = set_resume_point(output_file_partial);
|
||||
set_progress_options(progress, data);
|
||||
set_headers(headers);
|
||||
CURLcode res = perform(url);
|
||||
if (res != CURLE_OK){
|
||||
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
|
||||
if (download(url, headers, output_file_partial, progress, response_str)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!output_file.empty()) {
|
||||
std::filesystem::rename(output_file_partial, output_file);
|
||||
try {
|
||||
std::filesystem::rename(output_file_partial, output_file);
|
||||
} catch (const std::filesystem::filesystem_error & e) {
|
||||
printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
@@ -459,6 +442,42 @@ class HttpClient {
|
||||
CURL * curl = nullptr;
|
||||
struct curl_slist * chunk = nullptr;
|
||||
|
||||
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
const bool progress, std::string * response_str = nullptr) {
|
||||
curl = curl_easy_init();
|
||||
if (!curl) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
progress_data data;
|
||||
File out;
|
||||
if (!output_file.empty()) {
|
||||
if (!out.open(output_file, "ab")) {
|
||||
printe("Failed to open file for writing\n");
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (out.lock()) {
|
||||
printe("Failed to exclusively lock file\n");
|
||||
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
set_write_options(response_str, out);
|
||||
data.file_size = set_resume_point(output_file);
|
||||
set_progress_options(progress, data);
|
||||
set_headers(headers);
|
||||
CURLcode res = perform(url);
|
||||
if (res != CURLE_OK){
|
||||
printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res));
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void set_write_options(std::string * response_str, const File & out) {
|
||||
if (response_str) {
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data);
|
||||
@@ -507,6 +526,9 @@ class HttpClient {
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https");
|
||||
curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
|
||||
#ifdef _WIN32
|
||||
curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
||||
#endif
|
||||
return curl_easy_perform(curl);
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
@@ -5261,6 +5261,42 @@ int main(int argc, char ** argv) {
|
||||
svr->Get (params.api_prefix + "/slots", handle_slots);
|
||||
svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action);
|
||||
|
||||
// SPA fallback route - serve index.html for any route that doesn't match API endpoints
|
||||
// This enables client-side routing for dynamic routes like /chat/[id]
|
||||
if (params.webui && params.public_path.empty()) {
|
||||
// Only add fallback when using embedded static files
|
||||
svr->Get(".*", [](const httplib::Request & req, httplib::Response & res) {
|
||||
// Skip API routes - they should have been handled above
|
||||
if (req.path.find("/v1/") != std::string::npos ||
|
||||
req.path.find("/health") != std::string::npos ||
|
||||
req.path.find("/metrics") != std::string::npos ||
|
||||
req.path.find("/props") != std::string::npos ||
|
||||
req.path.find("/models") != std::string::npos ||
|
||||
req.path.find("/api/tags") != std::string::npos ||
|
||||
req.path.find("/completions") != std::string::npos ||
|
||||
req.path.find("/chat/completions") != std::string::npos ||
|
||||
req.path.find("/embeddings") != std::string::npos ||
|
||||
req.path.find("/tokenize") != std::string::npos ||
|
||||
req.path.find("/detokenize") != std::string::npos ||
|
||||
req.path.find("/lora-adapters") != std::string::npos ||
|
||||
req.path.find("/slots") != std::string::npos) {
|
||||
return false; // Let other handlers process API routes
|
||||
}
|
||||
|
||||
// Serve index.html for all other routes (SPA fallback)
|
||||
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
||||
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
||||
} else {
|
||||
res.set_header("Content-Encoding", "gzip");
|
||||
// COEP and COOP headers, required by pyodide (python interpreter)
|
||||
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
|
||||
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
|
||||
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Start the server
|
||||
//
|
||||
|
||||
@@ -92,7 +92,7 @@ def test_no_webui():
|
||||
url = f"http://{server.server_host}:{server.server_port}"
|
||||
res = requests.get(url)
|
||||
assert res.status_code == 200
|
||||
assert "<html>" in res.text
|
||||
assert "<!doctype html>" in res.text
|
||||
server.stop()
|
||||
|
||||
# with --no-webui
|
||||
|
||||
45
tools/server/webui/.gitignore
vendored
45
tools/server/webui/.gitignore
vendored
@@ -1,24 +1,27 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
test-results
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
# Output
|
||||
.output
|
||||
.vercel
|
||||
.netlify
|
||||
.wrangler
|
||||
/.svelte-kit
|
||||
/build
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
Thumbs.db
|
||||
|
||||
# Env
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
!.env.test
|
||||
|
||||
# Vite
|
||||
vite.config.js.timestamp-*
|
||||
vite.config.ts.timestamp-*
|
||||
|
||||
*storybook.log
|
||||
storybook-static
|
||||
|
||||
1
tools/server/webui/.npmrc
Normal file
1
tools/server/webui/.npmrc
Normal file
@@ -0,0 +1 @@
|
||||
engine-strict=true
|
||||
@@ -1,10 +1,9 @@
|
||||
**/.vscode
|
||||
**/.github
|
||||
**/.git
|
||||
**/.svn
|
||||
**/.hg
|
||||
**/node_modules
|
||||
**/dist
|
||||
**/build
|
||||
# Package Managers
|
||||
package-lock.json
|
||||
pnpm-lock.yaml
|
||||
yarn.lock
|
||||
bun.lock
|
||||
bun.lockb
|
||||
|
||||
*.config.js
|
||||
# Miscellaneous
|
||||
/static/
|
||||
|
||||
16
tools/server/webui/.prettierrc
Normal file
16
tools/server/webui/.prettierrc
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"useTabs": true,
|
||||
"singleQuote": true,
|
||||
"trailingComma": "none",
|
||||
"printWidth": 100,
|
||||
"plugins": ["prettier-plugin-svelte", "prettier-plugin-tailwindcss"],
|
||||
"overrides": [
|
||||
{
|
||||
"files": "*.svelte",
|
||||
"options": {
|
||||
"parser": "svelte"
|
||||
}
|
||||
}
|
||||
],
|
||||
"tailwindStylesheet": "./src/app.css"
|
||||
}
|
||||
36
tools/server/webui/.storybook/ModeWatcherDecorator.svelte
Normal file
36
tools/server/webui/.storybook/ModeWatcherDecorator.svelte
Normal file
@@ -0,0 +1,36 @@
|
||||
<script lang="ts">
|
||||
import { ModeWatcher } from 'mode-watcher';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
children?: any;
|
||||
}
|
||||
|
||||
let { children }: Props = $props();
|
||||
|
||||
onMount(() => {
|
||||
const root = document.documentElement;
|
||||
const theme = localStorage.getItem('mode-watcher-mode') || 'system';
|
||||
|
||||
if (theme === 'dark') {
|
||||
root.classList.add('dark');
|
||||
} else if (theme === 'light') {
|
||||
root.classList.remove('dark');
|
||||
} else {
|
||||
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||
if (prefersDark) {
|
||||
root.classList.add('dark');
|
||||
} else {
|
||||
root.classList.remove('dark');
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
<ModeWatcher />
|
||||
|
||||
{#if children}
|
||||
{@const Component = children}
|
||||
|
||||
<Component />
|
||||
{/if}
|
||||
@@ -0,0 +1,13 @@
|
||||
<script lang="ts">
|
||||
import * as Tooltip from '../src/lib/components/ui/tooltip';
|
||||
|
||||
interface Props {
|
||||
children: any;
|
||||
}
|
||||
|
||||
let { children }: Props = $props();
|
||||
</script>
|
||||
|
||||
<Tooltip.Provider>
|
||||
{@render children()}
|
||||
</Tooltip.Provider>
|
||||
17
tools/server/webui/.storybook/main.ts
Normal file
17
tools/server/webui/.storybook/main.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import type { StorybookConfig } from '@storybook/sveltekit';
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|ts|svelte)'],
|
||||
addons: [
|
||||
'@storybook/addon-svelte-csf',
|
||||
'@chromatic-com/storybook',
|
||||
'@storybook/addon-docs',
|
||||
'@storybook/addon-a11y',
|
||||
'@storybook/addon-vitest'
|
||||
],
|
||||
framework: {
|
||||
name: '@storybook/sveltekit',
|
||||
options: {}
|
||||
}
|
||||
};
|
||||
export default config;
|
||||
34
tools/server/webui/.storybook/preview.ts
Normal file
34
tools/server/webui/.storybook/preview.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import type { Preview } from '@storybook/sveltekit';
|
||||
import '../src/app.css';
|
||||
import ModeWatcherDecorator from './ModeWatcherDecorator.svelte';
|
||||
import TooltipProviderDecorator from './TooltipProviderDecorator.svelte';
|
||||
|
||||
const preview: Preview = {
|
||||
parameters: {
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/i
|
||||
}
|
||||
},
|
||||
backgrounds: {
|
||||
disable: true
|
||||
}
|
||||
},
|
||||
decorators: [
|
||||
(story) => ({
|
||||
Component: ModeWatcherDecorator,
|
||||
props: {
|
||||
children: story
|
||||
}
|
||||
}),
|
||||
(story) => ({
|
||||
Component: TooltipProviderDecorator,
|
||||
props: {
|
||||
children: story
|
||||
}
|
||||
})
|
||||
]
|
||||
};
|
||||
|
||||
export default preview;
|
||||
11
tools/server/webui/.storybook/vitest.setup.ts
Normal file
11
tools/server/webui/.storybook/vitest.setup.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { setProjectAnnotations } from '@storybook/sveltekit';
|
||||
import * as previewAnnotations from './preview';
|
||||
import { beforeAll } from 'vitest';
|
||||
|
||||
const project = setProjectAnnotations([previewAnnotations]);
|
||||
|
||||
beforeAll(async () => {
|
||||
if (project.beforeAll) {
|
||||
await project.beforeAll();
|
||||
}
|
||||
});
|
||||
66
tools/server/webui/README.md
Normal file
66
tools/server/webui/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# llama.cpp Web UI
|
||||
|
||||
A modern, feature-rich web interface for llama.cpp built with SvelteKit. This UI provides an intuitive chat interface with advanced file handling, conversation management, and comprehensive model interaction capabilities.
|
||||
|
||||
## Features
|
||||
|
||||
- **Modern Chat Interface** - Clean, responsive design with dark/light mode
|
||||
- **File Attachments** - Support for images, text files, PDFs, and audio with rich previews and drag-and-drop support
|
||||
- **Conversation Management** - Create, edit, branch, and search conversations
|
||||
- **Advanced Markdown** - Code highlighting, math formulas (KaTeX), and content blocks
|
||||
- **Reasoning Content** - Support for models with thinking blocks
|
||||
- **Keyboard Shortcuts** - Keyboard navigation (Shift+Ctrl/Cmd+O for new chat, Shift+Ctrl/Cmdt+E for edit conversation, Shift+Ctrl/Cmdt+D for delete conversation, Ctrl/Cmd+K for search, Ctrl/Cmd+V for paste, Ctrl/Cmd+B for opening/collapsing sidebar)
|
||||
- **Request Tracking** - Monitor processing with slots endpoint integration
|
||||
- **UI Testing** - Storybook component library with automated tests
|
||||
|
||||
## Development
|
||||
|
||||
Install dependencies:
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
Start the development server + Storybook:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
This will start both the SvelteKit dev server and Storybook on port 6006.
|
||||
|
||||
## Building
|
||||
|
||||
Create a production build:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
The build outputs static files to `../public` directory for deployment with llama.cpp server.
|
||||
|
||||
## Testing
|
||||
|
||||
Run the test suite:
|
||||
|
||||
```bash
|
||||
# E2E tests
|
||||
npm run test:e2e
|
||||
|
||||
# Unit tests
|
||||
npm run test:unit
|
||||
|
||||
# UI tests
|
||||
npm run test:ui
|
||||
|
||||
# All tests
|
||||
npm run test
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
- **Framework**: SvelteKit with Svelte 5 runes
|
||||
- **Components**: ShadCN UI + bits-ui design system
|
||||
- **Database**: IndexedDB with Dexie for local storage
|
||||
- **Build**: Static adapter for deployment with llama.cpp server
|
||||
- **Testing**: Playwright (E2E) + Vitest (unit) + Storybook (components)
|
||||
16
tools/server/webui/components.json
Normal file
16
tools/server/webui/components.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"$schema": "https://shadcn-svelte.com/schema.json",
|
||||
"tailwind": {
|
||||
"css": "src/app.css",
|
||||
"baseColor": "neutral"
|
||||
},
|
||||
"aliases": {
|
||||
"components": "$lib/components",
|
||||
"utils": "$lib/components/ui/utils",
|
||||
"ui": "$lib/components/ui",
|
||||
"hooks": "$lib/hooks",
|
||||
"lib": "$lib"
|
||||
},
|
||||
"typescript": true,
|
||||
"registry": "https://shadcn-svelte.com/registry"
|
||||
}
|
||||
6
tools/server/webui/e2e/demo.test.ts
Normal file
6
tools/server/webui/e2e/demo.test.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { expect, test } from '@playwright/test';
|
||||
|
||||
test('home page has expected h1', async ({ page }) => {
|
||||
await page.goto('/');
|
||||
await expect(page.locator('h1')).toBeVisible();
|
||||
});
|
||||
@@ -1,26 +1,49 @@
|
||||
import js from '@eslint/js'
|
||||
import globals from 'globals'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
import tseslint from 'typescript-eslint'
|
||||
// For more info, see https://github.com/storybookjs/eslint-plugin-storybook#configuration-flat-config-format
|
||||
import storybook from 'eslint-plugin-storybook';
|
||||
|
||||
export default tseslint.config(
|
||||
{ ignores: ['dist'] },
|
||||
{
|
||||
extends: [js.configs.recommended, ...tseslint.configs.recommended],
|
||||
files: ['**/*.{ts,tsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
},
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'react-refresh/only-export-components': 'off',
|
||||
'@typescript-eslint/no-unused-vars': 'off',
|
||||
},
|
||||
},
|
||||
)
|
||||
import prettier from 'eslint-config-prettier';
|
||||
import { includeIgnoreFile } from '@eslint/compat';
|
||||
import js from '@eslint/js';
|
||||
import svelte from 'eslint-plugin-svelte';
|
||||
import globals from 'globals';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
import ts from 'typescript-eslint';
|
||||
import svelteConfig from './svelte.config.js';
|
||||
|
||||
const gitignorePath = fileURLToPath(new URL('./.gitignore', import.meta.url));
|
||||
|
||||
export default ts.config(
|
||||
includeIgnoreFile(gitignorePath),
|
||||
js.configs.recommended,
|
||||
...ts.configs.recommended,
|
||||
...svelte.configs.recommended,
|
||||
prettier,
|
||||
...svelte.configs.prettier,
|
||||
{
|
||||
languageOptions: {
|
||||
globals: { ...globals.browser, ...globals.node }
|
||||
},
|
||||
rules: {
|
||||
// typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects.
|
||||
// see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors
|
||||
'no-undef': 'off',
|
||||
'svelte/no-at-html-tags': 'off'
|
||||
}
|
||||
},
|
||||
{
|
||||
files: ['**/*.svelte', '**/*.svelte.ts', '**/*.svelte.js'],
|
||||
languageOptions: {
|
||||
parserOptions: {
|
||||
projectService: true,
|
||||
extraFileExtensions: ['.svelte'],
|
||||
parser: ts.parser,
|
||||
svelteConfig
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
// Exclude Storybook files from main ESLint rules
|
||||
ignores: ['.storybook/**/*']
|
||||
},
|
||||
storybook.configs['flat/recommended']
|
||||
);
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta
|
||||
name="viewport"
|
||||
content="width=device-width, initial-scale=1, maximum-scale=1"
|
||||
/>
|
||||
<meta name="color-scheme" content="light dark" />
|
||||
<title>🦙 llama.cpp - chat</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.tsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
15102
tools/server/webui/package-lock.json
generated
15102
tools/server/webui/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -1,66 +1,90 @@
|
||||
{
|
||||
"name": "webui",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "npm run format && tsc -b && vite build",
|
||||
"format": "eslint . && prettier --write .",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@heroicons/react": "^2.2.0",
|
||||
"@sec-ant/readable-stream": "^0.6.0",
|
||||
"@tailwindcss/postcss": "^4.1.1",
|
||||
"@tailwindcss/vite": "^4.1.1",
|
||||
"@vscode/markdown-it-katex": "^1.1.1",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"daisyui": "^5.0.12",
|
||||
"dexie": "^4.0.11",
|
||||
"highlight.js": "^11.10.0",
|
||||
"katex": "^0.16.15",
|
||||
"pdfjs-dist": "^5.2.133",
|
||||
"postcss": "^8.4.49",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"react-dropzone": "^14.3.8",
|
||||
"react-hot-toast": "^2.5.2",
|
||||
"react-markdown": "^9.0.3",
|
||||
"react-router": "^7.1.5",
|
||||
"rehype-highlight": "^7.0.2",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"tailwindcss": "^4.1.1",
|
||||
"textlinestream": "^1.1.1",
|
||||
"vite-plugin-singlefile": "^2.0.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.17.0",
|
||||
"@types/markdown-it": "^14.1.2",
|
||||
"@types/node": "^22.13.1",
|
||||
"@types/react": "^18.3.18",
|
||||
"@types/react-dom": "^18.3.5",
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"eslint": "^9.17.0",
|
||||
"eslint-plugin-react-hooks": "^5.0.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.16",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^15.14.0",
|
||||
"prettier": "^3.4.2",
|
||||
"sass-embedded": "^1.83.4",
|
||||
"typescript": "~5.6.2",
|
||||
"typescript-eslint": "^8.18.2",
|
||||
"vite": "^6.0.5"
|
||||
},
|
||||
"prettier": {
|
||||
"trailingComma": "es5",
|
||||
"tabWidth": 2,
|
||||
"semi": true,
|
||||
"singleQuote": true,
|
||||
"bracketSameLine": false
|
||||
}
|
||||
"name": "webui",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite dev --host 0.0.0.0 & storybook dev -p 6006 --ci",
|
||||
"build": "vite build && ./scripts/post-build.sh",
|
||||
"preview": "vite preview",
|
||||
"prepare": "svelte-kit sync || echo ''",
|
||||
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
|
||||
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
|
||||
"reset": "rm -rf .svelte-kit node_modules",
|
||||
"format": "prettier --write .",
|
||||
"lint": "prettier --check . && eslint .",
|
||||
"test": "npm run test:ui -- --run && npm run test:client -- --run && npm run test:server -- --run && npm run test:e2e",
|
||||
"test:e2e": "playwright test",
|
||||
"test:client": "vitest --project=client",
|
||||
"test:server": "vitest --project=server",
|
||||
"test:ui": "vitest --project=ui",
|
||||
"test:unit": "vitest",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/storybook": "^4.0.1",
|
||||
"@eslint/compat": "^1.2.5",
|
||||
"@eslint/js": "^9.18.0",
|
||||
"@internationalized/date": "^3.8.2",
|
||||
"@lucide/svelte": "^0.515.0",
|
||||
"@playwright/test": "^1.49.1",
|
||||
"@storybook/addon-a11y": "^9.0.17",
|
||||
"@storybook/addon-docs": "^9.0.17",
|
||||
"@storybook/addon-svelte-csf": "^5.0.7",
|
||||
"@storybook/addon-vitest": "^9.0.17",
|
||||
"@storybook/sveltekit": "^9.0.17",
|
||||
"@sveltejs/adapter-static": "^3.0.8",
|
||||
"@sveltejs/kit": "^2.22.0",
|
||||
"@sveltejs/vite-plugin-svelte": "^6.0.0",
|
||||
"@tailwindcss/forms": "^0.5.9",
|
||||
"@tailwindcss/typography": "^0.5.15",
|
||||
"@tailwindcss/vite": "^4.0.0",
|
||||
"@types/node": "^22",
|
||||
"@vitest/browser": "^3.2.3",
|
||||
"bits-ui": "^2.8.11",
|
||||
"clsx": "^2.1.1",
|
||||
"dexie": "^4.0.11",
|
||||
"eslint": "^9.18.0",
|
||||
"eslint-config-prettier": "^10.0.1",
|
||||
"eslint-plugin-storybook": "^9.0.17",
|
||||
"eslint-plugin-svelte": "^3.0.0",
|
||||
"fflate": "^0.8.2",
|
||||
"globals": "^16.0.0",
|
||||
"mdsvex": "^0.12.3",
|
||||
"playwright": "^1.53.0",
|
||||
"prettier": "^3.4.2",
|
||||
"prettier-plugin-svelte": "^3.3.3",
|
||||
"prettier-plugin-tailwindcss": "^0.6.11",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"storybook": "^9.0.17",
|
||||
"svelte": "^5.0.0",
|
||||
"svelte-check": "^4.0.0",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"tailwind-variants": "^1.0.0",
|
||||
"tailwindcss": "^4.0.0",
|
||||
"tw-animate-css": "^1.3.5",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.20.0",
|
||||
"uuid": "^13.0.0",
|
||||
"vite": "^7.0.4",
|
||||
"vite-plugin-devtools-json": "^0.2.0",
|
||||
"vitest": "^3.2.3",
|
||||
"vitest-browser-svelte": "^0.1.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"highlight.js": "^11.11.1",
|
||||
"mode-watcher": "^1.1.0",
|
||||
"pdfjs-dist": "^5.4.54",
|
||||
"rehype-highlight": "^7.0.2",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark": "^15.0.1",
|
||||
"remark-breaks": "^4.0.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-html": "^16.0.1",
|
||||
"remark-rehype": "^11.1.2",
|
||||
"svelte-sonner": "^1.0.5",
|
||||
"unist-util-visit": "^5.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
9
tools/server/webui/playwright.config.ts
Normal file
9
tools/server/webui/playwright.config.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
import { defineConfig } from '@playwright/test';
|
||||
|
||||
export default defineConfig({
|
||||
webServer: {
|
||||
command: 'npm run build && npx http-server ../public -p 8181',
|
||||
port: 8181
|
||||
},
|
||||
testDir: 'e2e'
|
||||
});
|
||||
@@ -1,5 +0,0 @@
|
||||
export default {
|
||||
plugins: {
|
||||
"@tailwindcss/postcss": {},
|
||||
},
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
{
|
||||
"demo": true,
|
||||
"id": "conv-1734086746930",
|
||||
"lastModified": 1734087548943,
|
||||
"messages": [
|
||||
{
|
||||
"id": 1734086764521,
|
||||
"role": "user",
|
||||
"content": "this is a demo conversation, used in dev mode"
|
||||
},
|
||||
{
|
||||
"id": 1734087548327,
|
||||
"role": "assistant",
|
||||
"content": "This is the formula:\n\n$\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}$\n\nGiven an input vector \\(\\mathbf{x} = [x_1, x_2, \\ldots, x_n]\\)\n\n\\[\ny_i = \\frac{e^{x_i}}{\\sum_{j=1}^n e^{x_j}}\n\\]\n\n$2x + y = z$\n\nCode block latex:\n```latex\n\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}\n```\n\nTest dollar sign: $1234 $4567\n\nInvalid latex syntax: $E = mc^$ and $$E = mc^$$",
|
||||
"timings": {
|
||||
"prompt_n": 1,
|
||||
"prompt_ms": 28.923,
|
||||
"predicted_n": 25,
|
||||
"predicted_ms": 573.016
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1734087548328,
|
||||
"role": "user",
|
||||
"content": "this is a demo conversation, used in dev mode"
|
||||
},
|
||||
{
|
||||
"id": 1734087548329,
|
||||
"role": "assistant",
|
||||
"content": "Code block:\n```js\nconsole.log('hello world')\n```\n```sh\nls -la /dev\n```"
|
||||
}
|
||||
]
|
||||
}
|
||||
123
tools/server/webui/scripts/install-git-hooks.sh
Executable file
123
tools/server/webui/scripts/install-git-hooks.sh
Executable file
@@ -0,0 +1,123 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Script to install pre-commit and post-commit hooks for webui
|
||||
# Pre-commit: formats code and builds, stashes unstaged changes
|
||||
# Post-commit: automatically unstashes changes
|
||||
|
||||
REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
PRE_COMMIT_HOOK="$REPO_ROOT/.git/hooks/pre-commit"
|
||||
POST_COMMIT_HOOK="$REPO_ROOT/.git/hooks/post-commit"
|
||||
|
||||
echo "Installing pre-commit and post-commit hooks for webui..."
|
||||
|
||||
# Create the pre-commit hook
|
||||
cat > "$PRE_COMMIT_HOOK" << 'EOF'
|
||||
#!/bin/bash
|
||||
|
||||
# Check if there are any changes in the webui directory
|
||||
if git diff --cached --name-only | grep -q "^tools/server/webui/"; then
|
||||
echo "Formatting webui code..."
|
||||
|
||||
# Change to webui directory and run format
|
||||
cd tools/server/webui
|
||||
|
||||
# Check if npm is available and package.json exists
|
||||
if [ ! -f "package.json" ]; then
|
||||
echo "Error: package.json not found in tools/server/webui"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Stash any unstaged changes to avoid conflicts during format/build
|
||||
echo "Stashing unstaged changes..."
|
||||
git stash push --keep-index --include-untracked -m "Pre-commit hook: stashed unstaged changes"
|
||||
STASH_CREATED=$?
|
||||
|
||||
# Run the format command
|
||||
npm run format
|
||||
|
||||
# Check if format command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run format failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the check command
|
||||
npm run check
|
||||
|
||||
# Check if check command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run check failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Run the build command
|
||||
npm run build
|
||||
|
||||
# Check if build command succeeded
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Error: npm run build failed"
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "You can restore your unstaged changes with: git stash pop"
|
||||
fi
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Go back to repo root to add build output
|
||||
cd ../../..
|
||||
|
||||
# Add the build output to staging area
|
||||
git add tools/server/public/index.html.gz
|
||||
|
||||
if [ $STASH_CREATED -eq 0 ]; then
|
||||
echo "✅ Build completed. Your unstaged changes have been stashed."
|
||||
echo "They will be automatically restored after the commit."
|
||||
# Create a marker file to indicate stash was created by pre-commit hook
|
||||
touch .git/WEBUI_STASH_MARKER
|
||||
fi
|
||||
|
||||
echo "Webui code formatted successfully"
|
||||
fi
|
||||
|
||||
exit 0
|
||||
EOF
|
||||
|
||||
# Create the post-commit hook
|
||||
cat > "$POST_COMMIT_HOOK" << 'EOF'
|
||||
#!/bin/bash
|
||||
|
||||
# Check if we have a stash marker from the pre-commit hook
|
||||
if [ -f .git/WEBUI_STASH_MARKER ]; then
|
||||
echo "Restoring your unstaged changes..."
|
||||
git stash pop
|
||||
rm -f .git/WEBUI_STASH_MARKER
|
||||
echo "✅ Your unstaged changes have been restored."
|
||||
fi
|
||||
|
||||
exit 0
|
||||
EOF
|
||||
|
||||
# Make both hooks executable
|
||||
chmod +x "$PRE_COMMIT_HOOK"
|
||||
chmod +x "$POST_COMMIT_HOOK"
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "✅ Pre-commit and post-commit hooks installed successfully!"
|
||||
echo " Pre-commit: $PRE_COMMIT_HOOK"
|
||||
echo " Post-commit: $POST_COMMIT_HOOK"
|
||||
echo ""
|
||||
echo "The hooks will automatically:"
|
||||
echo " • Format and build webui code before commits"
|
||||
echo " • Stash unstaged changes during the process"
|
||||
echo " • Restore your unstaged changes after the commit"
|
||||
echo ""
|
||||
echo "To test the hooks, make a change to a file in the webui directory and commit it."
|
||||
else
|
||||
echo "❌ Failed to make hooks executable"
|
||||
exit 1
|
||||
fi
|
||||
3
tools/server/webui/scripts/post-build.sh
Executable file
3
tools/server/webui/scripts/post-build.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
rm -rf ../public/_app;
|
||||
rm ../public/favicon.svg;
|
||||
rm ../public/index.html;
|
||||
@@ -1,52 +0,0 @@
|
||||
import { HashRouter, Outlet, Route, Routes } from 'react-router';
|
||||
import Header from './components/Header';
|
||||
import Sidebar from './components/Sidebar';
|
||||
import { AppContextProvider, useAppContext } from './utils/app.context';
|
||||
import ChatScreen from './components/ChatScreen';
|
||||
import SettingDialog from './components/SettingDialog';
|
||||
import { Toaster } from 'react-hot-toast';
|
||||
import { ModalProvider } from './components/ModalProvider';
|
||||
|
||||
function App() {
|
||||
return (
|
||||
<ModalProvider>
|
||||
<HashRouter>
|
||||
<div className="flex flex-row drawer lg:drawer-open">
|
||||
<AppContextProvider>
|
||||
<Routes>
|
||||
<Route element={<AppLayout />}>
|
||||
<Route path="/chat/:convId" element={<ChatScreen />} />
|
||||
<Route path="*" element={<ChatScreen />} />
|
||||
</Route>
|
||||
</Routes>
|
||||
</AppContextProvider>
|
||||
</div>
|
||||
</HashRouter>
|
||||
</ModalProvider>
|
||||
);
|
||||
}
|
||||
|
||||
function AppLayout() {
|
||||
const { showSettings, setShowSettings } = useAppContext();
|
||||
return (
|
||||
<>
|
||||
<Sidebar />
|
||||
<main
|
||||
className="drawer-content grow flex flex-col h-screen mx-auto px-4 overflow-auto bg-base-100"
|
||||
id="main-scroll"
|
||||
>
|
||||
<Header />
|
||||
<Outlet />
|
||||
</main>
|
||||
{
|
||||
<SettingDialog
|
||||
show={showSettings}
|
||||
onClose={() => setShowSettings(false)}
|
||||
/>
|
||||
}
|
||||
<Toaster />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
@@ -1,96 +0,0 @@
|
||||
import daisyuiThemes from 'daisyui/theme/object';
|
||||
import { isNumeric } from './utils/misc';
|
||||
|
||||
export const isDev = import.meta.env.MODE === 'development';
|
||||
|
||||
// constants
|
||||
export const BASE_URL = new URL('.', document.baseURI).href
|
||||
.toString()
|
||||
.replace(/\/$/, '');
|
||||
|
||||
export const CONFIG_DEFAULT = {
|
||||
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
|
||||
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
|
||||
apiKey: '',
|
||||
systemMessage: '',
|
||||
showTokensPerSecond: false,
|
||||
showThoughtInProgress: false,
|
||||
excludeThoughtOnReq: true,
|
||||
pasteLongTextToFileLen: 2500,
|
||||
pdfAsImage: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'edkypmxt',
|
||||
temperature: 0.8,
|
||||
dynatemp_range: 0.0,
|
||||
dynatemp_exponent: 1.0,
|
||||
top_k: 40,
|
||||
top_p: 0.95,
|
||||
min_p: 0.05,
|
||||
xtc_probability: 0.0,
|
||||
xtc_threshold: 0.1,
|
||||
typical_p: 1.0,
|
||||
repeat_last_n: 64,
|
||||
repeat_penalty: 1.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: 0.0,
|
||||
dry_multiplier: 0.0,
|
||||
dry_base: 1.75,
|
||||
dry_allowed_length: 2,
|
||||
dry_penalty_last_n: -1,
|
||||
max_tokens: -1,
|
||||
custom: '', // custom json-stringified object
|
||||
// experimental features
|
||||
pyIntepreterEnabled: false,
|
||||
};
|
||||
export const CONFIG_INFO: Record<string, string> = {
|
||||
apiKey: 'Set the API Key if you are using --api-key option for the server.',
|
||||
systemMessage: 'The starting message that defines how model should behave.',
|
||||
pasteLongTextToFileLen:
|
||||
'On pasting long text, it will be converted to a file. You can control the file length by setting the value of this parameter. Value 0 means disable.',
|
||||
samplers:
|
||||
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
|
||||
temperature:
|
||||
'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
|
||||
dynatemp_range:
|
||||
'Addon for the temperature sampler. The added value to the range of dynamic temperature, which adjusts probabilities by entropy of tokens.',
|
||||
dynatemp_exponent:
|
||||
'Addon for the temperature sampler. Smoothes out the probability redistribution based on the most probable token.',
|
||||
top_k: 'Keeps only k top tokens.',
|
||||
top_p:
|
||||
'Limits tokens to those that together have a cumulative probability of at least p',
|
||||
min_p:
|
||||
'Limits tokens based on the minimum probability for a token to be considered, relative to the probability of the most likely token.',
|
||||
xtc_probability:
|
||||
'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.',
|
||||
xtc_threshold:
|
||||
'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.',
|
||||
typical_p:
|
||||
'Sorts and limits tokens based on the difference between log-probability and entropy.',
|
||||
repeat_last_n: 'Last n tokens to consider for penalizing repetition',
|
||||
repeat_penalty:
|
||||
'Controls the repetition of token sequences in the generated text',
|
||||
presence_penalty:
|
||||
'Limits tokens based on whether they appear in the output or not.',
|
||||
frequency_penalty:
|
||||
'Limits tokens based on how often they appear in the output.',
|
||||
dry_multiplier:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling multiplier.',
|
||||
dry_base:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling base value.',
|
||||
dry_allowed_length:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the allowed length for DRY sampling.',
|
||||
dry_penalty_last_n:
|
||||
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.',
|
||||
max_tokens: 'The maximum number of token per output.',
|
||||
custom: '', // custom json-stringified object
|
||||
};
|
||||
// config keys having numeric value (i.e. temperature, top_k, top_p, etc)
|
||||
export const CONFIG_NUMERIC_KEYS = Object.entries(CONFIG_DEFAULT)
|
||||
.filter((e) => isNumeric(e[1]))
|
||||
.map((e) => e[0]);
|
||||
// list of themes supported by daisyui
|
||||
export const THEMES = ['light', 'dark']
|
||||
// make sure light & dark are always at the beginning
|
||||
.concat(
|
||||
Object.keys(daisyuiThemes).filter((t) => t !== 'light' && t !== 'dark')
|
||||
);
|
||||
123
tools/server/webui/src/app.css
Normal file
123
tools/server/webui/src/app.css
Normal file
@@ -0,0 +1,123 @@
|
||||
@import 'tailwindcss';
|
||||
|
||||
@import 'tw-animate-css';
|
||||
|
||||
@custom-variant dark (&:is(.dark *));
|
||||
|
||||
:root {
|
||||
--radius: 0.625rem;
|
||||
--background: oklch(1 0 0);
|
||||
--foreground: oklch(0.145 0 0);
|
||||
--card: oklch(1 0 0);
|
||||
--card-foreground: oklch(0.145 0 0);
|
||||
--popover: oklch(1 0 0);
|
||||
--popover-foreground: oklch(0.145 0 0);
|
||||
--primary: oklch(0.205 0 0);
|
||||
--primary-foreground: oklch(0.985 0 0);
|
||||
--secondary: oklch(0.97 0 0);
|
||||
--secondary-foreground: oklch(0.205 0 0);
|
||||
--muted: oklch(0.97 0 0);
|
||||
--muted-foreground: oklch(0.556 0 0);
|
||||
--accent: oklch(0.97 0 0);
|
||||
--accent-foreground: oklch(0.205 0 0);
|
||||
--destructive: oklch(0.577 0.245 27.325);
|
||||
--border: oklch(0.875 0 0);
|
||||
--input: oklch(0.92 0 0);
|
||||
--ring: oklch(0.708 0 0);
|
||||
--chart-1: oklch(0.646 0.222 41.116);
|
||||
--chart-2: oklch(0.6 0.118 184.704);
|
||||
--chart-3: oklch(0.398 0.07 227.392);
|
||||
--chart-4: oklch(0.828 0.189 84.429);
|
||||
--chart-5: oklch(0.769 0.188 70.08);
|
||||
--sidebar: oklch(0.985 0 0);
|
||||
--sidebar-foreground: oklch(0.145 0 0);
|
||||
--sidebar-primary: oklch(0.205 0 0);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
--sidebar-accent: oklch(0.97 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||
--sidebar-border: oklch(0.922 0 0);
|
||||
--sidebar-ring: oklch(0.708 0 0);
|
||||
--code-background: oklch(0.225 0 0);
|
||||
--code-foreground: oklch(0.875 0 0);
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: oklch(0.16 0 0);
|
||||
--foreground: oklch(0.985 0 0);
|
||||
--card: oklch(0.205 0 0);
|
||||
--card-foreground: oklch(0.985 0 0);
|
||||
--popover: oklch(0.205 0 0);
|
||||
--popover-foreground: oklch(0.985 0 0);
|
||||
--primary: oklch(0.922 0 0);
|
||||
--primary-foreground: oklch(0.205 0 0);
|
||||
--secondary: oklch(0.269 0 0);
|
||||
--secondary-foreground: oklch(0.985 0 0);
|
||||
--muted: oklch(0.269 0 0);
|
||||
--muted-foreground: oklch(0.708 0 0);
|
||||
--accent: oklch(0.269 0 0);
|
||||
--accent-foreground: oklch(0.985 0 0);
|
||||
--destructive: oklch(0.704 0.191 22.216);
|
||||
--border: oklch(1 0 0 / 30%);
|
||||
--input: oklch(1 0 0 / 30%);
|
||||
--ring: oklch(0.556 0 0);
|
||||
--chart-1: oklch(0.488 0.243 264.376);
|
||||
--chart-2: oklch(0.696 0.17 162.48);
|
||||
--chart-3: oklch(0.769 0.188 70.08);
|
||||
--chart-4: oklch(0.627 0.265 303.9);
|
||||
--chart-5: oklch(0.645 0.246 16.439);
|
||||
--sidebar: oklch(0.205 0 0);
|
||||
--sidebar-foreground: oklch(0.985 0 0);
|
||||
--sidebar-primary: oklch(0.488 0.243 264.376);
|
||||
--sidebar-primary-foreground: oklch(0.985 0 0);
|
||||
--sidebar-accent: oklch(0.269 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.985 0 0);
|
||||
--sidebar-border: oklch(1 0 0 / 10%);
|
||||
--sidebar-ring: oklch(0.556 0 0);
|
||||
}
|
||||
|
||||
@theme inline {
|
||||
--radius-sm: calc(var(--radius) - 4px);
|
||||
--radius-md: calc(var(--radius) - 2px);
|
||||
--radius-lg: var(--radius);
|
||||
--radius-xl: calc(var(--radius) + 4px);
|
||||
--color-background: var(--background);
|
||||
--color-foreground: var(--foreground);
|
||||
--color-card: var(--card);
|
||||
--color-card-foreground: var(--card-foreground);
|
||||
--color-popover: var(--popover);
|
||||
--color-popover-foreground: var(--popover-foreground);
|
||||
--color-primary: var(--primary);
|
||||
--color-primary-foreground: var(--primary-foreground);
|
||||
--color-secondary: var(--secondary);
|
||||
--color-secondary-foreground: var(--secondary-foreground);
|
||||
--color-muted: var(--muted);
|
||||
--color-muted-foreground: var(--muted-foreground);
|
||||
--color-accent: var(--accent);
|
||||
--color-accent-foreground: var(--accent-foreground);
|
||||
--color-destructive: var(--destructive);
|
||||
--color-border: var(--border);
|
||||
--color-input: var(--input);
|
||||
--color-ring: var(--ring);
|
||||
--color-chart-1: var(--chart-1);
|
||||
--color-chart-2: var(--chart-2);
|
||||
--color-chart-3: var(--chart-3);
|
||||
--color-chart-4: var(--chart-4);
|
||||
--color-chart-5: var(--chart-5);
|
||||
--color-sidebar: var(--sidebar);
|
||||
--color-sidebar-foreground: var(--sidebar-foreground);
|
||||
--color-sidebar-primary: var(--sidebar-primary);
|
||||
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
|
||||
--color-sidebar-accent: var(--sidebar-accent);
|
||||
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
|
||||
--color-sidebar-border: var(--sidebar-border);
|
||||
--color-sidebar-ring: var(--sidebar-ring);
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border outline-ring/50;
|
||||
}
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
81
tools/server/webui/src/app.d.ts
vendored
Normal file
81
tools/server/webui/src/app.d.ts
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
// See https://svelte.dev/docs/kit/types#app.d.ts
|
||||
// for information about these interfaces
|
||||
|
||||
// Import chat types from dedicated module
|
||||
|
||||
import type {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatMessageData,
|
||||
ApiChatMessageContentPart,
|
||||
ApiContextSizeError,
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiProcessingState
|
||||
} from '$lib/types/api';
|
||||
|
||||
import type {
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageTimings
|
||||
} from '$lib/types/chat';
|
||||
|
||||
import type {
|
||||
DatabaseConversation,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile
|
||||
} from '$lib/types/database';
|
||||
|
||||
import type {
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType
|
||||
} from '$lib/types/settings';
|
||||
|
||||
declare global {
|
||||
// namespace App {
|
||||
// interface Error {}
|
||||
// interface Locals {}
|
||||
// interface PageData {}
|
||||
// interface PageState {}
|
||||
// interface Platform {}
|
||||
// }
|
||||
|
||||
export {
|
||||
ApiChatCompletionRequest,
|
||||
ApiChatCompletionResponse,
|
||||
ApiChatCompletionStreamChunk,
|
||||
ApiChatMessageData,
|
||||
ApiChatMessageContentPart,
|
||||
ApiContextSizeError,
|
||||
ApiErrorResponse,
|
||||
ApiLlamaCppServerProps,
|
||||
ApiProcessingState,
|
||||
ChatMessageData,
|
||||
ChatMessagePromptProgress,
|
||||
ChatMessageSiblingInfo,
|
||||
ChatMessageTimings,
|
||||
ChatMessageType,
|
||||
ChatRole,
|
||||
ChatUploadedFile,
|
||||
DatabaseConversation,
|
||||
DatabaseMessage,
|
||||
DatabaseMessageExtra,
|
||||
DatabaseMessageExtraAudioFile,
|
||||
DatabaseMessageExtraImageFile,
|
||||
DatabaseMessageExtraTextFile,
|
||||
DatabaseMessageExtraPdfFile,
|
||||
SettingsConfigValue,
|
||||
SettingsFieldConfig,
|
||||
SettingsConfigType,
|
||||
SettingsChatServiceOptions
|
||||
};
|
||||
}
|
||||
12
tools/server/webui/src/app.html
Normal file
12
tools/server/webui/src/app.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<link rel="icon" href="%sveltekit.assets%/favicon.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
%sveltekit.head%
|
||||
</head>
|
||||
<body data-sveltekit-preload-data="hover">
|
||||
<div style="display: contents">%sveltekit.body%</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,195 +0,0 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { OpenInNewTab, XCloseButton } from '../utils/common';
|
||||
import { CanvasType } from '../utils/types';
|
||||
import { PlayIcon, StopIcon } from '@heroicons/react/24/outline';
|
||||
import StorageUtils from '../utils/storage';
|
||||
|
||||
const canInterrupt = typeof SharedArrayBuffer === 'function';
|
||||
|
||||
// adapted from https://pyodide.org/en/stable/usage/webworker.html
|
||||
const WORKER_CODE = `
|
||||
importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js");
|
||||
|
||||
let stdOutAndErr = [];
|
||||
|
||||
let pyodideReadyPromise = loadPyodide({
|
||||
stdout: (data) => stdOutAndErr.push(data),
|
||||
stderr: (data) => stdOutAndErr.push(data),
|
||||
});
|
||||
|
||||
let alreadySetBuff = false;
|
||||
|
||||
self.onmessage = async (event) => {
|
||||
stdOutAndErr = [];
|
||||
|
||||
// make sure loading is done
|
||||
const pyodide = await pyodideReadyPromise;
|
||||
const { id, python, context, interruptBuffer } = event.data;
|
||||
|
||||
if (interruptBuffer && !alreadySetBuff) {
|
||||
pyodide.setInterruptBuffer(interruptBuffer);
|
||||
alreadySetBuff = true;
|
||||
}
|
||||
|
||||
// Now load any packages we need, run the code, and send the result back.
|
||||
await pyodide.loadPackagesFromImports(python);
|
||||
|
||||
// make a Python dictionary with the data from content
|
||||
const dict = pyodide.globals.get("dict");
|
||||
const globals = dict(Object.entries(context));
|
||||
try {
|
||||
self.postMessage({ id, running: true });
|
||||
// Execute the python code in this context
|
||||
const result = pyodide.runPython(python, { globals });
|
||||
self.postMessage({ result, id, stdOutAndErr });
|
||||
} catch (error) {
|
||||
self.postMessage({ error: error.message, id });
|
||||
}
|
||||
interruptBuffer[0] = 0;
|
||||
};
|
||||
`;
|
||||
|
||||
let worker: Worker;
|
||||
const interruptBuffer = canInterrupt
|
||||
? new Uint8Array(new SharedArrayBuffer(1))
|
||||
: null;
|
||||
|
||||
const startWorker = () => {
|
||||
if (!worker) {
|
||||
worker = new Worker(
|
||||
URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' }))
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if (StorageUtils.getConfig().pyIntepreterEnabled) {
|
||||
startWorker();
|
||||
}
|
||||
|
||||
const runCodeInWorker = (
|
||||
pyCode: string,
|
||||
callbackRunning: () => void
|
||||
): {
|
||||
donePromise: Promise<string>;
|
||||
interrupt: () => void;
|
||||
} => {
|
||||
startWorker();
|
||||
const id = Math.random() * 1e8;
|
||||
const context = {};
|
||||
if (interruptBuffer) {
|
||||
interruptBuffer[0] = 0;
|
||||
}
|
||||
|
||||
const donePromise = new Promise<string>((resolve) => {
|
||||
worker.onmessage = (event) => {
|
||||
const { error, stdOutAndErr, running } = event.data;
|
||||
if (id !== event.data.id) return;
|
||||
if (running) {
|
||||
callbackRunning();
|
||||
return;
|
||||
} else if (error) {
|
||||
resolve(error.toString());
|
||||
} else {
|
||||
resolve(stdOutAndErr.join('\n'));
|
||||
}
|
||||
};
|
||||
worker.postMessage({ id, python: pyCode, context, interruptBuffer });
|
||||
});
|
||||
|
||||
const interrupt = () => {
|
||||
console.log('Interrupting...');
|
||||
console.trace();
|
||||
if (interruptBuffer) {
|
||||
interruptBuffer[0] = 2;
|
||||
}
|
||||
};
|
||||
|
||||
return { donePromise, interrupt };
|
||||
};
|
||||
|
||||
export default function CanvasPyInterpreter() {
|
||||
const { canvasData, setCanvasData } = useAppContext();
|
||||
|
||||
const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation
|
||||
const [running, setRunning] = useState(false);
|
||||
const [output, setOutput] = useState('');
|
||||
const [interruptFn, setInterruptFn] = useState<() => void>();
|
||||
const [showStopBtn, setShowStopBtn] = useState(false);
|
||||
|
||||
const runCode = async (pycode: string) => {
|
||||
interruptFn?.();
|
||||
setRunning(true);
|
||||
setOutput('Loading Pyodide...');
|
||||
const { donePromise, interrupt } = runCodeInWorker(pycode, () => {
|
||||
setOutput('Running...');
|
||||
setShowStopBtn(canInterrupt);
|
||||
});
|
||||
setInterruptFn(() => interrupt);
|
||||
const out = await donePromise;
|
||||
setOutput(out);
|
||||
setRunning(false);
|
||||
setShowStopBtn(false);
|
||||
};
|
||||
|
||||
// run code on mount
|
||||
useEffect(() => {
|
||||
setCode(canvasData?.content ?? '');
|
||||
runCode(canvasData?.content ?? '');
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [canvasData?.content]);
|
||||
|
||||
if (canvasData?.type !== CanvasType.PY_INTERPRETER) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="card bg-base-200 w-full h-full shadow-xl">
|
||||
<div className="card-body">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<span className="text-lg font-bold">Python Interpreter</span>
|
||||
<XCloseButton
|
||||
className="bg-base-100"
|
||||
onClick={() => setCanvasData(null)}
|
||||
/>
|
||||
</div>
|
||||
<div className="grid grid-rows-3 gap-4 h-full">
|
||||
<textarea
|
||||
className="textarea textarea-bordered w-full h-full font-mono"
|
||||
value={code}
|
||||
onChange={(e) => setCode(e.target.value)}
|
||||
></textarea>
|
||||
<div className="font-mono flex flex-col row-span-2">
|
||||
<div className="flex items-center mb-2">
|
||||
<button
|
||||
className="btn btn-sm bg-base-100"
|
||||
onClick={() => runCode(code)}
|
||||
disabled={running}
|
||||
>
|
||||
<PlayIcon className="h-6 w-6" /> Run
|
||||
</button>
|
||||
{showStopBtn && (
|
||||
<button
|
||||
className="btn btn-sm bg-base-100 ml-2"
|
||||
onClick={() => interruptFn?.()}
|
||||
>
|
||||
<StopIcon className="h-6 w-6" /> Stop
|
||||
</button>
|
||||
)}
|
||||
<span className="grow text-right text-xs">
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/11762">
|
||||
Report a bug
|
||||
</OpenInNewTab>
|
||||
</span>
|
||||
</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-full dark-color"
|
||||
value={output}
|
||||
readOnly
|
||||
></textarea>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
import {
|
||||
DocumentTextIcon,
|
||||
SpeakerWaveIcon,
|
||||
XMarkIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { MessageExtra } from '../utils/types';
|
||||
import { useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
|
||||
export default function ChatInputExtraContextItem({
|
||||
items,
|
||||
removeItem,
|
||||
clickToShow,
|
||||
}: {
|
||||
items?: MessageExtra[];
|
||||
removeItem?: (index: number) => void;
|
||||
clickToShow?: boolean;
|
||||
}) {
|
||||
const [show, setShow] = useState(-1);
|
||||
const showingItem = show >= 0 ? items?.[show] : undefined;
|
||||
|
||||
if (!items) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-row gap-4 overflow-x-auto py-2 px-1 mb-1"
|
||||
role="group"
|
||||
aria-description="Selected files"
|
||||
>
|
||||
{items.map((item, i) => (
|
||||
<div
|
||||
className="indicator"
|
||||
key={i}
|
||||
onClick={() => clickToShow && setShow(i)}
|
||||
tabIndex={0}
|
||||
aria-description={
|
||||
clickToShow ? `Click to show: ${item.name}` : undefined
|
||||
}
|
||||
role={clickToShow ? 'button' : 'menuitem'}
|
||||
>
|
||||
{removeItem && (
|
||||
<div className="indicator-item indicator-top">
|
||||
<button
|
||||
aria-label="Remove file"
|
||||
className="btn btn-neutral btn-sm w-4 h-4 p-0 rounded-full"
|
||||
onClick={() => removeItem(i)}
|
||||
>
|
||||
<XMarkIcon className="h-3 w-3" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={classNames({
|
||||
'flex flex-row rounded-md shadow-sm items-center m-0 p-0': true,
|
||||
'cursor-pointer hover:shadow-md': !!clickToShow,
|
||||
})}
|
||||
>
|
||||
{item.type === 'imageFile' ? (
|
||||
<>
|
||||
<img
|
||||
src={item.base64Url}
|
||||
alt={`Preview image for ${item.name}`}
|
||||
className="w-14 h-14 object-cover rounded-md"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div
|
||||
className="w-14 h-14 flex items-center justify-center"
|
||||
aria-description="Document icon"
|
||||
>
|
||||
{item.type === 'audioFile' ? (
|
||||
<SpeakerWaveIcon className="h-8 w-8 text-gray-500" />
|
||||
) : (
|
||||
<DocumentTextIcon className="h-8 w-8 text-gray-500" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="text-xs pr-4">
|
||||
<b>{item.name ?? 'Extra content'}</b>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{showingItem && (
|
||||
<dialog
|
||||
className="modal modal-open"
|
||||
aria-description={`Preview ${showingItem.name}`}
|
||||
>
|
||||
<div className="modal-box">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<b>{showingItem.name ?? 'Extra content'}</b>
|
||||
<button
|
||||
className="btn btn-ghost btn-sm"
|
||||
aria-label="Close preview dialog"
|
||||
>
|
||||
<XMarkIcon className="h-5 w-5" onClick={() => setShow(-1)} />
|
||||
</button>
|
||||
</div>
|
||||
{showingItem.type === 'imageFile' ? (
|
||||
<img
|
||||
src={showingItem.base64Url}
|
||||
alt={`Preview image for ${showingItem.name}`}
|
||||
/>
|
||||
) : showingItem.type === 'audioFile' ? (
|
||||
<audio
|
||||
controls
|
||||
className="w-full"
|
||||
aria-description={`Audio file ${showingItem.name}`}
|
||||
>
|
||||
<source
|
||||
src={`data:${showingItem.mimeType};base64,${showingItem.base64Data}`}
|
||||
type={showingItem.mimeType}
|
||||
aria-description={`Audio file ${showingItem.name}`}
|
||||
/>
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
) : (
|
||||
<div className="overflow-x-auto">
|
||||
<pre className="whitespace-pre-wrap break-words text-sm">
|
||||
{showingItem.content}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="modal-backdrop" onClick={() => setShow(-1)}></div>
|
||||
</dialog>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { Message, PendingMessage } from '../utils/types';
|
||||
import { classNames } from '../utils/misc';
|
||||
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
|
||||
import {
|
||||
ArrowPathIcon,
|
||||
ChevronLeftIcon,
|
||||
ChevronRightIcon,
|
||||
PencilSquareIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import ChatInputExtraContextItem from './ChatInputExtraContextItem';
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
|
||||
interface SplitMessage {
|
||||
content: PendingMessage['content'];
|
||||
thought?: string;
|
||||
isThinking?: boolean;
|
||||
}
|
||||
|
||||
export default function ChatMessage({
|
||||
msg,
|
||||
siblingLeafNodeIds,
|
||||
siblingCurrIdx,
|
||||
id,
|
||||
onRegenerateMessage,
|
||||
onEditMessage,
|
||||
onChangeSibling,
|
||||
isPending,
|
||||
}: {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
id?: string;
|
||||
onRegenerateMessage(msg: Message): void;
|
||||
onEditMessage(msg: Message, content: string): void;
|
||||
onChangeSibling(sibling: Message['id']): void;
|
||||
isPending?: boolean;
|
||||
}) {
|
||||
const { viewingChat, config } = useAppContext();
|
||||
const [editingContent, setEditingContent] = useState<string | null>(null);
|
||||
const timings = useMemo(
|
||||
() =>
|
||||
msg.timings
|
||||
? {
|
||||
...msg.timings,
|
||||
prompt_per_second:
|
||||
(msg.timings.prompt_n / msg.timings.prompt_ms) * 1000,
|
||||
predicted_per_second:
|
||||
(msg.timings.predicted_n / msg.timings.predicted_ms) * 1000,
|
||||
}
|
||||
: null,
|
||||
[msg.timings]
|
||||
);
|
||||
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
|
||||
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
|
||||
|
||||
// for reasoning model, we split the message into content and thought
|
||||
// TODO: implement this as remark/rehype plugin in the future
|
||||
const { content, thought, isThinking }: SplitMessage = useMemo(() => {
|
||||
if (msg.content === null || msg.role !== 'assistant') {
|
||||
return { content: msg.content };
|
||||
}
|
||||
const REGEX_THINK_OPEN = /<think>|<\|channel\|>analysis<\|message\|>/;
|
||||
const REGEX_THINK_CLOSE = /<\/think>|<\|end\|>/;
|
||||
let actualContent = '';
|
||||
let thought = '';
|
||||
let isThinking = false;
|
||||
let thinkSplit = msg.content.split(REGEX_THINK_OPEN, 2);
|
||||
actualContent += thinkSplit[0];
|
||||
while (thinkSplit[1] !== undefined) {
|
||||
// <think> tag found
|
||||
thinkSplit = thinkSplit[1].split(REGEX_THINK_CLOSE, 2);
|
||||
thought += thinkSplit[0];
|
||||
isThinking = true;
|
||||
if (thinkSplit[1] !== undefined) {
|
||||
// </think> closing tag found
|
||||
isThinking = false;
|
||||
thinkSplit = thinkSplit[1].split(REGEX_THINK_OPEN, 2);
|
||||
actualContent += thinkSplit[0];
|
||||
}
|
||||
}
|
||||
return { content: actualContent, thought, isThinking };
|
||||
}, [msg]);
|
||||
|
||||
if (!viewingChat) return null;
|
||||
|
||||
const isUser = msg.role === 'user';
|
||||
|
||||
return (
|
||||
<div
|
||||
className="group"
|
||||
id={id}
|
||||
role="group"
|
||||
aria-description={`Message from ${msg.role}`}
|
||||
>
|
||||
<div
|
||||
className={classNames({
|
||||
chat: true,
|
||||
'chat-start': !isUser,
|
||||
'chat-end': isUser,
|
||||
})}
|
||||
>
|
||||
{msg.extra && msg.extra.length > 0 && (
|
||||
<ChatInputExtraContextItem items={msg.extra} clickToShow />
|
||||
)}
|
||||
|
||||
<div
|
||||
className={classNames({
|
||||
'chat-bubble markdown': true,
|
||||
'chat-bubble bg-transparent': !isUser,
|
||||
})}
|
||||
>
|
||||
{/* textarea for editing message */}
|
||||
{editingContent !== null && (
|
||||
<>
|
||||
<textarea
|
||||
dir="auto"
|
||||
className="textarea textarea-bordered bg-base-100 text-base-content max-w-2xl w-[calc(90vw-8em)] h-24"
|
||||
value={editingContent}
|
||||
onChange={(e) => setEditingContent(e.target.value)}
|
||||
></textarea>
|
||||
<br />
|
||||
<button
|
||||
className="btn btn-ghost mt-2 mr-2"
|
||||
onClick={() => setEditingContent(null)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn mt-2"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
setEditingContent(null);
|
||||
onEditMessage(msg as Message, editingContent);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
{/* not editing content, render message */}
|
||||
{editingContent === null && (
|
||||
<>
|
||||
{content === null ? (
|
||||
<>
|
||||
{/* show loading dots for pending message */}
|
||||
<span className="loading loading-dots loading-md"></span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
{/* render message as markdown */}
|
||||
<div dir="auto" tabIndex={0}>
|
||||
{thought && (
|
||||
<ThoughtProcess
|
||||
isThinking={!!isThinking && !!isPending}
|
||||
content={thought}
|
||||
open={config.showThoughtInProgress}
|
||||
/>
|
||||
)}
|
||||
|
||||
<MarkdownDisplay
|
||||
content={content}
|
||||
isGenerating={isPending}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{/* render timings if enabled */}
|
||||
{timings && config.showTokensPerSecond && (
|
||||
<div className="dropdown dropdown-hover dropdown-top mt-2">
|
||||
<div
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
className="cursor-pointer font-semibold text-sm opacity-60"
|
||||
>
|
||||
Speed: {timings.predicted_per_second.toFixed(1)} t/s
|
||||
</div>
|
||||
<div className="dropdown-content bg-base-100 z-10 w-64 p-2 shadow mt-4">
|
||||
<b>Prompt</b>
|
||||
<br />- Tokens: {timings.prompt_n}
|
||||
<br />- Time: {timings.prompt_ms} ms
|
||||
<br />- Speed: {timings.prompt_per_second.toFixed(1)} t/s
|
||||
<br />
|
||||
<b>Generation</b>
|
||||
<br />- Tokens: {timings.predicted_n}
|
||||
<br />- Time: {timings.predicted_ms} ms
|
||||
<br />- Speed: {timings.predicted_per_second.toFixed(1)} t/s
|
||||
<br />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* actions for each message */}
|
||||
{msg.content !== null && (
|
||||
<div
|
||||
className={classNames({
|
||||
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
|
||||
'flex-row-reverse': msg.role === 'user',
|
||||
})}
|
||||
>
|
||||
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
|
||||
<div
|
||||
className="flex gap-1 items-center opacity-60 text-sm"
|
||||
role="navigation"
|
||||
aria-description={`Message version ${siblingCurrIdx + 1} of ${siblingLeafNodeIds.length}`}
|
||||
>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !prevSibling,
|
||||
})}
|
||||
onClick={() => prevSibling && onChangeSibling(prevSibling)}
|
||||
aria-label="Previous message version"
|
||||
>
|
||||
<ChevronLeftIcon className="h-4 w-4" />
|
||||
</button>
|
||||
<span>
|
||||
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
|
||||
</span>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-ghost p-1': true,
|
||||
'opacity-20': !nextSibling,
|
||||
})}
|
||||
onClick={() => nextSibling && onChangeSibling(nextSibling)}
|
||||
aria-label="Next message version"
|
||||
>
|
||||
<ChevronRightIcon className="h-4 w-4" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* user message */}
|
||||
{msg.role === 'user' && (
|
||||
<BtnWithTooltips
|
||||
className="btn-mini w-8 h-8"
|
||||
onClick={() => setEditingContent(msg.content)}
|
||||
disabled={msg.content === null}
|
||||
tooltipsContent="Edit message"
|
||||
>
|
||||
<PencilSquareIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
)}
|
||||
{/* assistant message */}
|
||||
{msg.role === 'assistant' && (
|
||||
<>
|
||||
{!isPending && (
|
||||
<BtnWithTooltips
|
||||
className="btn-mini w-8 h-8"
|
||||
onClick={() => {
|
||||
if (msg.content !== null) {
|
||||
onRegenerateMessage(msg as Message);
|
||||
}
|
||||
}}
|
||||
disabled={msg.content === null}
|
||||
tooltipsContent="Regenerate response"
|
||||
>
|
||||
<ArrowPathIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<CopyButton className="btn-mini w-8 h-8" content={msg.content} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ThoughtProcess({
|
||||
isThinking,
|
||||
content,
|
||||
open,
|
||||
}: {
|
||||
isThinking: boolean;
|
||||
content: string;
|
||||
open: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
aria-label="Toggle thought process display"
|
||||
tabIndex={0}
|
||||
className={classNames({
|
||||
'collapse bg-none': true,
|
||||
})}
|
||||
>
|
||||
<input type="checkbox" defaultChecked={open} />
|
||||
<div className="collapse-title px-0">
|
||||
<div className="btn rounded-xl">
|
||||
{isThinking ? (
|
||||
<span>
|
||||
<span
|
||||
className="loading loading-spinner loading-md mr-2"
|
||||
style={{ verticalAlign: 'middle' }}
|
||||
></span>
|
||||
Thinking
|
||||
</span>
|
||||
) : (
|
||||
<>Thought Process</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className="collapse-content text-base-content/70 text-sm p-1"
|
||||
tabIndex={0}
|
||||
aria-description="Thought process content"
|
||||
>
|
||||
<div className="border-l-2 border-base-content/20 pl-4 mb-4">
|
||||
<MarkdownDisplay content={content} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,459 +0,0 @@
|
||||
import { ClipboardEvent, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
|
||||
import ChatMessage from './ChatMessage';
|
||||
import { CanvasType, Message, PendingMessage } from '../utils/types';
|
||||
import { classNames, cleanCurrentUrl } from '../utils/misc';
|
||||
import CanvasPyInterpreter from './CanvasPyInterpreter';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useVSCodeContext } from '../utils/llama-vscode';
|
||||
import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts';
|
||||
import {
|
||||
ArrowUpIcon,
|
||||
StopIcon,
|
||||
PaperClipIcon,
|
||||
} from '@heroicons/react/24/solid';
|
||||
import {
|
||||
ChatExtraContextApi,
|
||||
useChatExtraContext,
|
||||
} from './useChatExtraContext.tsx';
|
||||
import Dropzone from 'react-dropzone';
|
||||
import toast from 'react-hot-toast';
|
||||
import ChatInputExtraContextItem from './ChatInputExtraContextItem.tsx';
|
||||
import { scrollToBottom, useChatScroll } from './useChatScroll.tsx';
|
||||
|
||||
/**
|
||||
* A message display is a message node with additional information for rendering.
|
||||
* For example, siblings of the message node are stored as their last node (aka leaf node).
|
||||
*/
|
||||
export interface MessageDisplay {
|
||||
msg: Message | PendingMessage;
|
||||
siblingLeafNodeIds: Message['id'][];
|
||||
siblingCurrIdx: number;
|
||||
isPending?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* If the current URL contains "?m=...", prefill the message input with the value.
|
||||
* If the current URL contains "?q=...", prefill and SEND the message.
|
||||
*/
|
||||
const prefilledMsg = {
|
||||
content() {
|
||||
const url = new URL(window.location.href);
|
||||
return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
|
||||
},
|
||||
shouldSend() {
|
||||
const url = new URL(window.location.href);
|
||||
return url.searchParams.has('q');
|
||||
},
|
||||
clear() {
|
||||
cleanCurrentUrl(['m', 'q']);
|
||||
},
|
||||
};
|
||||
|
||||
function getListMessageDisplay(
|
||||
msgs: Readonly<Message[]>,
|
||||
leafNodeId: Message['id']
|
||||
): MessageDisplay[] {
|
||||
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
|
||||
const res: MessageDisplay[] = [];
|
||||
const nodeMap = new Map<Message['id'], Message>();
|
||||
for (const msg of msgs) {
|
||||
nodeMap.set(msg.id, msg);
|
||||
}
|
||||
// find leaf node from a message node
|
||||
const findLeafNode = (msgId: Message['id']): Message['id'] => {
|
||||
let currNode: Message | undefined = nodeMap.get(msgId);
|
||||
while (currNode) {
|
||||
if (currNode.children.length === 0) break;
|
||||
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
|
||||
}
|
||||
return currNode?.id ?? -1;
|
||||
};
|
||||
// traverse the current nodes
|
||||
for (const msg of currNodes) {
|
||||
const parentNode = nodeMap.get(msg.parent ?? -1);
|
||||
if (!parentNode) continue;
|
||||
const siblings = parentNode.children;
|
||||
if (msg.type !== 'root') {
|
||||
res.push({
|
||||
msg,
|
||||
siblingLeafNodeIds: siblings.map(findLeafNode),
|
||||
siblingCurrIdx: siblings.indexOf(msg.id),
|
||||
});
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
export default function ChatScreen() {
|
||||
const {
|
||||
viewingChat,
|
||||
sendMessage,
|
||||
isGenerating,
|
||||
stopGenerating,
|
||||
pendingMessages,
|
||||
canvasData,
|
||||
replaceMessageAndGenerate,
|
||||
} = useAppContext();
|
||||
|
||||
const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content());
|
||||
const extraContext = useChatExtraContext();
|
||||
useVSCodeContext(textarea, extraContext);
|
||||
|
||||
const msgListRef = useRef<HTMLDivElement>(null);
|
||||
useChatScroll(msgListRef);
|
||||
|
||||
// keep track of leaf node for rendering
|
||||
const [currNodeId, setCurrNodeId] = useState<number>(-1);
|
||||
const messages: MessageDisplay[] = useMemo(() => {
|
||||
if (!viewingChat) return [];
|
||||
else return getListMessageDisplay(viewingChat.messages, currNodeId);
|
||||
}, [currNodeId, viewingChat]);
|
||||
|
||||
const currConvId = viewingChat?.conv.id ?? null;
|
||||
const pendingMsg: PendingMessage | undefined =
|
||||
pendingMessages[currConvId ?? ''];
|
||||
|
||||
useEffect(() => {
|
||||
// reset to latest node when conversation changes
|
||||
setCurrNodeId(-1);
|
||||
// scroll to bottom when conversation changes
|
||||
scrollToBottom(false, 1);
|
||||
}, [currConvId]);
|
||||
|
||||
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
|
||||
if (currLeafNodeId) {
|
||||
setCurrNodeId(currLeafNodeId);
|
||||
}
|
||||
// useChatScroll will handle the auto scroll
|
||||
};
|
||||
|
||||
const sendNewMessage = async () => {
|
||||
const lastInpMsg = textarea.value();
|
||||
if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? '')) {
|
||||
toast.error('Please enter a message');
|
||||
return;
|
||||
}
|
||||
textarea.setValue('');
|
||||
scrollToBottom(false);
|
||||
setCurrNodeId(-1);
|
||||
// get the last message node
|
||||
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
|
||||
if (
|
||||
!(await sendMessage(
|
||||
currConvId,
|
||||
lastMsgNodeId,
|
||||
lastInpMsg,
|
||||
extraContext.items,
|
||||
onChunk
|
||||
))
|
||||
) {
|
||||
// restore the input message if failed
|
||||
textarea.setValue(lastInpMsg);
|
||||
}
|
||||
// OK
|
||||
extraContext.clearItems();
|
||||
};
|
||||
|
||||
// for vscode context
|
||||
textarea.refOnSubmit.current = sendNewMessage;
|
||||
|
||||
const handleEditMessage = async (msg: Message, content: string) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.id);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
content,
|
||||
msg.extra,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const handleRegenerateMessage = async (msg: Message) => {
|
||||
if (!viewingChat) return;
|
||||
setCurrNodeId(msg.parent);
|
||||
scrollToBottom(false);
|
||||
await replaceMessageAndGenerate(
|
||||
viewingChat.conv.id,
|
||||
msg.parent,
|
||||
null,
|
||||
msg.extra,
|
||||
onChunk
|
||||
);
|
||||
setCurrNodeId(-1);
|
||||
scrollToBottom(false);
|
||||
};
|
||||
|
||||
const hasCanvas = !!canvasData;
|
||||
|
||||
useEffect(() => {
|
||||
if (prefilledMsg.shouldSend()) {
|
||||
// send the prefilled message if needed
|
||||
sendNewMessage();
|
||||
} else {
|
||||
// otherwise, focus on the input
|
||||
textarea.focus();
|
||||
}
|
||||
prefilledMsg.clear();
|
||||
// no need to keep track of sendNewMessage
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [textarea.ref]);
|
||||
|
||||
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
|
||||
const pendingMsgDisplay: MessageDisplay[] =
|
||||
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
|
||||
? [
|
||||
{
|
||||
msg: pendingMsg,
|
||||
siblingLeafNodeIds: [],
|
||||
siblingCurrIdx: 0,
|
||||
isPending: true,
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames({
|
||||
'grid lg:gap-8 grow transition-[300ms]': true,
|
||||
'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
|
||||
'grid-cols-[1fr_0fr]': !hasCanvas,
|
||||
})}
|
||||
>
|
||||
<div
|
||||
className={classNames({
|
||||
'flex flex-col w-full max-w-[900px] mx-auto': true,
|
||||
'hidden lg:flex': hasCanvas, // adapted for mobile
|
||||
flex: !hasCanvas,
|
||||
})}
|
||||
>
|
||||
{/* chat messages */}
|
||||
<div id="messages-list" className="grow" ref={msgListRef}>
|
||||
<div className="mt-auto flex flex-col items-center">
|
||||
{/* placeholder to shift the message to the bottom */}
|
||||
{viewingChat ? (
|
||||
''
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-4">Send a message to start</div>
|
||||
<ServerInfo />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{[...messages, ...pendingMsgDisplay].map((msg) => (
|
||||
<ChatMessage
|
||||
key={msg.msg.id}
|
||||
msg={msg.msg}
|
||||
siblingLeafNodeIds={msg.siblingLeafNodeIds}
|
||||
siblingCurrIdx={msg.siblingCurrIdx}
|
||||
onRegenerateMessage={handleRegenerateMessage}
|
||||
onEditMessage={handleEditMessage}
|
||||
onChangeSibling={setCurrNodeId}
|
||||
isPending={msg.isPending}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* chat input */}
|
||||
<ChatInput
|
||||
textarea={textarea}
|
||||
extraContext={extraContext}
|
||||
onSend={sendNewMessage}
|
||||
onStop={() => stopGenerating(currConvId ?? '')}
|
||||
isGenerating={isGenerating(currConvId ?? '')}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
|
||||
{canvasData?.type === CanvasType.PY_INTERPRETER && (
|
||||
<CanvasPyInterpreter />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ServerInfo() {
|
||||
const { serverProps } = useAppContext();
|
||||
const modalities = [];
|
||||
if (serverProps?.modalities?.audio) {
|
||||
modalities.push('audio');
|
||||
}
|
||||
if (serverProps?.modalities?.vision) {
|
||||
modalities.push('vision');
|
||||
}
|
||||
return (
|
||||
<div
|
||||
className="card card-sm shadow-sm border-1 border-base-content/20 text-base-content/70 mb-6"
|
||||
tabIndex={0}
|
||||
aria-description="Server information"
|
||||
>
|
||||
<div className="card-body">
|
||||
<b>Server Info</b>
|
||||
<p>
|
||||
<b>Model</b>: {serverProps?.model_path?.split(/(\\|\/)/).pop()}
|
||||
<br />
|
||||
<b>Build</b>: {serverProps?.build_info}
|
||||
<br />
|
||||
{modalities.length > 0 ? (
|
||||
<>
|
||||
<b>Supported modalities:</b> {modalities.join(', ')}
|
||||
</>
|
||||
) : (
|
||||
''
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ChatInput({
|
||||
textarea,
|
||||
extraContext,
|
||||
onSend,
|
||||
onStop,
|
||||
isGenerating,
|
||||
}: {
|
||||
textarea: ChatTextareaApi;
|
||||
extraContext: ChatExtraContextApi;
|
||||
onSend: () => void;
|
||||
onStop: () => void;
|
||||
isGenerating: boolean;
|
||||
}) {
|
||||
const { config } = useAppContext();
|
||||
const [isDrag, setIsDrag] = useState(false);
|
||||
|
||||
return (
|
||||
<div
|
||||
role="group"
|
||||
aria-label="Chat input"
|
||||
className={classNames({
|
||||
'flex items-end pt-8 pb-6 sticky bottom-0 bg-base-100': true,
|
||||
'opacity-50': isDrag, // simply visual feedback to inform user that the file will be accepted
|
||||
})}
|
||||
>
|
||||
<Dropzone
|
||||
noClick
|
||||
onDrop={(files: File[]) => {
|
||||
setIsDrag(false);
|
||||
extraContext.onFileAdded(files);
|
||||
}}
|
||||
onDragEnter={() => setIsDrag(true)}
|
||||
onDragLeave={() => setIsDrag(false)}
|
||||
multiple={true}
|
||||
>
|
||||
{({ getRootProps, getInputProps }) => (
|
||||
<div
|
||||
className="flex flex-col rounded-xl border-1 border-base-content/30 p-3 w-full"
|
||||
// when a file is pasted to the input, we handle it here
|
||||
// if a text is pasted, and if it is long text, we will convert it to a file
|
||||
onPasteCapture={(e: ClipboardEvent<HTMLInputElement>) => {
|
||||
const text = e.clipboardData.getData('text/plain');
|
||||
if (
|
||||
text.length > 0 &&
|
||||
config.pasteLongTextToFileLen > 0 &&
|
||||
text.length > config.pasteLongTextToFileLen
|
||||
) {
|
||||
// if the text is too long, we will convert it to a file
|
||||
extraContext.addItems([
|
||||
{
|
||||
type: 'context',
|
||||
name: 'Pasted Content',
|
||||
content: text,
|
||||
},
|
||||
]);
|
||||
e.preventDefault();
|
||||
return;
|
||||
}
|
||||
|
||||
// if a file is pasted, we will handle it here
|
||||
const files = Array.from(e.clipboardData.items)
|
||||
.filter((item) => item.kind === 'file')
|
||||
.map((item) => item.getAsFile())
|
||||
.filter((file) => file !== null);
|
||||
|
||||
if (files.length > 0) {
|
||||
e.preventDefault();
|
||||
extraContext.onFileAdded(files);
|
||||
}
|
||||
}}
|
||||
{...getRootProps()}
|
||||
>
|
||||
{!isGenerating && (
|
||||
<ChatInputExtraContextItem
|
||||
items={extraContext.items}
|
||||
removeItem={extraContext.removeItem}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex flex-row w-full">
|
||||
<textarea
|
||||
// Default (mobile): Enable vertical resize, overflow auto for scrolling if needed
|
||||
// Large screens (lg:): Disable manual resize, apply max-height for autosize limit
|
||||
className="text-md outline-none border-none w-full resize-vertical lg:resize-none lg:max-h-48 lg:overflow-y-auto" // Adjust lg:max-h-48 as needed (e.g., lg:max-h-60)
|
||||
placeholder="Type a message (Shift+Enter to add a new line)"
|
||||
ref={textarea.ref}
|
||||
onInput={textarea.onInput} // Hook's input handler (will only resize height on lg+ screens)
|
||||
onKeyDown={(e) => {
|
||||
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
onSend();
|
||||
}
|
||||
}}
|
||||
id="msg-input"
|
||||
dir="auto"
|
||||
// Set a base height of 2 rows for mobile views
|
||||
// On lg+ screens, the hook will calculate and set the initial height anyway
|
||||
rows={2}
|
||||
></textarea>
|
||||
|
||||
{/* buttons area */}
|
||||
<div className="flex flex-row gap-2 ml-2">
|
||||
<label
|
||||
htmlFor="file-upload"
|
||||
className={classNames({
|
||||
'btn w-8 h-8 p-0 rounded-full': true,
|
||||
'btn-disabled': isGenerating,
|
||||
})}
|
||||
aria-label="Upload file"
|
||||
tabIndex={0}
|
||||
role="button"
|
||||
>
|
||||
<PaperClipIcon className="h-5 w-5" />
|
||||
</label>
|
||||
<input
|
||||
id="file-upload"
|
||||
type="file"
|
||||
disabled={isGenerating}
|
||||
{...getInputProps()}
|
||||
hidden
|
||||
/>
|
||||
{isGenerating ? (
|
||||
<button
|
||||
className="btn btn-neutral w-8 h-8 p-0 rounded-full"
|
||||
onClick={onStop}
|
||||
>
|
||||
<StopIcon className="h-5 w-5" />
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="btn btn-primary w-8 h-8 p-0 rounded-full"
|
||||
onClick={onSend}
|
||||
aria-label="Send message"
|
||||
>
|
||||
<ArrowUpIcon className="h-5 w-5" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Dropzone>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
import { useEffect, useState } from 'react';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { classNames } from '../utils/misc';
|
||||
import daisyuiThemes from 'daisyui/theme/object';
|
||||
import { THEMES } from '../Config';
|
||||
import {
|
||||
Cog8ToothIcon,
|
||||
MoonIcon,
|
||||
Bars3Icon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
|
||||
export default function Header() {
|
||||
const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme());
|
||||
const { setShowSettings } = useAppContext();
|
||||
|
||||
const setTheme = (theme: string) => {
|
||||
StorageUtils.setTheme(theme);
|
||||
setSelectedTheme(theme);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
document.body.setAttribute('data-theme', selectedTheme);
|
||||
document.body.setAttribute(
|
||||
'data-color-scheme',
|
||||
daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto'
|
||||
);
|
||||
}, [selectedTheme]);
|
||||
|
||||
return (
|
||||
<div className="flex flex-row items-center pt-6 pb-6 sticky top-0 z-10 bg-base-100">
|
||||
{/* open sidebar button */}
|
||||
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
|
||||
<Bars3Icon className="h-5 w-5" />
|
||||
</label>
|
||||
|
||||
<div className="grow text-2xl font-bold ml-2">llama.cpp</div>
|
||||
|
||||
{/* action buttons (top right) */}
|
||||
<div className="flex items-center">
|
||||
<div
|
||||
className="tooltip tooltip-bottom"
|
||||
data-tip="Settings"
|
||||
onClick={() => setShowSettings(true)}
|
||||
>
|
||||
<button className="btn" aria-hidden={true}>
|
||||
{/* settings button */}
|
||||
<Cog8ToothIcon className="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* theme controller is copied from https://daisyui.com/components/theme-controller/ */}
|
||||
<div className="tooltip tooltip-bottom" data-tip="Themes">
|
||||
<div className="dropdown dropdown-end dropdown-bottom">
|
||||
<div tabIndex={0} role="button" className="btn m-1">
|
||||
<MoonIcon className="w-5 h-5" />
|
||||
</div>
|
||||
<ul
|
||||
tabIndex={0}
|
||||
className="dropdown-content bg-base-300 rounded-box z-[1] w-52 p-2 shadow-2xl h-80 overflow-y-auto"
|
||||
>
|
||||
<li>
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-sm btn-block btn-ghost justify-start': true,
|
||||
'btn-active': selectedTheme === 'auto',
|
||||
})}
|
||||
onClick={() => setTheme('auto')}
|
||||
>
|
||||
auto
|
||||
</button>
|
||||
</li>
|
||||
{THEMES.map((theme) => (
|
||||
<li key={theme}>
|
||||
<input
|
||||
type="radio"
|
||||
name="theme-dropdown"
|
||||
className="theme-controller btn btn-sm btn-block btn-ghost justify-start"
|
||||
aria-label={theme}
|
||||
value={theme}
|
||||
checked={selectedTheme === theme}
|
||||
onChange={(e) => e.target.checked && setTheme(theme)}
|
||||
/>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
import React, { useMemo, useState } from 'react';
|
||||
import Markdown, { ExtraProps } from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
import rehypeHightlight from 'rehype-highlight';
|
||||
import rehypeKatex from 'rehype-katex';
|
||||
import remarkMath from 'remark-math';
|
||||
import remarkBreaks from 'remark-breaks';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import { classNames, copyStr } from '../utils/misc';
|
||||
import { ElementContent, Root } from 'hast';
|
||||
import { visit } from 'unist-util-visit';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { CanvasType } from '../utils/types';
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
import { DocumentDuplicateIcon, PlayIcon } from '@heroicons/react/24/outline';
|
||||
|
||||
export default function MarkdownDisplay({
|
||||
content,
|
||||
isGenerating,
|
||||
}: {
|
||||
content: string;
|
||||
isGenerating?: boolean;
|
||||
}) {
|
||||
const preprocessedContent = useMemo(
|
||||
() => preprocessLaTeX(content),
|
||||
[content]
|
||||
);
|
||||
return (
|
||||
<Markdown
|
||||
remarkPlugins={[remarkGfm, remarkMath, remarkBreaks]}
|
||||
rehypePlugins={[rehypeHightlight, rehypeKatex, rehypeCustomCopyButton]}
|
||||
components={{
|
||||
button: (props) => (
|
||||
<CodeBlockButtons
|
||||
{...props}
|
||||
isGenerating={isGenerating}
|
||||
origContent={preprocessedContent}
|
||||
/>
|
||||
),
|
||||
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
|
||||
}}
|
||||
>
|
||||
{preprocessedContent}
|
||||
</Markdown>
|
||||
);
|
||||
}
|
||||
|
||||
const CodeBlockButtons: React.ElementType<
|
||||
React.ClassAttributes<HTMLButtonElement> &
|
||||
React.HTMLAttributes<HTMLButtonElement> &
|
||||
ExtraProps & { origContent: string; isGenerating?: boolean }
|
||||
> = ({ node, origContent, isGenerating }) => {
|
||||
const { config } = useAppContext();
|
||||
const startOffset = node?.position?.start.offset ?? 0;
|
||||
const endOffset = node?.position?.end.offset ?? 0;
|
||||
|
||||
const copiedContent = useMemo(
|
||||
() =>
|
||||
origContent
|
||||
.substring(startOffset, endOffset)
|
||||
.replace(/^```[^\n]+\n/g, '')
|
||||
.replace(/```$/g, ''),
|
||||
[origContent, startOffset, endOffset]
|
||||
);
|
||||
|
||||
const codeLanguage = useMemo(
|
||||
() =>
|
||||
origContent
|
||||
.substring(startOffset, startOffset + 10)
|
||||
.match(/^```([^\n]+)\n/)?.[1] ?? '',
|
||||
[origContent, startOffset]
|
||||
);
|
||||
|
||||
const canRunCode =
|
||||
!isGenerating &&
|
||||
config.pyIntepreterEnabled &&
|
||||
codeLanguage.startsWith('py');
|
||||
|
||||
return (
|
||||
<div
|
||||
className={classNames({
|
||||
'text-right sticky top-[7em] mb-2 mr-2 h-0': true,
|
||||
'display-none': !node?.position,
|
||||
})}
|
||||
>
|
||||
<CopyButton
|
||||
className="badge btn-mini btn-soft shadow-sm"
|
||||
content={copiedContent}
|
||||
/>
|
||||
{canRunCode && (
|
||||
<RunPyCodeButton
|
||||
className="badge btn-mini shadow-sm ml-2"
|
||||
content={copiedContent}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const CopyButton = ({
|
||||
content,
|
||||
className,
|
||||
}: {
|
||||
content: string;
|
||||
className?: string;
|
||||
}) => {
|
||||
const [copied, setCopied] = useState(false);
|
||||
return (
|
||||
<BtnWithTooltips
|
||||
className={className}
|
||||
onClick={() => {
|
||||
copyStr(content);
|
||||
setCopied(true);
|
||||
}}
|
||||
onMouseLeave={() => setCopied(false)}
|
||||
tooltipsContent={copied ? 'Copied!' : 'Copy'}
|
||||
>
|
||||
<DocumentDuplicateIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
);
|
||||
};
|
||||
|
||||
export const RunPyCodeButton = ({
|
||||
content,
|
||||
className,
|
||||
}: {
|
||||
content: string;
|
||||
className?: string;
|
||||
}) => {
|
||||
const { setCanvasData } = useAppContext();
|
||||
return (
|
||||
<>
|
||||
<BtnWithTooltips
|
||||
className={className}
|
||||
onClick={() =>
|
||||
setCanvasData({
|
||||
type: CanvasType.PY_INTERPRETER,
|
||||
content,
|
||||
})
|
||||
}
|
||||
tooltipsContent="Run code"
|
||||
>
|
||||
<PlayIcon className="h-4 w-4" />
|
||||
</BtnWithTooltips>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* This injects the "button" element before each "pre" element.
|
||||
* The actual button will be replaced with a react component in the MarkdownDisplay.
|
||||
* We don't replace "pre" node directly because it will cause the node to re-render, which causes this bug: https://github.com/ggerganov/llama.cpp/issues/9608
|
||||
*/
|
||||
function rehypeCustomCopyButton() {
|
||||
return function (tree: Root) {
|
||||
visit(tree, 'element', function (node) {
|
||||
if (node.tagName === 'pre' && !node.properties.visited) {
|
||||
const preNode = { ...node };
|
||||
// replace current node
|
||||
preNode.properties.visited = 'true';
|
||||
node.tagName = 'div';
|
||||
node.properties = {};
|
||||
// add node for button
|
||||
const btnNode: ElementContent = {
|
||||
type: 'element',
|
||||
tagName: 'button',
|
||||
properties: {},
|
||||
children: [],
|
||||
position: node.position,
|
||||
};
|
||||
node.children = [btnNode, preNode];
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* The part below is copied and adapted from:
|
||||
* https://github.com/danny-avila/LibreChat/blob/main/client/src/utils/latex.ts
|
||||
* (MIT License)
|
||||
*/
|
||||
|
||||
// Regex to check if the processed content contains any potential LaTeX patterns
|
||||
const containsLatexRegex =
|
||||
/\\\(.*?\\\)|\\\[.*?\\\]|\$.*?\$|\\begin\{equation\}.*?\\end\{equation\}/;
|
||||
|
||||
// Regex for inline and block LaTeX expressions
|
||||
const inlineLatex = new RegExp(/\\\((.+?)\\\)/, 'g');
|
||||
const blockLatex = new RegExp(/\\\[(.*?[^\\])\\\]/, 'gs');
|
||||
|
||||
// Function to restore code blocks
|
||||
const restoreCodeBlocks = (content: string, codeBlocks: string[]) => {
|
||||
return content.replace(
|
||||
/<<CODE_BLOCK_(\d+)>>/g,
|
||||
(_, index) => codeBlocks[index]
|
||||
);
|
||||
};
|
||||
|
||||
// Regex to identify code blocks and inline code
|
||||
const codeBlockRegex = /(```[\s\S]*?```|`.*?`)/g;
|
||||
|
||||
export const processLaTeX = (_content: string) => {
|
||||
let content = _content;
|
||||
// Temporarily replace code blocks and inline code with placeholders
|
||||
const codeBlocks: string[] = [];
|
||||
let index = 0;
|
||||
content = content.replace(codeBlockRegex, (match) => {
|
||||
codeBlocks[index] = match;
|
||||
return `<<CODE_BLOCK_${index++}>>`;
|
||||
});
|
||||
|
||||
// Escape dollar signs followed by a digit or space and digit
|
||||
let processedContent = content.replace(/(\$)(?=\s?\d)/g, '\\$');
|
||||
|
||||
// If no LaTeX patterns are found, restore code blocks and return the processed content
|
||||
if (!containsLatexRegex.test(processedContent)) {
|
||||
return restoreCodeBlocks(processedContent, codeBlocks);
|
||||
}
|
||||
|
||||
// Convert LaTeX expressions to a markdown compatible format
|
||||
processedContent = processedContent
|
||||
.replace(inlineLatex, (_: string, equation: string) => `$${equation}$`) // Convert inline LaTeX
|
||||
.replace(blockLatex, (_: string, equation: string) => `$$${equation}$$`); // Convert block LaTeX
|
||||
|
||||
// Restore code blocks
|
||||
return restoreCodeBlocks(processedContent, codeBlocks);
|
||||
};
|
||||
|
||||
/**
|
||||
* Preprocesses LaTeX content by replacing delimiters and escaping certain characters.
|
||||
*
|
||||
* @param content The input string containing LaTeX expressions.
|
||||
* @returns The processed string with replaced delimiters and escaped characters.
|
||||
*/
|
||||
export function preprocessLaTeX(content: string): string {
|
||||
// Step 1: Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
content = content.replace(/(```[\s\S]*?```|`[^`\n]+`)/g, (_, code) => {
|
||||
codeBlocks.push(code);
|
||||
return `<<CODE_BLOCK_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Step 2: Protect existing LaTeX expressions
|
||||
const latexExpressions: string[] = [];
|
||||
|
||||
// Protect block math ($$...$$), \[...\], and \(...\) as before.
|
||||
content = content.replace(
|
||||
/(\$\$[\s\S]*?\$\$|\\\[[\s\S]*?\\\]|\\\(.*?\\\))/g,
|
||||
(match) => {
|
||||
latexExpressions.push(match);
|
||||
return `<<LATEX_${latexExpressions.length - 1}>>`;
|
||||
}
|
||||
);
|
||||
|
||||
// Protect inline math ($...$) only if it does NOT match a currency pattern.
|
||||
// We assume a currency pattern is one where the inner content is purely numeric (with optional decimals).
|
||||
content = content.replace(/\$([^$]+)\$/g, (match, inner) => {
|
||||
if (/^\s*\d+(?:\.\d+)?\s*$/.test(inner)) {
|
||||
// This looks like a currency value (e.g. "$123" or "$12.34"),
|
||||
// so don't protect it.
|
||||
return match;
|
||||
} else {
|
||||
// Otherwise, treat it as a LaTeX expression.
|
||||
latexExpressions.push(match);
|
||||
return `<<LATEX_${latexExpressions.length - 1}>>`;
|
||||
}
|
||||
});
|
||||
|
||||
// Step 3: Escape dollar signs that are likely currency indicators.
|
||||
// (Now that inline math is protected, this will only escape dollars not already protected)
|
||||
content = content.replace(/\$(?=\d)/g, '\\$');
|
||||
|
||||
// Step 4: Restore LaTeX expressions
|
||||
content = content.replace(
|
||||
/<<LATEX_(\d+)>>/g,
|
||||
(_, index) => latexExpressions[parseInt(index)]
|
||||
);
|
||||
|
||||
// Step 5: Restore code blocks
|
||||
content = content.replace(
|
||||
/<<CODE_BLOCK_(\d+)>>/g,
|
||||
(_, index) => codeBlocks[parseInt(index)]
|
||||
);
|
||||
|
||||
// Step 6: Apply additional escaping functions
|
||||
content = escapeBrackets(content);
|
||||
content = escapeMhchem(content);
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
export function escapeBrackets(text: string): string {
|
||||
const pattern =
|
||||
/(```[\S\s]*?```|`.*?`)|\\\[([\S\s]*?[^\\])\\]|\\\((.*?)\\\)/g;
|
||||
return text.replace(
|
||||
pattern,
|
||||
(
|
||||
match: string,
|
||||
codeBlock: string | undefined,
|
||||
squareBracket: string | undefined,
|
||||
roundBracket: string | undefined
|
||||
): string => {
|
||||
if (codeBlock != null) {
|
||||
return codeBlock;
|
||||
} else if (squareBracket != null) {
|
||||
return `$$${squareBracket}$$`;
|
||||
} else if (roundBracket != null) {
|
||||
return `$${roundBracket}$`;
|
||||
}
|
||||
return match;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function escapeMhchem(text: string) {
|
||||
return text.replaceAll('$\\ce{', '$\\\\ce{').replaceAll('$\\pu{', '$\\\\pu{');
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
import React, { createContext, useState, useContext } from 'react';
|
||||
|
||||
type ModalContextType = {
|
||||
showConfirm: (message: string) => Promise<boolean>;
|
||||
showPrompt: (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
) => Promise<string | undefined>;
|
||||
showAlert: (message: string) => Promise<void>;
|
||||
};
|
||||
const ModalContext = createContext<ModalContextType>(null!);
|
||||
|
||||
interface ModalState<T> {
|
||||
isOpen: boolean;
|
||||
message: string;
|
||||
defaultValue?: string;
|
||||
resolve: ((value: T) => void) | null;
|
||||
}
|
||||
|
||||
export function ModalProvider({ children }: { children: React.ReactNode }) {
|
||||
const [confirmState, setConfirmState] = useState<ModalState<boolean>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const [promptState, setPromptState] = useState<
|
||||
ModalState<string | undefined>
|
||||
>({ isOpen: false, message: '', resolve: null });
|
||||
const [alertState, setAlertState] = useState<ModalState<void>>({
|
||||
isOpen: false,
|
||||
message: '',
|
||||
resolve: null,
|
||||
});
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
const showConfirm = (message: string): Promise<boolean> => {
|
||||
return new Promise((resolve) => {
|
||||
setConfirmState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showPrompt = (
|
||||
message: string,
|
||||
defaultValue?: string
|
||||
): Promise<string | undefined> => {
|
||||
return new Promise((resolve) => {
|
||||
setPromptState({ isOpen: true, message, defaultValue, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const showAlert = (message: string): Promise<void> => {
|
||||
return new Promise((resolve) => {
|
||||
setAlertState({ isOpen: true, message, resolve });
|
||||
});
|
||||
};
|
||||
|
||||
const handleConfirm = (result: boolean) => {
|
||||
confirmState.resolve?.(result);
|
||||
setConfirmState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handlePrompt = (result?: string) => {
|
||||
promptState.resolve?.(result);
|
||||
setPromptState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
const handleAlertClose = () => {
|
||||
alertState.resolve?.();
|
||||
setAlertState({ isOpen: false, message: '', resolve: null });
|
||||
};
|
||||
|
||||
return (
|
||||
<ModalContext.Provider value={{ showConfirm, showPrompt, showAlert }}>
|
||||
{children}
|
||||
|
||||
{/* Confirm Modal */}
|
||||
{confirmState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{confirmState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button
|
||||
className="btn btn-ghost"
|
||||
onClick={() => handleConfirm(false)}
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-error"
|
||||
onClick={() => handleConfirm(true)}
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Prompt Modal */}
|
||||
{promptState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{promptState.message}</h3>
|
||||
<input
|
||||
type="text"
|
||||
className="input input-bordered w-full mt-2"
|
||||
defaultValue={promptState.defaultValue}
|
||||
ref={inputRef}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
handlePrompt((e.target as HTMLInputElement).value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<div className="modal-action">
|
||||
<button className="btn btn-ghost" onClick={() => handlePrompt()}>
|
||||
Cancel
|
||||
</button>
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
onClick={() => handlePrompt(inputRef.current?.value)}
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
|
||||
{/* Alert Modal */}
|
||||
{alertState.isOpen && (
|
||||
<dialog className="modal modal-open z-[1100]">
|
||||
<div className="modal-box">
|
||||
<h3 className="font-bold text-lg">{alertState.message}</h3>
|
||||
<div className="modal-action">
|
||||
<button className="btn" onClick={handleAlertClose}>
|
||||
OK
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
)}
|
||||
</ModalContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useModals() {
|
||||
const context = useContext(ModalContext);
|
||||
if (!context) throw new Error('useModals must be used within ModalProvider');
|
||||
return context;
|
||||
}
|
||||
@@ -1,553 +0,0 @@
|
||||
import { useState } from 'react';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
|
||||
import { isDev } from '../Config';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { classNames, isBoolean, isNumeric, isString } from '../utils/misc';
|
||||
import {
|
||||
BeakerIcon,
|
||||
ChatBubbleOvalLeftEllipsisIcon,
|
||||
Cog6ToothIcon,
|
||||
FunnelIcon,
|
||||
HandRaisedIcon,
|
||||
SquaresPlusIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { OpenInNewTab } from '../utils/common';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
type SettKey = keyof typeof CONFIG_DEFAULT;
|
||||
|
||||
const BASIC_KEYS: SettKey[] = [
|
||||
'temperature',
|
||||
'top_k',
|
||||
'top_p',
|
||||
'min_p',
|
||||
'max_tokens',
|
||||
];
|
||||
const SAMPLER_KEYS: SettKey[] = [
|
||||
'dynatemp_range',
|
||||
'dynatemp_exponent',
|
||||
'typical_p',
|
||||
'xtc_probability',
|
||||
'xtc_threshold',
|
||||
];
|
||||
const PENALTY_KEYS: SettKey[] = [
|
||||
'repeat_last_n',
|
||||
'repeat_penalty',
|
||||
'presence_penalty',
|
||||
'frequency_penalty',
|
||||
'dry_multiplier',
|
||||
'dry_base',
|
||||
'dry_allowed_length',
|
||||
'dry_penalty_last_n',
|
||||
];
|
||||
|
||||
enum SettingInputType {
|
||||
SHORT_INPUT,
|
||||
LONG_INPUT,
|
||||
CHECKBOX,
|
||||
CUSTOM,
|
||||
}
|
||||
|
||||
interface SettingFieldInput {
|
||||
type: Exclude<SettingInputType, SettingInputType.CUSTOM>;
|
||||
label: string | React.ReactElement;
|
||||
help?: string | React.ReactElement;
|
||||
key: SettKey;
|
||||
}
|
||||
|
||||
interface SettingFieldCustom {
|
||||
type: SettingInputType.CUSTOM;
|
||||
key: SettKey;
|
||||
component:
|
||||
| string
|
||||
| React.FC<{
|
||||
value: string | boolean | number;
|
||||
onChange: (value: string) => void;
|
||||
}>;
|
||||
}
|
||||
|
||||
interface SettingSection {
|
||||
title: React.ReactElement;
|
||||
fields: (SettingFieldInput | SettingFieldCustom)[];
|
||||
}
|
||||
|
||||
const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline';
|
||||
|
||||
const SETTING_SECTIONS: SettingSection[] = [
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<Cog6ToothIcon className={ICON_CLASSNAME} />
|
||||
General
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'API Key',
|
||||
key: 'apiKey',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.LONG_INPUT,
|
||||
label: 'System Message (will be disabled if left empty)',
|
||||
key: 'systemMessage',
|
||||
},
|
||||
...BASIC_KEYS.map(
|
||||
(key) =>
|
||||
({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'Paste length to file',
|
||||
key: 'pasteLongTextToFileLen',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Parse PDF as image instead of text',
|
||||
key: 'pdfAsImage',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<FunnelIcon className={ICON_CLASSNAME} />
|
||||
Samplers
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: 'Samplers queue',
|
||||
key: 'samplers',
|
||||
},
|
||||
...SAMPLER_KEYS.map(
|
||||
(key) =>
|
||||
({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
}) as SettingFieldInput
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<HandRaisedIcon className={ICON_CLASSNAME} />
|
||||
Penalties
|
||||
</>
|
||||
),
|
||||
fields: PENALTY_KEYS.map((key) => ({
|
||||
type: SettingInputType.SHORT_INPUT,
|
||||
label: key,
|
||||
key,
|
||||
})),
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<ChatBubbleOvalLeftEllipsisIcon className={ICON_CLASSNAME} />
|
||||
Reasoning
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Expand thought process by default when generating messages',
|
||||
key: 'showThoughtInProgress',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label:
|
||||
'Exclude thought process when sending requests to API (Recommended for DeepSeek-R1)',
|
||||
key: 'excludeThoughtOnReq',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<SquaresPlusIcon className={ICON_CLASSNAME} />
|
||||
Advanced
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CUSTOM,
|
||||
key: 'custom', // dummy key, won't be used
|
||||
component: () => {
|
||||
const debugImportDemoConv = async () => {
|
||||
const res = await fetch('/demo-conversation.json');
|
||||
const demoConv = await res.json();
|
||||
StorageUtils.remove(demoConv.id);
|
||||
for (const msg of demoConv.messages) {
|
||||
StorageUtils.appendMsg(demoConv.id, msg);
|
||||
}
|
||||
};
|
||||
return (
|
||||
<button className="btn" onClick={debugImportDemoConv}>
|
||||
(debug) Import demo conversation
|
||||
</button>
|
||||
);
|
||||
},
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: 'Show tokens per second',
|
||||
key: 'showTokensPerSecond',
|
||||
},
|
||||
{
|
||||
type: SettingInputType.LONG_INPUT,
|
||||
label: (
|
||||
<>
|
||||
Custom JSON config (For more info, refer to{' '}
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/blob/master/tools/server/README.md">
|
||||
server documentation
|
||||
</OpenInNewTab>
|
||||
)
|
||||
</>
|
||||
),
|
||||
key: 'custom',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
title: (
|
||||
<>
|
||||
<BeakerIcon className={ICON_CLASSNAME} />
|
||||
Experimental
|
||||
</>
|
||||
),
|
||||
fields: [
|
||||
{
|
||||
type: SettingInputType.CUSTOM,
|
||||
key: 'custom', // dummy key, won't be used
|
||||
component: () => (
|
||||
<>
|
||||
<p className="mb-8">
|
||||
Experimental features are not guaranteed to work correctly.
|
||||
<br />
|
||||
<br />
|
||||
If you encounter any problems, create a{' '}
|
||||
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/new?template=019-bug-misc.yml">
|
||||
Bug (misc.)
|
||||
</OpenInNewTab>{' '}
|
||||
report on Github. Please also specify <b>webui/experimental</b> on
|
||||
the report title and include screenshots.
|
||||
<br />
|
||||
<br />
|
||||
Some features may require packages downloaded from CDN, so they
|
||||
need internet connection.
|
||||
</p>
|
||||
</>
|
||||
),
|
||||
},
|
||||
{
|
||||
type: SettingInputType.CHECKBOX,
|
||||
label: (
|
||||
<>
|
||||
<b>Enable Python interpreter</b>
|
||||
<br />
|
||||
<small className="text-xs">
|
||||
This feature uses{' '}
|
||||
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
|
||||
downloaded from CDN. To use this feature, ask the LLM to generate
|
||||
Python code inside a Markdown code block. You will see a "Run"
|
||||
button on the code block, near the "Copy" button.
|
||||
</small>
|
||||
</>
|
||||
),
|
||||
key: 'pyIntepreterEnabled',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export default function SettingDialog({
|
||||
show,
|
||||
onClose,
|
||||
}: {
|
||||
show: boolean;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
const { config, saveConfig } = useAppContext();
|
||||
const [sectionIdx, setSectionIdx] = useState(0);
|
||||
|
||||
// clone the config object to prevent direct mutation
|
||||
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
|
||||
JSON.parse(JSON.stringify(config))
|
||||
);
|
||||
const { showConfirm, showAlert } = useModals();
|
||||
|
||||
const resetConfig = async () => {
|
||||
if (await showConfirm('Are you sure you want to reset all settings?')) {
|
||||
setLocalConfig(CONFIG_DEFAULT);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
// copy the local config to prevent direct mutation
|
||||
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
|
||||
JSON.stringify(localConfig)
|
||||
);
|
||||
// validate the config
|
||||
for (const key in newConfig) {
|
||||
const value = newConfig[key as SettKey];
|
||||
const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]);
|
||||
const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]);
|
||||
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
|
||||
if (mustBeString) {
|
||||
if (!isString(value)) {
|
||||
await showAlert(`Value for ${key} must be string`);
|
||||
return;
|
||||
}
|
||||
} else if (mustBeNumeric) {
|
||||
const trimmedValue = value.toString().trim();
|
||||
const numVal = Number(trimmedValue);
|
||||
if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) {
|
||||
await showAlert(`Value for ${key} must be numeric`);
|
||||
return;
|
||||
}
|
||||
// force conversion to number
|
||||
// @ts-expect-error this is safe
|
||||
newConfig[key] = numVal;
|
||||
} else if (mustBeBoolean) {
|
||||
if (!isBoolean(value)) {
|
||||
await showAlert(`Value for ${key} must be boolean`);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
console.error(`Unknown default type for key ${key}`);
|
||||
}
|
||||
}
|
||||
if (isDev) console.log('Saving config', newConfig);
|
||||
saveConfig(newConfig);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const onChange = (key: SettKey) => (value: string | boolean) => {
|
||||
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
|
||||
setLocalConfig({ ...localConfig, [key]: value });
|
||||
};
|
||||
|
||||
return (
|
||||
<dialog
|
||||
className={classNames({ modal: true, 'modal-open': show })}
|
||||
aria-label="Settings dialog"
|
||||
>
|
||||
<div className="modal-box w-11/12 max-w-3xl">
|
||||
<h3 className="text-lg font-bold mb-6">Settings</h3>
|
||||
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
|
||||
{/* Left panel, showing sections - Desktop version */}
|
||||
<div
|
||||
className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200"
|
||||
role="complementary"
|
||||
aria-description="Settings sections"
|
||||
tabIndex={0}
|
||||
>
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<button
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
|
||||
'btn-active': sectionIdx === idx,
|
||||
})}
|
||||
onClick={() => setSectionIdx(idx)}
|
||||
dir="auto"
|
||||
>
|
||||
{section.title}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Left panel, showing sections - Mobile version */}
|
||||
{/* This menu is skipped on a11y, otherwise it's repeated the desktop version */}
|
||||
<div
|
||||
className="md:hidden flex flex-row gap-2 mb-4"
|
||||
aria-disabled={true}
|
||||
>
|
||||
<details className="dropdown">
|
||||
<summary className="btn bt-sm w-full m-1">
|
||||
{SETTING_SECTIONS[sectionIdx].title}
|
||||
</summary>
|
||||
<ul className="menu dropdown-content bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
|
||||
{SETTING_SECTIONS.map((section, idx) => (
|
||||
<div
|
||||
key={idx}
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start font-normal': true,
|
||||
'btn-active': sectionIdx === idx,
|
||||
})}
|
||||
onClick={() => setSectionIdx(idx)}
|
||||
dir="auto"
|
||||
>
|
||||
{section.title}
|
||||
</div>
|
||||
))}
|
||||
</ul>
|
||||
</details>
|
||||
</div>
|
||||
|
||||
{/* Right panel, showing setting fields */}
|
||||
<div className="grow overflow-y-auto px-4">
|
||||
{SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => {
|
||||
const key = `${sectionIdx}-${idx}`;
|
||||
if (field.type === SettingInputType.SHORT_INPUT) {
|
||||
return (
|
||||
<SettingsModalShortInput
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={localConfig[field.key]}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.LONG_INPUT) {
|
||||
return (
|
||||
<SettingsModalLongInput
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={localConfig[field.key].toString()}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.CHECKBOX) {
|
||||
return (
|
||||
<SettingsModalCheckbox
|
||||
key={key}
|
||||
configKey={field.key}
|
||||
value={!!localConfig[field.key]}
|
||||
onChange={onChange(field.key)}
|
||||
label={field.label as string}
|
||||
/>
|
||||
);
|
||||
} else if (field.type === SettingInputType.CUSTOM) {
|
||||
return (
|
||||
<div key={key} className="mb-2">
|
||||
{typeof field.component === 'string'
|
||||
? field.component
|
||||
: field.component({
|
||||
value: localConfig[field.key],
|
||||
onChange: onChange(field.key),
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
|
||||
<p className="opacity-40 mb-6 text-sm mt-8">
|
||||
Settings are saved in browser's localStorage
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="modal-action">
|
||||
<button className="btn" onClick={resetConfig}>
|
||||
Reset to default
|
||||
</button>
|
||||
<button className="btn" onClick={onClose}>
|
||||
Close
|
||||
</button>
|
||||
<button className="btn btn-primary" onClick={handleSave}>
|
||||
Save
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</dialog>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalLongInput({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
label?: string;
|
||||
}) {
|
||||
return (
|
||||
<label className="form-control">
|
||||
<div className="label inline text-sm">{label || configKey}</div>
|
||||
<textarea
|
||||
className="textarea textarea-bordered h-24 mb-2"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalShortInput({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
value: any;
|
||||
onChange: (value: string) => void;
|
||||
label?: string;
|
||||
}) {
|
||||
const helpMsg = CONFIG_INFO[configKey];
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* on mobile, we simply show the help message here */}
|
||||
{helpMsg && (
|
||||
<div className="block mb-1 opacity-75">
|
||||
<p className="text-xs">{helpMsg}</p>
|
||||
</div>
|
||||
)}
|
||||
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
|
||||
<div className="dropdown dropdown-hover">
|
||||
<div tabIndex={0} role="button" className="font-bold hidden md:block">
|
||||
{label || configKey}
|
||||
</div>
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
className="grow"
|
||||
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
/>
|
||||
</label>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsModalCheckbox({
|
||||
configKey,
|
||||
value,
|
||||
onChange,
|
||||
label,
|
||||
}: {
|
||||
configKey: SettKey;
|
||||
value: boolean;
|
||||
onChange: (value: boolean) => void;
|
||||
label: string;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex flex-row items-center mb-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
className="toggle"
|
||||
checked={value}
|
||||
onChange={(e) => onChange(e.target.checked)}
|
||||
/>
|
||||
<span className="ml-4">{label || configKey}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,369 +0,0 @@
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import { classNames } from '../utils/misc';
|
||||
import { Conversation } from '../utils/types';
|
||||
import StorageUtils from '../utils/storage';
|
||||
import { useNavigate, useParams } from 'react-router';
|
||||
import {
|
||||
ArrowDownTrayIcon,
|
||||
EllipsisVerticalIcon,
|
||||
PencilIcon,
|
||||
PencilSquareIcon,
|
||||
TrashIcon,
|
||||
XMarkIcon,
|
||||
} from '@heroicons/react/24/outline';
|
||||
import { BtnWithTooltips } from '../utils/common';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useModals } from './ModalProvider';
|
||||
|
||||
export default function Sidebar() {
|
||||
const params = useParams();
|
||||
const navigate = useNavigate();
|
||||
|
||||
const { isGenerating } = useAppContext();
|
||||
|
||||
const [conversations, setConversations] = useState<Conversation[]>([]);
|
||||
const [currConv, setCurrConv] = useState<Conversation | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
|
||||
}, [params.convId]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleConversationChange = async () => {
|
||||
setConversations(await StorageUtils.getAllConversations());
|
||||
};
|
||||
StorageUtils.onConversationChanged(handleConversationChange);
|
||||
handleConversationChange();
|
||||
return () => {
|
||||
StorageUtils.offConversationChanged(handleConversationChange);
|
||||
};
|
||||
}, []);
|
||||
const { showConfirm, showPrompt } = useModals();
|
||||
|
||||
const groupedConv = useMemo(
|
||||
() => groupConversationsByDate(conversations),
|
||||
[conversations]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<input
|
||||
id="toggle-drawer"
|
||||
type="checkbox"
|
||||
className="drawer-toggle"
|
||||
aria-label="Toggle sidebar"
|
||||
defaultChecked
|
||||
/>
|
||||
|
||||
<div
|
||||
className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64"
|
||||
role="complementary"
|
||||
aria-label="Sidebar"
|
||||
tabIndex={0}
|
||||
>
|
||||
<label
|
||||
htmlFor="toggle-drawer"
|
||||
aria-label="Close sidebar"
|
||||
className="drawer-overlay"
|
||||
></label>
|
||||
|
||||
<a
|
||||
href="#main-scroll"
|
||||
className="absolute -left-80 top-0 w-1 h-1 overflow-hidden"
|
||||
>
|
||||
Skip to main content
|
||||
</a>
|
||||
|
||||
<div className="flex flex-col bg-base-200 min-h-full max-w-64 py-4 px-4">
|
||||
<div className="flex flex-row items-center justify-between mb-4 mt-4">
|
||||
<h2 className="font-bold ml-4" role="heading">
|
||||
Conversations
|
||||
</h2>
|
||||
|
||||
{/* close sidebar button */}
|
||||
<label
|
||||
htmlFor="toggle-drawer"
|
||||
className="btn btn-ghost lg:hidden"
|
||||
aria-label="Close sidebar"
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
>
|
||||
<XMarkIcon className="w-5 h-5" />
|
||||
</label>
|
||||
</div>
|
||||
|
||||
{/* new conversation button */}
|
||||
<button
|
||||
className={classNames({
|
||||
'btn btn-ghost justify-start px-2': true,
|
||||
'btn-soft': !currConv,
|
||||
})}
|
||||
onClick={() => navigate('/')}
|
||||
aria-label="New conversation"
|
||||
>
|
||||
<PencilSquareIcon className="w-5 h-5" />
|
||||
New conversation
|
||||
</button>
|
||||
|
||||
{/* list of conversations */}
|
||||
{groupedConv.map((group, i) => (
|
||||
<div key={i} role="group">
|
||||
{/* group name (by date) */}
|
||||
{group.title ? (
|
||||
// we use btn class here to make sure that the padding/margin are aligned with the other items
|
||||
<b
|
||||
className="btn btn-ghost btn-xs bg-none btn-disabled block text-xs text-base-content text-start px-2 mb-0 mt-6 font-bold"
|
||||
role="note"
|
||||
aria-description={group.title}
|
||||
tabIndex={0}
|
||||
>
|
||||
{group.title}
|
||||
</b>
|
||||
) : (
|
||||
<div className="h-2" />
|
||||
)}
|
||||
|
||||
{group.conversations.map((conv) => (
|
||||
<ConversationItem
|
||||
key={conv.id}
|
||||
conv={conv}
|
||||
isCurrConv={currConv?.id === conv.id}
|
||||
onSelect={() => {
|
||||
navigate(`/chat/${conv.id}`);
|
||||
}}
|
||||
onDelete={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot delete conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (
|
||||
await showConfirm(
|
||||
'Are you sure to delete this conversation?'
|
||||
)
|
||||
) {
|
||||
toast.success('Conversation deleted');
|
||||
StorageUtils.remove(conv.id);
|
||||
navigate('/');
|
||||
}
|
||||
}}
|
||||
onDownload={() => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot download conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
const conversationJson = JSON.stringify(conv, null, 2);
|
||||
const blob = new Blob([conversationJson], {
|
||||
type: 'application/json',
|
||||
});
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = `conversation_${conv.id}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
}}
|
||||
onRename={async () => {
|
||||
if (isGenerating(conv.id)) {
|
||||
toast.error(
|
||||
'Cannot rename conversation while generating'
|
||||
);
|
||||
return;
|
||||
}
|
||||
const newName = await showPrompt(
|
||||
'Enter new name for the conversation',
|
||||
conv.name
|
||||
);
|
||||
if (newName && newName.trim().length > 0) {
|
||||
StorageUtils.updateConversationName(conv.id, newName);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
<div className="text-center text-xs opacity-40 mt-auto mx-4 pt-8">
|
||||
Conversations are saved to browser's IndexedDB
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function ConversationItem({
|
||||
conv,
|
||||
isCurrConv,
|
||||
onSelect,
|
||||
onDelete,
|
||||
onDownload,
|
||||
onRename,
|
||||
}: {
|
||||
conv: Conversation;
|
||||
isCurrConv: boolean;
|
||||
onSelect: () => void;
|
||||
onDelete: () => void;
|
||||
onDownload: () => void;
|
||||
onRename: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
role="menuitem"
|
||||
tabIndex={0}
|
||||
aria-label={conv.name}
|
||||
className={classNames({
|
||||
'group flex flex-row btn btn-ghost justify-start items-center font-normal px-2 h-9':
|
||||
true,
|
||||
'btn-soft': isCurrConv,
|
||||
})}
|
||||
>
|
||||
<button
|
||||
key={conv.id}
|
||||
className="w-full overflow-hidden truncate text-start"
|
||||
onClick={onSelect}
|
||||
dir="auto"
|
||||
>
|
||||
{conv.name}
|
||||
</button>
|
||||
<div tabIndex={0} className="dropdown dropdown-end h-5">
|
||||
<BtnWithTooltips
|
||||
// on mobile, we always show the ellipsis icon
|
||||
// on desktop, we only show it when the user hovers over the conversation item
|
||||
// we use opacity instead of hidden to avoid layout shift
|
||||
className="cursor-pointer opacity-100 md:opacity-0 group-hover:opacity-100"
|
||||
onClick={() => {}}
|
||||
tooltipsContent="More"
|
||||
>
|
||||
<EllipsisVerticalIcon className="w-5 h-5" />
|
||||
</BtnWithTooltips>
|
||||
{/* dropdown menu */}
|
||||
<ul
|
||||
aria-label="More options"
|
||||
tabIndex={0}
|
||||
className="dropdown-content menu bg-base-100 rounded-box z-[1] p-2 shadow"
|
||||
>
|
||||
<li onClick={onRename} tabIndex={0}>
|
||||
<a>
|
||||
<PencilIcon className="w-4 h-4" />
|
||||
Rename
|
||||
</a>
|
||||
</li>
|
||||
<li onClick={onDownload} tabIndex={0}>
|
||||
<a>
|
||||
<ArrowDownTrayIcon className="w-4 h-4" />
|
||||
Download
|
||||
</a>
|
||||
</li>
|
||||
<li className="text-error" onClick={onDelete} tabIndex={0}>
|
||||
<a>
|
||||
<TrashIcon className="w-4 h-4" />
|
||||
Delete
|
||||
</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
|
||||
export interface GroupedConversations {
|
||||
title?: string;
|
||||
conversations: Conversation[];
|
||||
}
|
||||
|
||||
// TODO @ngxson : add test for this function
|
||||
// Group conversations by date
|
||||
// - "Previous 7 Days"
|
||||
// - "Previous 30 Days"
|
||||
// - "Month Year" (e.g., "April 2023")
|
||||
export function groupConversationsByDate(
|
||||
conversations: Conversation[]
|
||||
): GroupedConversations[] {
|
||||
const now = new Date();
|
||||
const today = new Date(now.getFullYear(), now.getMonth(), now.getDate()); // Start of today
|
||||
|
||||
const sevenDaysAgo = new Date(today);
|
||||
sevenDaysAgo.setDate(today.getDate() - 7);
|
||||
|
||||
const thirtyDaysAgo = new Date(today);
|
||||
thirtyDaysAgo.setDate(today.getDate() - 30);
|
||||
|
||||
const groups: { [key: string]: Conversation[] } = {
|
||||
Today: [],
|
||||
'Previous 7 Days': [],
|
||||
'Previous 30 Days': [],
|
||||
};
|
||||
const monthlyGroups: { [key: string]: Conversation[] } = {}; // Key format: "Month Year" e.g., "April 2023"
|
||||
|
||||
// Sort conversations by lastModified date in descending order (newest first)
|
||||
// This helps when adding to groups, but the final output order of groups is fixed.
|
||||
const sortedConversations = [...conversations].sort(
|
||||
(a, b) => b.lastModified - a.lastModified
|
||||
);
|
||||
|
||||
for (const conv of sortedConversations) {
|
||||
const convDate = new Date(conv.lastModified);
|
||||
|
||||
if (convDate >= today) {
|
||||
groups['Today'].push(conv);
|
||||
} else if (convDate >= sevenDaysAgo) {
|
||||
groups['Previous 7 Days'].push(conv);
|
||||
} else if (convDate >= thirtyDaysAgo) {
|
||||
groups['Previous 30 Days'].push(conv);
|
||||
} else {
|
||||
const monthName = convDate.toLocaleString('default', { month: 'long' });
|
||||
const year = convDate.getFullYear();
|
||||
const monthYearKey = `${monthName} ${year}`;
|
||||
if (!monthlyGroups[monthYearKey]) {
|
||||
monthlyGroups[monthYearKey] = [];
|
||||
}
|
||||
monthlyGroups[monthYearKey].push(conv);
|
||||
}
|
||||
}
|
||||
|
||||
const result: GroupedConversations[] = [];
|
||||
|
||||
if (groups['Today'].length > 0) {
|
||||
result.push({
|
||||
title: undefined, // no title for Today
|
||||
conversations: groups['Today'],
|
||||
});
|
||||
}
|
||||
|
||||
if (groups['Previous 7 Days'].length > 0) {
|
||||
result.push({
|
||||
title: 'Previous 7 Days',
|
||||
conversations: groups['Previous 7 Days'],
|
||||
});
|
||||
}
|
||||
|
||||
if (groups['Previous 30 Days'].length > 0) {
|
||||
result.push({
|
||||
title: 'Previous 30 Days',
|
||||
conversations: groups['Previous 30 Days'],
|
||||
});
|
||||
}
|
||||
|
||||
// Sort monthly groups by date (most recent month first)
|
||||
const sortedMonthKeys = Object.keys(monthlyGroups).sort((a, b) => {
|
||||
const dateA = new Date(a); // "Month Year" can be parsed by Date constructor
|
||||
const dateB = new Date(b);
|
||||
return dateB.getTime() - dateA.getTime();
|
||||
});
|
||||
|
||||
for (const monthKey of sortedMonthKeys) {
|
||||
if (monthlyGroups[monthKey].length > 0) {
|
||||
result.push({ title: monthKey, conversations: monthlyGroups[monthKey] });
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -1,371 +0,0 @@
|
||||
import { useState } from 'react';
|
||||
import { MessageExtra } from '../utils/types';
|
||||
import toast from 'react-hot-toast';
|
||||
import { useAppContext } from '../utils/app.context';
|
||||
import * as pdfjs from 'pdfjs-dist';
|
||||
import pdfjsWorkerSrc from 'pdfjs-dist/build/pdf.worker.min.mjs?url';
|
||||
import { TextContent, TextItem } from 'pdfjs-dist/types/src/display/api';
|
||||
|
||||
pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc;
|
||||
|
||||
// This file handles uploading extra context items (a.k.a files)
|
||||
// It allows processing these kinds of files:
|
||||
// - image files (converted to base64)
|
||||
// - audio files (converted to base64)
|
||||
// - text files (including code files)
|
||||
// - pdf (converted to text)
|
||||
|
||||
// Interface describing the API returned by the hook
|
||||
export interface ChatExtraContextApi {
|
||||
items?: MessageExtra[]; // undefined if empty, similar to Message['extra']
|
||||
addItems: (items: MessageExtra[]) => void;
|
||||
removeItem: (idx: number) => void;
|
||||
clearItems: () => void;
|
||||
onFileAdded: (files: File[]) => void; // used by "upload" button
|
||||
}
|
||||
|
||||
export function useChatExtraContext(): ChatExtraContextApi {
|
||||
const { serverProps, config } = useAppContext();
|
||||
const [items, setItems] = useState<MessageExtra[]>([]);
|
||||
|
||||
const addItems = (newItems: MessageExtra[]) => {
|
||||
setItems((prev) => [...prev, ...newItems]);
|
||||
};
|
||||
|
||||
const removeItem = (idx: number) => {
|
||||
setItems((prev) => prev.filter((_, i) => i !== idx));
|
||||
};
|
||||
|
||||
const clearItems = () => {
|
||||
setItems([]);
|
||||
};
|
||||
|
||||
const isSupportVision = serverProps?.modalities?.vision;
|
||||
|
||||
const onFileAdded = async (files: File[]) => {
|
||||
try {
|
||||
for (const file of files) {
|
||||
const mimeType = file.type;
|
||||
|
||||
// this limit is only to prevent accidental uploads of huge files
|
||||
// it can potentially crashes the browser because we read the file as base64
|
||||
if (file.size > 500 * 1024 * 1024) {
|
||||
toast.error('File is too large. Maximum size is 500MB.');
|
||||
break;
|
||||
}
|
||||
|
||||
if (mimeType.startsWith('image/')) {
|
||||
if (!isSupportVision) {
|
||||
toast.error('Multimodal is not supported by this server or model.');
|
||||
break;
|
||||
}
|
||||
|
||||
let base64Url = await getFileAsBase64(file);
|
||||
if (mimeType === 'image/svg+xml') {
|
||||
// Convert SVG to PNG
|
||||
base64Url = await svgBase64UrlToPngDataURL(base64Url);
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
},
|
||||
]);
|
||||
} else if (mimeType.startsWith('video/')) {
|
||||
toast.error('Video files are not supported yet.');
|
||||
break;
|
||||
} else if (mimeType.startsWith('audio/')) {
|
||||
if (!/mpeg|wav/.test(mimeType)) {
|
||||
toast.error('Only mp3 and wav audio files are supported.');
|
||||
break;
|
||||
}
|
||||
|
||||
// plain base64, not a data URL
|
||||
const base64Data = await getFileAsBase64(file, false);
|
||||
addItems([
|
||||
{
|
||||
type: 'audioFile',
|
||||
name: file.name,
|
||||
mimeType,
|
||||
base64Data,
|
||||
},
|
||||
]);
|
||||
} else if (mimeType.startsWith('application/pdf')) {
|
||||
if (config.pdfAsImage && !isSupportVision) {
|
||||
toast(
|
||||
'Multimodal is not supported, PDF will be converted to text instead of image.'
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if (config.pdfAsImage && isSupportVision) {
|
||||
// Convert PDF to images
|
||||
const base64Urls = await convertPDFToImage(file);
|
||||
addItems(
|
||||
base64Urls.map((base64Url) => ({
|
||||
type: 'imageFile',
|
||||
name: file.name,
|
||||
base64Url,
|
||||
}))
|
||||
);
|
||||
} else {
|
||||
// Convert PDF to text
|
||||
const content = await convertPDFToText(file);
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
if (isSupportVision) {
|
||||
toast.success(
|
||||
'PDF file converted to text. You can also convert it to image, see in Settings.'
|
||||
);
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Because there can be many text file types (like code file), we will not check the mime type
|
||||
// and will just check if the file is not binary.
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
const content = event.target.result as string;
|
||||
if (!isLikelyNotBinary(content)) {
|
||||
toast.error('File is binary. Please upload a text file.');
|
||||
return;
|
||||
}
|
||||
addItems([
|
||||
{
|
||||
type: 'textFile',
|
||||
name: file.name,
|
||||
content,
|
||||
},
|
||||
]);
|
||||
}
|
||||
};
|
||||
reader.readAsText(file);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const errorMessage = `Error processing file: ${message}`;
|
||||
toast.error(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
items: items.length > 0 ? items : undefined,
|
||||
addItems,
|
||||
removeItem,
|
||||
clearItems,
|
||||
onFileAdded,
|
||||
};
|
||||
}
|
||||
|
||||
async function getFileAsBase64(file: File, outputUrl = true): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
let result = event.target.result as string;
|
||||
if (!outputUrl) {
|
||||
// remove base64 url prefix and correct characters
|
||||
result = result.substring(result.indexOf(',') + 1);
|
||||
}
|
||||
resolve(result);
|
||||
} else {
|
||||
reject(new Error('Failed to read file.'));
|
||||
}
|
||||
};
|
||||
reader.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function getFileAsBuffer(file: File): Promise<ArrayBuffer> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (event) => {
|
||||
if (event.target?.result) {
|
||||
resolve(event.target.result as ArrayBuffer);
|
||||
} else {
|
||||
reject(new Error('Failed to read file.'));
|
||||
}
|
||||
};
|
||||
reader.readAsArrayBuffer(file);
|
||||
});
|
||||
}
|
||||
|
||||
async function convertPDFToText(file: File): Promise<string> {
|
||||
const buffer = await getFileAsBuffer(file);
|
||||
const pdf = await pdfjs.getDocument(buffer).promise;
|
||||
const numPages = pdf.numPages;
|
||||
const textContentPromises: Promise<TextContent>[] = [];
|
||||
for (let i = 1; i <= numPages; i++) {
|
||||
textContentPromises.push(
|
||||
pdf.getPage(i).then((page) => page.getTextContent())
|
||||
);
|
||||
}
|
||||
const textContents = await Promise.all(textContentPromises);
|
||||
const textItems = textContents.flatMap((textContent: TextContent) =>
|
||||
textContent.items.map((item) => (item as TextItem).str ?? '')
|
||||
);
|
||||
return textItems.join('\n');
|
||||
}
|
||||
|
||||
// returns list of base64 images
|
||||
async function convertPDFToImage(file: File): Promise<string[]> {
|
||||
const buffer = await getFileAsBuffer(file);
|
||||
const doc = await pdfjs.getDocument(buffer).promise;
|
||||
const pages: Promise<string>[] = [];
|
||||
|
||||
for (let i = 1; i <= doc.numPages; i++) {
|
||||
const page = await doc.getPage(i);
|
||||
const viewport = page.getViewport({ scale: 1.5 });
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
canvas.width = viewport.width;
|
||||
canvas.height = viewport.height;
|
||||
if (!ctx) {
|
||||
throw new Error('Failed to get 2D context from canvas');
|
||||
}
|
||||
const task = page.render({ canvasContext: ctx, viewport: viewport });
|
||||
pages.push(
|
||||
task.promise.then(() => {
|
||||
return canvas.toDataURL();
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
return await Promise.all(pages);
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
// This code is a heuristic to determine if a string is likely not binary.
|
||||
// It is necessary because input file can have various mime types which we don't have time to investigate.
|
||||
// For example, a python file can be text/plain, application/x-python, etc.
|
||||
function isLikelyNotBinary(str: string): boolean {
|
||||
const options = {
|
||||
prefixLength: 1024 * 10, // Check the first 10KB of the string
|
||||
suspiciousCharThresholdRatio: 0.15, // Allow up to 15% suspicious chars
|
||||
maxAbsoluteNullBytes: 2,
|
||||
};
|
||||
|
||||
if (!str) {
|
||||
return true; // Empty string is considered "not binary" or trivially text.
|
||||
}
|
||||
|
||||
const sampleLength = Math.min(str.length, options.prefixLength);
|
||||
if (sampleLength === 0) {
|
||||
return true; // Effectively an empty string after considering prefixLength.
|
||||
}
|
||||
|
||||
let suspiciousCharCount = 0;
|
||||
let nullByteCount = 0;
|
||||
|
||||
for (let i = 0; i < sampleLength; i++) {
|
||||
const charCode = str.charCodeAt(i);
|
||||
|
||||
// 1. Check for Unicode Replacement Character (U+FFFD)
|
||||
// This is a strong indicator if the string was created from decoding bytes as UTF-8.
|
||||
if (charCode === 0xfffd) {
|
||||
suspiciousCharCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2. Check for Null Bytes (U+0000)
|
||||
if (charCode === 0x0000) {
|
||||
nullByteCount++;
|
||||
// We also count nulls towards the general suspicious character count,
|
||||
// as they are less common in typical text files.
|
||||
suspiciousCharCount++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. Check for C0 Control Characters (U+0001 to U+001F)
|
||||
// Exclude common text control characters: TAB (9), LF (10), CR (13).
|
||||
// We can also be a bit lenient with BEL (7) and BS (8) which sometimes appear in logs.
|
||||
if (charCode < 32) {
|
||||
if (
|
||||
charCode !== 9 && // TAB
|
||||
charCode !== 10 && // LF
|
||||
charCode !== 13 && // CR
|
||||
charCode !== 7 && // BEL (Bell) - sometimes in logs
|
||||
charCode !== 8 // BS (Backspace) - less common, but possible
|
||||
) {
|
||||
suspiciousCharCount++;
|
||||
}
|
||||
}
|
||||
// Characters from 32 (space) up to 126 (~) are printable ASCII.
|
||||
// Characters 127 (DEL) is a control character.
|
||||
// Characters >= 128 are extended ASCII / multi-byte Unicode.
|
||||
// If they resulted in U+FFFD, we caught it. Otherwise, they are valid
|
||||
// (though perhaps unusual) Unicode characters from JS's perspective.
|
||||
// The main concern is if those higher characters came from misinterpreting
|
||||
// a single-byte encoding as UTF-8, which again, U+FFFD would usually flag.
|
||||
}
|
||||
|
||||
// Check absolute null byte count
|
||||
if (nullByteCount > options.maxAbsoluteNullBytes) {
|
||||
return false; // Too many null bytes is a strong binary indicator
|
||||
}
|
||||
|
||||
// Check ratio of suspicious characters
|
||||
const ratio = suspiciousCharCount / sampleLength;
|
||||
return ratio <= options.suspiciousCharThresholdRatio;
|
||||
}
|
||||
|
||||
// WARN: vibe code below
|
||||
// Converts a Base64URL encoded SVG string to a PNG Data URL using browser Canvas API.
|
||||
function svgBase64UrlToPngDataURL(base64UrlSvg: string): Promise<string> {
|
||||
const backgroundColor = 'white'; // Default background color for PNG
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
const img = new Image();
|
||||
|
||||
img.onload = () => {
|
||||
const canvas = document.createElement('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
|
||||
if (!ctx) {
|
||||
reject(new Error('Failed to get 2D canvas context.'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Use provided dimensions or SVG's natural dimensions, with fallbacks
|
||||
// Fallbacks (e.g., 300x300) are for SVGs without explicit width/height
|
||||
// or when naturalWidth/Height might be 0 before full processing.
|
||||
const targetWidth = img.naturalWidth || 300;
|
||||
const targetHeight = img.naturalHeight || 300;
|
||||
|
||||
canvas.width = targetWidth;
|
||||
canvas.height = targetHeight;
|
||||
|
||||
if (backgroundColor) {
|
||||
ctx.fillStyle = backgroundColor;
|
||||
ctx.fillRect(0, 0, canvas.width, canvas.height);
|
||||
}
|
||||
|
||||
ctx.drawImage(img, 0, 0, targetWidth, targetHeight);
|
||||
resolve(canvas.toDataURL('image/png'));
|
||||
};
|
||||
|
||||
img.onerror = () => {
|
||||
reject(
|
||||
new Error('Failed to load SVG image. Ensure the SVG data is valid.')
|
||||
);
|
||||
};
|
||||
|
||||
// Load SVG string into an Image element
|
||||
img.src = base64UrlSvg;
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
const errorMessage = `Error converting SVG to PNG: ${message}`;
|
||||
toast.error(errorMessage);
|
||||
reject(new Error(errorMessage));
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
import React, { useEffect } from 'react';
|
||||
import { throttle } from '../utils/misc';
|
||||
|
||||
export const scrollToBottom = (requiresNearBottom: boolean, delay?: number) => {
|
||||
const mainScrollElem = document.getElementById('main-scroll');
|
||||
if (!mainScrollElem) return;
|
||||
const spaceToBottom =
|
||||
mainScrollElem.scrollHeight -
|
||||
mainScrollElem.scrollTop -
|
||||
mainScrollElem.clientHeight;
|
||||
if (!requiresNearBottom || spaceToBottom < 100) {
|
||||
setTimeout(
|
||||
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
|
||||
delay ?? 80
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const scrollToBottomThrottled = throttle(scrollToBottom, 80);
|
||||
|
||||
export function useChatScroll(msgListRef: React.RefObject<HTMLDivElement>) {
|
||||
useEffect(() => {
|
||||
if (!msgListRef.current) return;
|
||||
|
||||
const resizeObserver = new ResizeObserver((_) => {
|
||||
scrollToBottomThrottled(true, 10);
|
||||
});
|
||||
|
||||
resizeObserver.observe(msgListRef.current);
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [msgListRef]);
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user