mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-16 05:54:06 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
364a7a6d4a | ||
|
|
2df5bcf357 | ||
|
|
075c01567b | ||
|
|
a014310374 | ||
|
|
35fb82497e | ||
|
|
3c62aed89f | ||
|
|
f1eb1cb1eb | ||
|
|
de41f2b7bf | ||
|
|
a74a0d69f3 | ||
|
|
5f7e166cbf | ||
|
|
d72f5f7ba2 |
52
.github/workflows/build-amd.yml
vendored
Normal file
52
.github/workflows/build-amd.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
name: CI (AMD)
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths: [
|
||||
'.github/workflows/build-amd.yml',
|
||||
'**/CMakeLists.txt',
|
||||
'**/.cmake',
|
||||
'**/*.h',
|
||||
'**/*.hpp',
|
||||
'**/*.c',
|
||||
'**/*.cpp',
|
||||
'**/*.cu',
|
||||
'**/*.cuh',
|
||||
'**/*.comp'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ggml-ci-x64-amd-vulkan:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-x64-amd-rocm:
|
||||
runs-on: [self-hosted, Linux, X64, AMD]
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
amd-smi static
|
||||
GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
39
.github/workflows/build.yml
vendored
39
.github/workflows/build.yml
vendored
@@ -1222,11 +1222,12 @@ jobs:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: android-build
|
||||
evict-old-files: 1d
|
||||
# Disabled due to size (400MB) and always 0 cache hits
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.16
|
||||
# with:
|
||||
# key: android-build
|
||||
# evict-old-files: 1d
|
||||
|
||||
- name: Set up JDK
|
||||
uses: actions/setup-java@v3
|
||||
@@ -1461,34 +1462,6 @@ jobs:
|
||||
run: |
|
||||
bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
# ggml-ci-x64-amd-vulkan:
|
||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# vulkaninfo --summary
|
||||
# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
#
|
||||
# ggml-ci-x64-amd-rocm:
|
||||
# runs-on: [self-hosted, Linux, X64, AMD]
|
||||
#
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
# run: |
|
||||
# amd-smi static
|
||||
# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp
|
||||
|
||||
ggml-ci-mac-metal:
|
||||
runs-on: [self-hosted, macOS, ARM64]
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@
|
||||
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||
/ggml/src/ggml-metal/ @ggerganov
|
||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-quants.* @ggerganov
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
|
||||
@@ -114,6 +114,7 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then
|
||||
# arm 9 and newer enables sve by default, adjust these flags depending on the cpu used
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
|
||||
fi
|
||||
|
||||
## helpers
|
||||
|
||||
# download a file if it does not exist or if it is outdated
|
||||
|
||||
186
common/arg.cpp
186
common/arg.cpp
@@ -217,12 +217,55 @@ struct common_hf_file_res {
|
||||
std::string mmprojFile;
|
||||
};
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
|
||||
bool common_has_curl() {
|
||||
return true;
|
||||
static void write_etag(const std::string & path, const std::string & etag) {
|
||||
const std::string etag_path = path + ".etag";
|
||||
write_file(etag_path, etag);
|
||||
LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
|
||||
}
|
||||
|
||||
static std::string read_etag(const std::string & path) {
|
||||
std::string none;
|
||||
const std::string etag_path = path + ".etag";
|
||||
|
||||
if (std::filesystem::exists(etag_path)) {
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return none;
|
||||
}
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
}
|
||||
|
||||
// no etag file, but maybe there is an old .json
|
||||
// remove this code later
|
||||
const std::string metadata_path = path + ".json";
|
||||
|
||||
if (std::filesystem::exists(metadata_path)) {
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
try {
|
||||
nlohmann::json metadata_json;
|
||||
metadata_in >> metadata_json;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata_json.dump().c_str());
|
||||
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
||||
std::string etag = metadata_json.at("etag");
|
||||
write_etag(path, etag);
|
||||
if (!std::filesystem::remove(metadata_path)) {
|
||||
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
|
||||
}
|
||||
return etag;
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
return none;
|
||||
}
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
|
||||
//
|
||||
// CURL utils
|
||||
//
|
||||
@@ -373,36 +416,15 @@ static bool common_download_head(CURL * curl,
|
||||
static bool common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token) {
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
std::string etag;
|
||||
|
||||
// Check if the file already exists locally
|
||||
const auto file_exists = std::filesystem::exists(path);
|
||||
if (file_exists) {
|
||||
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
if (metadata_in.good()) {
|
||||
try {
|
||||
metadata_in >> metadata;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata.dump().c_str());
|
||||
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
||||
etag = metadata.at("etag");
|
||||
}
|
||||
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
||||
last_modified = metadata.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
}
|
||||
// if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again)
|
||||
etag = read_etag(path);
|
||||
} else {
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
@@ -440,11 +462,6 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
headers.etag.c_str());
|
||||
should_download = true;
|
||||
should_download_from_scratch = true;
|
||||
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_modified.c_str(), headers.last_modified.c_str());
|
||||
should_download = true;
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -475,15 +492,9 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write the updated JSON metadata file.
|
||||
metadata.update({
|
||||
{ "url", url },
|
||||
{ "etag", headers.etag },
|
||||
{ "lastModified", headers.last_modified }
|
||||
});
|
||||
write_file(metadata_path, metadata.dump(4));
|
||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
||||
if (head_request_ok) {
|
||||
write_etag(path, headers.etag);
|
||||
}
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
||||
@@ -570,10 +581,6 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||
|
||||
#else
|
||||
|
||||
bool common_has_curl() {
|
||||
return false;
|
||||
}
|
||||
|
||||
struct common_url {
|
||||
std::string scheme;
|
||||
std::string user;
|
||||
@@ -664,51 +671,6 @@ static void print_progress(size_t current, size_t total) { // TODO isatty
|
||||
std::cout.flush();
|
||||
}
|
||||
|
||||
struct common_file_metadata {
|
||||
std::string etag;
|
||||
std::string last_modified;
|
||||
};
|
||||
|
||||
static std::optional<common_file_metadata> read_metadata(const std::string & path) {
|
||||
if (!std::filesystem::exists(path)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
nlohmann::json metadata_json;
|
||||
common_file_metadata metadata;
|
||||
|
||||
std::ifstream metadata_in(path);
|
||||
try {
|
||||
metadata_in >> metadata_json;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(),
|
||||
metadata_json.dump().c_str());
|
||||
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
||||
metadata.etag = metadata_json.at("etag");
|
||||
}
|
||||
if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) {
|
||||
metadata.last_modified = metadata_json.at("lastModified");
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return metadata;
|
||||
}
|
||||
|
||||
static void write_metadata(const std::string & path,
|
||||
const std::string & url,
|
||||
const common_file_metadata & metadata) {
|
||||
nlohmann::json metadata_json = {
|
||||
{ "url", url },
|
||||
{ "etag", metadata.etag },
|
||||
{ "lastModified", metadata.last_modified }
|
||||
};
|
||||
|
||||
write_file(path, metadata_json.dump(4));
|
||||
LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
static bool common_pull_file(httplib::Client & cli,
|
||||
const std::string & resolve_path,
|
||||
const std::string & path_tmp,
|
||||
@@ -775,8 +737,6 @@ static bool common_pull_file(httplib::Client & cli,
|
||||
static bool common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token) {
|
||||
// If the file exists, check its JSON metadata companion file.
|
||||
std::string metadata_path = path + ".json";
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
@@ -788,12 +748,11 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
cli.set_default_headers(default_headers);
|
||||
|
||||
common_file_metadata last;
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
std::string last_etag;
|
||||
if (file_exists) {
|
||||
if (auto opt = read_metadata(metadata_path)) {
|
||||
last = *opt;
|
||||
}
|
||||
last_etag = read_etag(path);
|
||||
} else {
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
@@ -809,14 +768,9 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
}
|
||||
|
||||
common_file_metadata current;
|
||||
if (head_ok) {
|
||||
if (head->has_header("ETag")) {
|
||||
current.etag = head->get_header_value("ETag");
|
||||
}
|
||||
if (head->has_header("Last-Modified")) {
|
||||
current.last_modified = head->get_header_value("Last-Modified");
|
||||
}
|
||||
std::string etag;
|
||||
if (head_ok && head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
@@ -834,16 +788,10 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
|
||||
bool should_download_from_scratch = false;
|
||||
if (head_ok) {
|
||||
if (!last.etag.empty() && last.etag != current.etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last.etag.c_str(), current.etag.c_str());
|
||||
should_download_from_scratch = true;
|
||||
} else if (!last.last_modified.empty() && last.last_modified != current.last_modified) {
|
||||
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last.last_modified.c_str(), current.last_modified.c_str());
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_etag.c_str(), etag.c_str());
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
|
||||
if (file_exists) {
|
||||
@@ -871,9 +819,8 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n",
|
||||
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(),
|
||||
current.etag.c_str(), current.last_modified.c_str());
|
||||
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
|
||||
__func__, show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
|
||||
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
|
||||
if (!was_pull_successful) {
|
||||
if (i + 1 < max_attempts) {
|
||||
@@ -883,7 +830,6 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
} else {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -891,7 +837,9 @@ static bool common_download_file_single_online(const std::string & url,
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return false;
|
||||
}
|
||||
write_metadata(metadata_path, url, current);
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
@@ -78,7 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
|
||||
|
||||
// function to be used by test-arg-parser
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
bool common_has_curl();
|
||||
|
||||
struct common_remote_params {
|
||||
std::vector<std::string> headers;
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 3)
|
||||
set(GGML_VERSION_PATCH 4)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -513,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.13.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.14.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
@@ -592,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
|
||||
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
|
||||
|
||||
@@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
||||
return tensor->ne[dim];
|
||||
}
|
||||
|
||||
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
|
||||
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
|
||||
using V = std::remove_reference_t<Variant>;
|
||||
return (std::is_invocable_r_v<
|
||||
Ret,
|
||||
std::variant_alternative_t<Is, V>,
|
||||
Args...> || ...);
|
||||
}
|
||||
|
||||
template <typename Variant, typename Ret, typename... Args>
|
||||
constexpr bool variant_any_invocable_v =
|
||||
variant_any_invocable_impl<Variant, Ret, Args...>(
|
||||
std::make_index_sequence<
|
||||
std::variant_size_v<std::remove_reference_t<Variant>>>{});
|
||||
|
||||
template<typename Ret, typename Variant, typename... Args>
|
||||
static Ret variant_call(const Variant & var, Args&&... args) {
|
||||
return std::visit([&](auto&& func) -> Ret {
|
||||
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
|
||||
return func(std::forward<Args>(args)...);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid function type in variant_call");
|
||||
}
|
||||
}, var);
|
||||
static inline Ret variant_call(Variant && var, Args&&... args) {
|
||||
static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
|
||||
"No alternative in Variant is invocable with the provided arguments and return type.");
|
||||
|
||||
return std::visit(
|
||||
[&](auto && f) -> Ret {
|
||||
using F = std::decay_t<decltype(f)>;
|
||||
if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
|
||||
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
|
||||
} else {
|
||||
GGML_ABORT("Invalid function type in variant_call");
|
||||
GGML_UNREACHABLE();
|
||||
}
|
||||
},
|
||||
std::forward<Variant>(var)
|
||||
);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
@@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
|
||||
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
|
||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||
k * n * sizeof(float) + n * sizeof(float);
|
||||
} else {
|
||||
@@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
||||
if (dst->op == GGML_OP_MUL_MAT) {
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
@@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
}
|
||||
|
||||
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
@@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
|
||||
bool is_gemv = src1->ne[1] == 1;
|
||||
const bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
GGML_ASSERT(kernel);
|
||||
@@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
|
||||
const int64_t lhs_batch_size0 = ne12;
|
||||
const int64_t rhs_batch_size0 = ne02;
|
||||
const int64_t batch_size = rhs_batch_size0;
|
||||
const int64_t batch_size = lhs_batch_size0;
|
||||
|
||||
GGML_ASSERT(rhs_batch_size0 > 0);
|
||||
GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
|
||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||
|
||||
const int64_t m = ne11 * r;
|
||||
const int64_t n = ne01;
|
||||
const int64_t k = ne00;
|
||||
const int64_t m_group = ne11;
|
||||
const int64_t m = m_group;
|
||||
const int64_t n = ne01;
|
||||
const int64_t k = ne00;
|
||||
|
||||
const size_t lhs_stride = src1->nb[1];
|
||||
const size_t rhs_stride = src0->nb[1];
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
|
||||
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
|
||||
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
|
||||
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
||||
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
||||
const int64_t mr = (int64_t) kernel->get_mr();
|
||||
const int64_t nr = (int64_t) kernel->get_nr();
|
||||
const int64_t kr = (int64_t) kernel->get_kr();
|
||||
const int64_t sr = (int64_t) kernel->get_sr();
|
||||
|
||||
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
||||
const size_t kxn_size = k * n * sizeof(float);
|
||||
const size_t bias_size = n * sizeof(float);
|
||||
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
|
||||
const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
|
||||
const size_t bias_size = (size_t)n * sizeof(float);
|
||||
|
||||
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
||||
GGML_ASSERT(wsize_required <= params->wsize);
|
||||
@@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
uint8_t * bias = rhs_kxn + kxn_size;
|
||||
|
||||
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
||||
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
|
||||
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
|
||||
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
|
||||
const int64_t rhs_batch_idx = batch_idx / r;
|
||||
const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
|
||||
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
|
||||
|
||||
// LHS packing
|
||||
// LHS packing (threaded over m, honoring mr alignment and KV groups)
|
||||
{
|
||||
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
||||
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
||||
|
||||
if (ith < num_threads) {
|
||||
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
|
||||
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
|
||||
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
||||
|
||||
const int64_t m_start = ith * num_m_per_thread0;
|
||||
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
const int64_t m_start = ith * num_m_per_thread0;
|
||||
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
|
||||
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
|
||||
// Base packed offset (aligned) and per-row stride in bytes
|
||||
const size_t base_packed_off = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t next_block_off = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
|
||||
|
||||
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
||||
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
||||
int64_t remaining = m_count;
|
||||
int64_t cur = m_start;
|
||||
|
||||
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
while (remaining > 0) {
|
||||
const int64_t row_in_group = cur;
|
||||
const int64_t avail = m_group - row_in_group;
|
||||
const int64_t take = std::min(avail, remaining);
|
||||
|
||||
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
|
||||
const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
|
||||
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
||||
void * dst_ptr = lhs_packed + dst_off;
|
||||
|
||||
variant_call<void>(lhs_info->pack_func,
|
||||
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
|
||||
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
|
||||
|
||||
cur += take;
|
||||
remaining -= take;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RHS packing
|
||||
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
|
||||
// First thread to reach this point handles RHS packing
|
||||
memset(bias, 0, n * sizeof(float));
|
||||
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
|
||||
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
|
||||
// RHS packing (single thread), then synchronize
|
||||
if (ith == 0) {
|
||||
memset(bias, 0, (size_t)n * sizeof(float));
|
||||
transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
|
||||
reinterpret_cast<float *>(rhs_kxn),
|
||||
reinterpret_cast<const uint16_t *>(rhs_batch_base),
|
||||
rhs_stride);
|
||||
|
||||
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
||||
variant_call<void>(kernels->rhs_info.pack_func,
|
||||
/*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
|
||||
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
first_to_arrive.clear(std::memory_order_release);
|
||||
|
||||
// Perform the matmul
|
||||
// Matmul (threaded over n)
|
||||
{
|
||||
const int64_t m_to_process = m;
|
||||
const int64_t m_start = 0;
|
||||
|
||||
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
||||
int64_t num_threads = KAI_MIN(n / n_step, nth);
|
||||
if (num_threads <= 0) {
|
||||
num_threads = 1;
|
||||
const int64_t n_step = (int64_t) kernel->get_n_step();
|
||||
int64_t num_threads_n = KAI_MIN(n / n_step, nth);
|
||||
if (num_threads_n <= 0) {
|
||||
num_threads_n = 1;
|
||||
}
|
||||
|
||||
if (ith < num_threads) {
|
||||
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
||||
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
|
||||
if (ith < num_threads_n) {
|
||||
const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
|
||||
const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
|
||||
|
||||
const int64_t n_start = ith * num_n_per_thread0;
|
||||
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
|
||||
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
|
||||
// LHS packed base at row 0 (consistent with packing above)
|
||||
const size_t lhs_packed_offset0 = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
|
||||
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
||||
|
||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
|
||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
|
||||
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
|
||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
variant_call<void>(kernel->run_kernel,
|
||||
(size_t)m, (size_t)n_to_process, (size_t)k,
|
||||
lhs_ptr, rhs_ptr,
|
||||
dst_ptr, dst_stride, sizeof(float),
|
||||
-FLT_MAX, FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
if (batch_idx != batch_size - 1) {
|
||||
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
|
||||
// the work data buffer (params->wdata) is used as temporary storage which means that only
|
||||
// a single batch can be processed at any given time. No barrier is needed for the last
|
||||
// batch since GGML inserts a barrier between the execution of every operator.
|
||||
ggml_barrier(params->threadpool);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
} else
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
@@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
return nullptr;
|
||||
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
||||
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
|
||||
@@ -2641,6 +2641,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
|
||||
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
@@ -2669,7 +2671,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
|
||||
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
|
||||
@@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
if (ne00 == 4) {
|
||||
if (ne00 < 32) {
|
||||
nsg = 1;
|
||||
nr0 = 32;
|
||||
nr1 = 4;
|
||||
suffix = "_c4";
|
||||
} else if (ne00 % 4 == 0) {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
suffix = "_4";
|
||||
suffix = "_short";
|
||||
} else {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nsg = std::min(4, (ne00 + 127) / 128);
|
||||
nr0 = 2;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
smem = 32*sizeof(float)*nr0;
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
if (ne00 % 4 == 0) {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
suffix = "_4";
|
||||
} else {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
}
|
||||
nsg = std::min(4, (ne00 + 127) / 128);
|
||||
nr0 = 2;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*nr0;
|
||||
suffix = ne00 % 4 == 0 ? "_4" : "";
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
|
||||
@@ -8,9 +8,6 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_F 2
|
||||
#define N_SG_F 4
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
@@ -352,6 +349,7 @@ typedef struct {
|
||||
uint64_t nb13;
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
int32_t nr0;
|
||||
int16_t r2;
|
||||
int16_t r3;
|
||||
} ggml_metal_kargs_mul_mv;
|
||||
@@ -427,6 +425,7 @@ typedef struct {
|
||||
int32_t ne0;
|
||||
int32_t ne1;
|
||||
uint64_t nb1;
|
||||
int32_t nr0;
|
||||
} ggml_metal_kargs_mul_mv_id;
|
||||
|
||||
// NORM
|
||||
|
||||
@@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
} else {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
ggml_metal_kargs_mul_mv args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.nr0 =*/ nr0,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
@@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
|
||||
}
|
||||
} else {
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
ggml_metal_kargs_mul_mv_id args = {
|
||||
/*.nei0 =*/ ne20,
|
||||
/*.nei1 =*/ ne21,
|
||||
@@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nr0 =*/ nr0,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
|
||||
|
||||
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
|
||||
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
|
||||
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
|
||||
|
||||
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
||||
|
||||
if (ggml_is_quantized(op->src[0]->type)) {
|
||||
GGML_ASSERT(ne00 >= nsg*nr0);
|
||||
}
|
||||
|
||||
@@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl(
|
||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
template<typename T0, typename T1, short NR0>
|
||||
template<typename T0, typename T1, typename args_t>
|
||||
void kernel_mul_mv_t_t_disp(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
switch (args.nr0) {
|
||||
//case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
//case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
//case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_mul_mv_t_t(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
|
||||
typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
|
||||
@@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl(
|
||||
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0>
|
||||
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||
void kernel_mul_mv_t_t_4_disp(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
switch (args.nr0) {
|
||||
//case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
//case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
//case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
|
||||
};
|
||||
}
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14>
|
||||
kernel void kernel_mul_mv_t_t_4(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
@@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4(
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
|
||||
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
|
||||
#endif
|
||||
|
||||
#define N_MV_T_T 4
|
||||
|
||||
template<typename T04, typename T14, typename args_t>
|
||||
void kernel_mul_mv_c4_impl(
|
||||
template<typename T0, typename T1, typename args_t>
|
||||
void kernel_mul_mv_t_t_short_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl(
|
||||
uint3 tgpig,
|
||||
ushort tiisg) {
|
||||
const int r0 = tgpig.x*32 + tiisg;
|
||||
const int rb = tgpig.y*N_MV_T_T;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
if (r0 >= args.ne01) {
|
||||
@@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl(
|
||||
|
||||
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
device const T04 * x = (device const T04 *) (src0 + offset0);
|
||||
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
||||
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= args.ne11) {
|
||||
break;
|
||||
}
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||
|
||||
device const T14 * y = (device const T14 *) (src1 + offset1);
|
||||
float res = 0.0f;
|
||||
|
||||
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
|
||||
for (int i = 0; i < args.ne00; ++i) {
|
||||
res += (float) x[i] * (float) y[i];
|
||||
}
|
||||
|
||||
dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
|
||||
}
|
||||
|
||||
template<typename T04, typename T14>
|
||||
kernel void kernel_mul_mv_c4(
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_mul_mv_t_t_short(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
|
||||
kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
|
||||
args,
|
||||
src0,
|
||||
src1,
|
||||
@@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4(
|
||||
tiisg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
||||
typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>;
|
||||
template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
|
||||
#endif
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
@@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m
|
||||
// matrix-vector multiplication
|
||||
//
|
||||
|
||||
typedef void (kernel_mul_mv_impl_t)(
|
||||
typedef void (kernel_mul_mv_disp_t)(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)(
|
||||
uint3 tgpig,
|
||||
ushort tiisg);
|
||||
|
||||
typedef void (kernel_mul_mv2_impl_t)(
|
||||
typedef void (kernel_mul_mv2_disp_t)(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)(
|
||||
ushort tiisg,
|
||||
ushort sgitg);
|
||||
|
||||
template<kernel_mul_mv_impl_t impl_fn>
|
||||
template<kernel_mul_mv_disp_t disp_fn>
|
||||
void mmv_fn(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
@@ -8487,10 +8520,10 @@ void mmv_fn(
|
||||
ushort tiitg,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
impl_fn(args, src0, src1, dst, tgpig, tiisg);
|
||||
disp_fn(args, src0, src1, dst, tgpig, tiisg);
|
||||
}
|
||||
|
||||
template<kernel_mul_mv2_impl_t impl_fn>
|
||||
template<kernel_mul_mv2_disp_t disp_fn>
|
||||
void mmv_fn(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
@@ -8501,12 +8534,12 @@ void mmv_fn(
|
||||
ushort tiitg,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
|
||||
|
||||
template<mul_mv_impl_fn_t impl_fn>
|
||||
template<mul_mv_disp_fn_t disp_fn>
|
||||
kernel void kernel_mul_mv_id(
|
||||
constant ggml_metal_kargs_mul_mv_id & args,
|
||||
device const char * src0s,
|
||||
@@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id(
|
||||
/*.nb13 =*/ args.nb12, // ne12 == 1
|
||||
/*.ne0 =*/ args.ne0,
|
||||
/*.ne1 =*/ 1, // args.ne1,
|
||||
/*.nr0 =*/ args.nr0,
|
||||
/*.r2 =*/ 1,
|
||||
/*.r3 =*/ 1,
|
||||
};
|
||||
|
||||
impl_fn(
|
||||
disp_fn(
|
||||
args0,
|
||||
/* src0 */ src0_cur,
|
||||
/* src1 */ src1_cur,
|
||||
@@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id(
|
||||
sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
||||
|
||||
@@ -1 +1 @@
|
||||
83a15e113b130337a892fb6575c337754557d56f
|
||||
72632094336524a9c809e129e8b1c52154543a5a
|
||||
|
||||
@@ -11751,6 +11751,7 @@ struct llm_graph_context_mamba : public llm_graph_context {
|
||||
// TODO: skip computing output earlier for unused tokens
|
||||
|
||||
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
|
||||
cb(y, "mamba2_y_add_d", il);
|
||||
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
|
||||
|
||||
// grouped RMS norm
|
||||
@@ -14705,6 +14706,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
ggml_build_forward_expand(gf, inpL);
|
||||
|
||||
auto * inp = build_inp_mem_hybrid();
|
||||
|
||||
@@ -14736,7 +14738,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba {
|
||||
|
||||
// add residual
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "block_out", il);
|
||||
cb(cur, "nemotron_h_block_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
|
||||
@@ -126,52 +126,35 @@ int main(void) {
|
||||
assert(params.cpuparams.n_threads == 1010);
|
||||
#endif // _WIN32
|
||||
|
||||
if (common_has_curl()) {
|
||||
printf("test-arg-parser: test curl-related functions\n\n");
|
||||
const char * GOOD_URL = "https://ggml.ai/";
|
||||
const char * BAD_URL = "https://www.google.com/404";
|
||||
const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
|
||||
printf("test-arg-parser: test curl-related functions\n\n");
|
||||
const char * GOOD_URL = "http://ggml.ai/";
|
||||
const char * BAD_URL = "http://ggml.ai/404";
|
||||
|
||||
{
|
||||
printf("test-arg-parser: test good URL\n\n");
|
||||
auto res = common_remote_get_content(GOOD_URL, {});
|
||||
assert(res.first == 200);
|
||||
assert(res.second.size() > 0);
|
||||
std::string str(res.second.data(), res.second.size());
|
||||
assert(str.find("llama.cpp") != std::string::npos);
|
||||
}
|
||||
{
|
||||
printf("test-arg-parser: test good URL\n\n");
|
||||
auto res = common_remote_get_content(GOOD_URL, {});
|
||||
assert(res.first == 200);
|
||||
assert(res.second.size() > 0);
|
||||
std::string str(res.second.data(), res.second.size());
|
||||
assert(str.find("llama.cpp") != std::string::npos);
|
||||
}
|
||||
|
||||
{
|
||||
printf("test-arg-parser: test bad URL\n\n");
|
||||
auto res = common_remote_get_content(BAD_URL, {});
|
||||
assert(res.first == 404);
|
||||
}
|
||||
{
|
||||
printf("test-arg-parser: test bad URL\n\n");
|
||||
auto res = common_remote_get_content(BAD_URL, {});
|
||||
assert(res.first == 404);
|
||||
}
|
||||
|
||||
{
|
||||
printf("test-arg-parser: test max size error\n");
|
||||
common_remote_params params;
|
||||
params.max_size = 1;
|
||||
try {
|
||||
common_remote_get_content(GOOD_URL, params);
|
||||
assert(false && "it should throw an error");
|
||||
} catch (std::exception & e) {
|
||||
printf(" expected error: %s\n\n", e.what());
|
||||
}
|
||||
{
|
||||
printf("test-arg-parser: test max size error\n");
|
||||
common_remote_params params;
|
||||
params.max_size = 1;
|
||||
try {
|
||||
common_remote_get_content(GOOD_URL, params);
|
||||
assert(false && "it should throw an error");
|
||||
} catch (std::exception & e) {
|
||||
printf(" expected error: %s\n\n", e.what());
|
||||
}
|
||||
|
||||
{
|
||||
printf("test-arg-parser: test timeout error\n");
|
||||
common_remote_params params;
|
||||
params.timeout = 1;
|
||||
try {
|
||||
common_remote_get_content(BIG_FILE, params);
|
||||
assert(false && "it should throw an error");
|
||||
} catch (std::exception & e) {
|
||||
printf(" expected error: %s\n\n", e.what());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
printf("test-arg-parser: no curl, skipping curl-related functions\n");
|
||||
}
|
||||
|
||||
printf("test-arg-parser: all tests OK\n\n");
|
||||
|
||||
@@ -2140,6 +2140,27 @@ struct test_set_rows : public test_case {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL ||
|
||||
type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) {
|
||||
// estimate what the max nmse error would be if one quantized value is
|
||||
// off by one. The test values are distributed in [-1,1], so it'll be
|
||||
// roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference,
|
||||
// which is roughly 0.25 times the number of elements.
|
||||
double err_estimate = 1.0f/8.0f;
|
||||
if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
|
||||
err_estimate /= 2.0f;
|
||||
}
|
||||
if (type == GGML_TYPE_Q8_0) {
|
||||
err_estimate /= 8.0f;
|
||||
}
|
||||
err_estimate *= err_estimate;
|
||||
err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]);
|
||||
return err_estimate;
|
||||
}
|
||||
return 1e-7;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ARGMAX
|
||||
@@ -2430,6 +2451,30 @@ struct test_cpy : public test_case {
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
if (type_src == type_dst) {
|
||||
return 0.0;
|
||||
}
|
||||
if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL ||
|
||||
type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) {
|
||||
// estimate what the max nmse error would be if one quantized value is
|
||||
// off by one. The test values are distributed in [-150,150], so it'll be
|
||||
// roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference,
|
||||
// which is roughly 0.25*150^2 times the number of elements.
|
||||
double err_estimate = 1.0f/8.0f * 150.0f;
|
||||
if (type_dst == GGML_TYPE_IQ4_NL) {
|
||||
// iq4_nl values are a bit more spread out
|
||||
err_estimate *= 2.0f;
|
||||
}
|
||||
if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) {
|
||||
err_estimate /= 2.0f;
|
||||
}
|
||||
if (type_dst == GGML_TYPE_Q8_0) {
|
||||
err_estimate /= 8.0f;
|
||||
}
|
||||
err_estimate *= err_estimate;
|
||||
err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]);
|
||||
return err_estimate;
|
||||
}
|
||||
return 1e-6;
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
@@ -1,7 +1,8 @@
|
||||
/**
|
||||
* Parses thinking content from a message that may contain <think> tags
|
||||
* Parses thinking content from a message that may contain <think> tags or [THINK] tags
|
||||
* Returns an object with thinking content and cleaned message content
|
||||
* Handles both complete <think>...</think> blocks and incomplete <think> blocks (streaming)
|
||||
* Handles both complete blocks and incomplete blocks (streaming)
|
||||
* Supports formats: <think>...</think> and [THINK]...[/THINK]
|
||||
* @param content - The message content to parse
|
||||
* @returns An object containing the extracted thinking content and the cleaned message content
|
||||
*/
|
||||
@@ -9,12 +10,11 @@ export function parseThinkingContent(content: string): {
|
||||
thinking: string | null;
|
||||
cleanContent: string;
|
||||
} {
|
||||
const incompleteMatch = content.includes('<think>') && !content.includes('</think>');
|
||||
const incompleteThinkMatch = content.includes('<think>') && !content.includes('</think>');
|
||||
const incompleteThinkBracketMatch = content.includes('[THINK]') && !content.includes('[/THINK]');
|
||||
|
||||
if (incompleteMatch) {
|
||||
// Remove the entire <think>... part from clean content
|
||||
if (incompleteThinkMatch) {
|
||||
const cleanContent = content.split('</think>')?.[1]?.trim();
|
||||
// Extract everything after <think> as thinking content
|
||||
const thinkingContent = content.split('<think>')?.[1]?.trim();
|
||||
|
||||
return {
|
||||
@@ -23,12 +23,40 @@ export function parseThinkingContent(content: string): {
|
||||
};
|
||||
}
|
||||
|
||||
const completeMatch = content.includes('</think>');
|
||||
if (incompleteThinkBracketMatch) {
|
||||
const cleanContent = content.split('[/THINK]')?.[1]?.trim();
|
||||
const thinkingContent = content.split('[THINK]')?.[1]?.trim();
|
||||
|
||||
if (completeMatch) {
|
||||
return {
|
||||
thinking: content.split('</think>')?.[0]?.trim(),
|
||||
cleanContent: content.split('</think>')?.[1]?.trim()
|
||||
cleanContent,
|
||||
thinking: thinkingContent
|
||||
};
|
||||
}
|
||||
|
||||
const completeThinkMatch = content.match(/<think>([\s\S]*?)<\/think>/);
|
||||
const completeThinkBracketMatch = content.match(/\[THINK\]([\s\S]*?)\[\/THINK\]/);
|
||||
|
||||
if (completeThinkMatch) {
|
||||
const thinkingContent = completeThinkMatch[1]?.trim() ?? '';
|
||||
const cleanContent = `${content.slice(0, completeThinkMatch.index ?? 0)}${content.slice(
|
||||
(completeThinkMatch.index ?? 0) + completeThinkMatch[0].length
|
||||
)}`.trim();
|
||||
|
||||
return {
|
||||
thinking: thinkingContent,
|
||||
cleanContent
|
||||
};
|
||||
}
|
||||
|
||||
if (completeThinkBracketMatch) {
|
||||
const thinkingContent = completeThinkBracketMatch[1]?.trim() ?? '';
|
||||
const cleanContent = `${content.slice(0, completeThinkBracketMatch.index ?? 0)}${content.slice(
|
||||
(completeThinkBracketMatch.index ?? 0) + completeThinkBracketMatch[0].length
|
||||
)}`.trim();
|
||||
|
||||
return {
|
||||
thinking: thinkingContent,
|
||||
cleanContent
|
||||
};
|
||||
}
|
||||
|
||||
@@ -39,26 +67,33 @@ export function parseThinkingContent(content: string): {
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if content contains an opening <think> tag (for streaming)
|
||||
* Checks if content contains an opening thinking tag (for streaming)
|
||||
* Supports both <think> and [THINK] formats
|
||||
* @param content - The message content to check
|
||||
* @returns True if the content contains an opening <think> tag
|
||||
* @returns True if the content contains an opening thinking tag
|
||||
*/
|
||||
export function hasThinkingStart(content: string): boolean {
|
||||
return content.includes('<think>') || content.includes('<|channel|>analysis');
|
||||
return (
|
||||
content.includes('<think>') ||
|
||||
content.includes('[THINK]') ||
|
||||
content.includes('<|channel|>analysis')
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if content contains a closing </think> tag (for streaming)
|
||||
* Checks if content contains a closing thinking tag (for streaming)
|
||||
* Supports both </think> and [/THINK] formats
|
||||
* @param content - The message content to check
|
||||
* @returns True if the content contains a closing </think> tag
|
||||
* @returns True if the content contains a closing thinking tag
|
||||
*/
|
||||
export function hasThinkingEnd(content: string): boolean {
|
||||
return content.includes('</think>');
|
||||
return content.includes('</think>') || content.includes('[/THINK]');
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts partial thinking content during streaming
|
||||
* Used when we have <think> but not yet </think>
|
||||
* Supports both <think> and [THINK] formats
|
||||
* Used when we have opening tag but not yet closing tag
|
||||
* @param content - The message content to extract partial thinking from
|
||||
* @returns An object containing the extracted partial thinking content and the remaining content
|
||||
*/
|
||||
@@ -66,23 +101,41 @@ export function extractPartialThinking(content: string): {
|
||||
thinking: string | null;
|
||||
remainingContent: string;
|
||||
} {
|
||||
const startIndex = content.indexOf('<think>');
|
||||
if (startIndex === -1) {
|
||||
const thinkStartIndex = content.indexOf('<think>');
|
||||
const thinkEndIndex = content.indexOf('</think>');
|
||||
|
||||
const bracketStartIndex = content.indexOf('[THINK]');
|
||||
const bracketEndIndex = content.indexOf('[/THINK]');
|
||||
|
||||
const useThinkFormat =
|
||||
thinkStartIndex !== -1 && (bracketStartIndex === -1 || thinkStartIndex < bracketStartIndex);
|
||||
const useBracketFormat =
|
||||
bracketStartIndex !== -1 && (thinkStartIndex === -1 || bracketStartIndex < thinkStartIndex);
|
||||
|
||||
if (useThinkFormat) {
|
||||
if (thinkEndIndex === -1) {
|
||||
const thinkingStart = thinkStartIndex + '<think>'.length;
|
||||
|
||||
return {
|
||||
thinking: content.substring(thinkingStart),
|
||||
remainingContent: content.substring(0, thinkStartIndex)
|
||||
};
|
||||
}
|
||||
} else if (useBracketFormat) {
|
||||
if (bracketEndIndex === -1) {
|
||||
const thinkingStart = bracketStartIndex + '[THINK]'.length;
|
||||
|
||||
return {
|
||||
thinking: content.substring(thinkingStart),
|
||||
remainingContent: content.substring(0, bracketStartIndex)
|
||||
};
|
||||
}
|
||||
} else {
|
||||
return { thinking: null, remainingContent: content };
|
||||
}
|
||||
|
||||
const endIndex = content.indexOf('</think>');
|
||||
if (endIndex === -1) {
|
||||
// Still streaming thinking content
|
||||
const thinkingStart = startIndex + '<think>'.length;
|
||||
return {
|
||||
thinking: content.substring(thinkingStart),
|
||||
remainingContent: content.substring(0, startIndex)
|
||||
};
|
||||
}
|
||||
|
||||
// Complete thinking block found
|
||||
const parsed = parseThinkingContent(content);
|
||||
|
||||
return {
|
||||
thinking: parsed.thinking,
|
||||
remainingContent: parsed.cleanContent
|
||||
|
||||
@@ -59,6 +59,60 @@
|
||||
thinking: '',
|
||||
children: []
|
||||
});
|
||||
|
||||
// Message with <think> format thinking content
|
||||
const thinkTagMessage: DatabaseMessage = {
|
||||
id: '6',
|
||||
convId: 'conv-1',
|
||||
type: 'message',
|
||||
timestamp: Date.now() - 1000 * 60 * 2,
|
||||
role: 'assistant',
|
||||
content:
|
||||
"<think>\nLet me analyze this step by step:\n\n1. The user is asking about thinking formats\n2. I need to demonstrate the <think> tag format\n3. This content should be displayed in the thinking section\n4. The main response should be separate\n\nThis is a good example of reasoning content.\n</think>\n\nHere's my response after thinking through the problem. The thinking content above should be displayed separately from this main response content.",
|
||||
parent: '1',
|
||||
thinking: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
// Message with [THINK] format thinking content
|
||||
const thinkBracketMessage: DatabaseMessage = {
|
||||
id: '7',
|
||||
convId: 'conv-1',
|
||||
type: 'message',
|
||||
timestamp: Date.now() - 1000 * 60 * 1,
|
||||
role: 'assistant',
|
||||
content:
|
||||
'[THINK]\nThis is the DeepSeek-style thinking format:\n\n- Using square brackets instead of angle brackets\n- Should work identically to the <think> format\n- Content parsing should extract this reasoning\n- Display should be the same as <think> format\n\nBoth formats should be supported seamlessly.\n[/THINK]\n\nThis is the main response content that comes after the [THINK] block. The reasoning above should be parsed and displayed in the thinking section.',
|
||||
parent: '1',
|
||||
thinking: '',
|
||||
children: []
|
||||
};
|
||||
|
||||
// Streaming message for <think> format
|
||||
let streamingThinkMessage = $state({
|
||||
id: '8',
|
||||
convId: 'conv-1',
|
||||
type: 'message',
|
||||
timestamp: 0, // No timestamp = streaming
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
parent: '1',
|
||||
thinking: '',
|
||||
children: []
|
||||
});
|
||||
|
||||
// Streaming message for [THINK] format
|
||||
let streamingBracketMessage = $state({
|
||||
id: '9',
|
||||
convId: 'conv-1',
|
||||
type: 'message',
|
||||
timestamp: 0, // No timestamp = streaming
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
parent: '1',
|
||||
thinking: '',
|
||||
children: []
|
||||
});
|
||||
</script>
|
||||
|
||||
<Story
|
||||
@@ -144,3 +198,115 @@
|
||||
await new Promise(resolve => setTimeout(resolve, 100));
|
||||
}}
|
||||
/>
|
||||
|
||||
<Story
|
||||
name="ThinkTagFormat"
|
||||
args={{
|
||||
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
|
||||
message: thinkTagMessage
|
||||
}}
|
||||
/>
|
||||
|
||||
<Story
|
||||
name="ThinkBracketFormat"
|
||||
args={{
|
||||
class: 'max-w-[56rem] w-[calc(100vw-2rem)]',
|
||||
message: thinkBracketMessage
|
||||
}}
|
||||
/>
|
||||
|
||||
<Story
|
||||
name="StreamingThinkTag"
|
||||
args={{
|
||||
message: streamingThinkMessage
|
||||
}}
|
||||
parameters={{
|
||||
test: {
|
||||
timeout: 30000
|
||||
}
|
||||
}}
|
||||
asChild
|
||||
play={async () => {
|
||||
// Phase 1: Stream <think> reasoning content
|
||||
const thinkingContent =
|
||||
'Let me work through this problem systematically:\n\n1. First, I need to understand what the user is asking\n2. Then I should consider different approaches\n3. I need to evaluate the pros and cons\n4. Finally, I should provide a clear recommendation\n\nThis step-by-step approach will ensure accuracy.';
|
||||
|
||||
let currentContent = '<think>\n';
|
||||
streamingThinkMessage.content = currentContent;
|
||||
|
||||
for (let i = 0; i < thinkingContent.length; i++) {
|
||||
currentContent += thinkingContent[i];
|
||||
streamingThinkMessage.content = currentContent;
|
||||
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||
}
|
||||
|
||||
// Close the thinking block
|
||||
currentContent += '\n</think>\n\n';
|
||||
streamingThinkMessage.content = currentContent;
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
|
||||
// Phase 2: Stream main response content
|
||||
const responseContent =
|
||||
"Based on my analysis above, here's the solution:\n\n**Key Points:**\n- The approach should be systematic\n- We need to consider all factors\n- Implementation should be step-by-step\n\nThis ensures the best possible outcome.";
|
||||
|
||||
for (let i = 0; i < responseContent.length; i++) {
|
||||
currentContent += responseContent[i];
|
||||
streamingThinkMessage.content = currentContent;
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
}
|
||||
|
||||
streamingThinkMessage.timestamp = Date.now();
|
||||
}}
|
||||
>
|
||||
<div class="w-[56rem]">
|
||||
<ChatMessage message={streamingThinkMessage} />
|
||||
</div>
|
||||
</Story>
|
||||
|
||||
<Story
|
||||
name="StreamingThinkBracket"
|
||||
args={{
|
||||
message: streamingBracketMessage
|
||||
}}
|
||||
parameters={{
|
||||
test: {
|
||||
timeout: 30000
|
||||
}
|
||||
}}
|
||||
asChild
|
||||
play={async () => {
|
||||
// Phase 1: Stream [THINK] reasoning content
|
||||
const thinkingContent =
|
||||
'Using the DeepSeek format now:\n\n- This demonstrates the [THINK] bracket format\n- Should parse identically to <think> tags\n- The UI should display this in the thinking section\n- Main content should be separate\n\nBoth formats provide the same functionality.';
|
||||
|
||||
let currentContent = '[THINK]\n';
|
||||
streamingBracketMessage.content = currentContent;
|
||||
|
||||
for (let i = 0; i < thinkingContent.length; i++) {
|
||||
currentContent += thinkingContent[i];
|
||||
streamingBracketMessage.content = currentContent;
|
||||
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||
}
|
||||
|
||||
// Close the thinking block
|
||||
currentContent += '\n[/THINK]\n\n';
|
||||
streamingBracketMessage.content = currentContent;
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
|
||||
// Phase 2: Stream main response content
|
||||
const responseContent =
|
||||
"Here's my response after using the [THINK] format:\n\n**Observations:**\n- Both <think> and [THINK] formats work seamlessly\n- The parsing logic handles both cases\n- UI display is consistent across formats\n\nThis demonstrates the enhanced thinking content support.";
|
||||
|
||||
for (let i = 0; i < responseContent.length; i++) {
|
||||
currentContent += responseContent[i];
|
||||
streamingBracketMessage.content = currentContent;
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
}
|
||||
|
||||
streamingBracketMessage.timestamp = Date.now();
|
||||
}}
|
||||
>
|
||||
<div class="w-[56rem]">
|
||||
<ChatMessage message={streamingBracketMessage} />
|
||||
</div>
|
||||
</Story>
|
||||
|
||||
Reference in New Issue
Block a user