mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-18 06:54:06 +00:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05fa625eac | ||
|
|
d612901116 | ||
|
|
cceb1b4e33 | ||
|
|
d23a55997d | ||
|
|
5f28c53d11 | ||
|
|
4408494144 | ||
|
|
2ba9adc093 | ||
|
|
cc45f2ada6 | ||
|
|
d5dfc33027 | ||
|
|
267ba5a1d9 | ||
|
|
ff4affb4c1 | ||
|
|
55d58599c8 | ||
|
|
1a8c700bfd | ||
|
|
27b93cbd15 | ||
|
|
6e67fd2144 | ||
|
|
9e118b97c4 | ||
|
|
57088276d4 | ||
|
|
341bc7d23c | ||
|
|
08e6d914b8 | ||
|
|
184c694f45 | ||
|
|
684b36101c | ||
|
|
3a00c98584 | ||
|
|
079feab9e3 | ||
|
|
01d8eaa28d | ||
|
|
1725e316c1 | ||
|
|
b7742cf321 | ||
|
|
badba89320 | ||
|
|
baa12f3831 | ||
|
|
2d8015e8a4 | ||
|
|
eb145c0753 | ||
|
|
6e473fb384 | ||
|
|
c7db95f106 | ||
|
|
0d00ef65ed | ||
|
|
91ea5d67f2 | ||
|
|
dbb023336b | ||
|
|
53aef25a88 | ||
|
|
2dec548094 | ||
|
|
0ccbfdef3e | ||
|
|
94a602db66 | ||
|
|
05a6f0e894 | ||
|
|
b48e80f677 |
@@ -41,7 +41,7 @@ body:
|
||||
attributes:
|
||||
label: GGML backends
|
||||
description: Which GGML backends do you know to be affected?
|
||||
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
|
||||
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
|
||||
multiple: true
|
||||
validations:
|
||||
required: true
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
2
.github/ISSUE_TEMPLATE/011-bug-results.yml
vendored
@@ -42,7 +42,7 @@ body:
|
||||
attributes:
|
||||
label: GGML backends
|
||||
description: Which GGML backends do you know to be affected?
|
||||
options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN]
|
||||
options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN]
|
||||
multiple: true
|
||||
validations:
|
||||
required: true
|
||||
|
||||
@@ -112,15 +112,9 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_
|
||||
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
|
||||
|
||||
# 3rd party libs
|
||||
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)
|
||||
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
|
||||
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
|
||||
|
||||
# deprecated
|
||||
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
|
||||
if (LLAMA_CURL)
|
||||
message(WARNING "LLAMA_CURL option is deprecated and will be ignored")
|
||||
endif()
|
||||
|
||||
# Required for relocatable CMake package
|
||||
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
|
||||
@@ -148,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS)
|
||||
endif()
|
||||
|
||||
# transition helpers
|
||||
function (llama_option_depr TYPE OLD NEW)
|
||||
function (llama_option_depr TYPE OLD)
|
||||
if (${OLD})
|
||||
message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n")
|
||||
set(${NEW} ON PARENT_SCOPE)
|
||||
set(NEW "${ARGV2}")
|
||||
if(NEW)
|
||||
message(${TYPE} "${OLD} is deprecated, use ${NEW} instead")
|
||||
set(${NEW} ON PARENT_SCOPE)
|
||||
else()
|
||||
message(${TYPE} "${OLD} is deprecated and will be ignored")
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
@@ -164,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC)
|
||||
llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
|
||||
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
|
||||
llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
|
||||
llama_option_depr(WARNING LLAMA_CURL)
|
||||
|
||||
include("cmake/license.cmake")
|
||||
license_add_file("llama.cpp" "LICENSE")
|
||||
@@ -197,9 +197,7 @@ add_subdirectory(src)
|
||||
|
||||
if (LLAMA_BUILD_COMMON)
|
||||
add_subdirectory(common)
|
||||
if (LLAMA_HTTPLIB)
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
add_subdirectory(vendor/cpp-httplib)
|
||||
endif()
|
||||
|
||||
if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION)
|
||||
|
||||
@@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=(
|
||||
-DGGML_OPENMP=${GGML_OPENMP}
|
||||
)
|
||||
|
||||
XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
check_required_tool() {
|
||||
local tool=$1
|
||||
local install_message=$2
|
||||
@@ -60,9 +55,12 @@ check_required_tool() {
|
||||
}
|
||||
echo "Checking for required tools..."
|
||||
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
|
||||
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
|
||||
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
||||
|
||||
XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }')
|
||||
MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1)
|
||||
MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2)
|
||||
echo "Detected Xcode version: $XCODE_VERSION"
|
||||
|
||||
set -e
|
||||
|
||||
@@ -260,7 +258,7 @@ combine_static_libraries() {
|
||||
|
||||
# Since we have multiple architectures libtool will find object files that do not
|
||||
# match the target architecture. We suppress these warnings.
|
||||
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
||||
xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
||||
|
||||
# Determine SDK, architectures, and install_name based on platform and simulator flag.
|
||||
local sdk=""
|
||||
@@ -333,7 +331,7 @@ combine_static_libraries() {
|
||||
|
||||
# Platform-specific post-processing for device builds
|
||||
if [[ "$is_simulator" == "false" ]]; then
|
||||
if command -v xcrun vtool &>/dev/null; then
|
||||
if xcrun -f vtool &>/dev/null; then
|
||||
case "$platform" in
|
||||
"ios")
|
||||
echo "Marking binary as a framework binary for iOS..."
|
||||
@@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xros \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -- -quiet
|
||||
@@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_SYSTEM_NAME=visionOS \
|
||||
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
||||
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
||||
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_HTTPLIB=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -- -quiet
|
||||
@@ -528,7 +524,7 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
|
||||
|
||||
# Create XCFramework with correct debug symbols paths
|
||||
echo "Creating XCFramework..."
|
||||
xcodebuild -create-xcframework \
|
||||
xcrun xcodebuild -create-xcframework \
|
||||
-framework $(pwd)/build-ios-sim/framework/llama.framework \
|
||||
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
|
||||
-framework $(pwd)/build-ios-device/framework/llama.framework \
|
||||
|
||||
@@ -112,11 +112,7 @@ endif()
|
||||
|
||||
# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS build_info)
|
||||
|
||||
if (LLAMA_HTTPLIB)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB)
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
endif()
|
||||
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib)
|
||||
|
||||
if (LLAMA_LLGUIDANCE)
|
||||
include(ExternalProject)
|
||||
|
||||
@@ -879,7 +879,8 @@ std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
@@ -1223,7 +1224,7 @@ common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
return res;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
int err = llama_set_adapter_cvec(
|
||||
lctx,
|
||||
cvec.data.data(),
|
||||
cvec.data.size(),
|
||||
@@ -1325,12 +1326,15 @@ std::string get_model_endpoint() {
|
||||
}
|
||||
|
||||
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
||||
llama_clear_adapter_lora(ctx);
|
||||
for (auto & la : lora) {
|
||||
if (la.scale != 0.0f) {
|
||||
llama_set_adapter_lora(ctx, la.ptr, la.scale);
|
||||
}
|
||||
std::vector<llama_adapter_lora *> loras;
|
||||
std::vector<float> scales;
|
||||
|
||||
for (auto & la: lora) {
|
||||
loras.push_back(la.ptr);
|
||||
scales.push_back(la.scale);
|
||||
}
|
||||
|
||||
llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
|
||||
}
|
||||
|
||||
struct llama_model_params common_model_params_to_llama(common_params & params) {
|
||||
|
||||
@@ -670,7 +670,7 @@ static std::vector<T> string_split(const std::string & str, char delim) {
|
||||
}
|
||||
|
||||
template<>
|
||||
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||
inline std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
||||
{
|
||||
std::vector<std::string> parts;
|
||||
size_t begin_pos = 0;
|
||||
@@ -685,7 +685,7 @@ std::vector<std::string> string_split<std::string>(const std::string & input, ch
|
||||
return parts;
|
||||
}
|
||||
|
||||
static bool string_starts_with(const std::string & str,
|
||||
inline bool string_starts_with(const std::string & str,
|
||||
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
||||
return str.rfind(prefix, 0) == 0;
|
||||
}
|
||||
@@ -870,11 +870,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
inline std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
}
|
||||
|
||||
static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
||||
return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() };
|
||||
}
|
||||
|
||||
|
||||
@@ -19,9 +19,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
#include "http.h"
|
||||
#endif
|
||||
|
||||
#ifndef __EMSCRIPTEN__
|
||||
#ifdef __linux__
|
||||
@@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) {
|
||||
}
|
||||
|
||||
static std::string read_etag(const std::string & path) {
|
||||
std::string none;
|
||||
const std::string etag_path = path + ".etag";
|
||||
|
||||
if (std::filesystem::exists(etag_path)) {
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return none;
|
||||
}
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
if (!std::filesystem::exists(etag_path)) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// no etag file, but maybe there is an old .json
|
||||
// remove this code later
|
||||
const std::string metadata_path = path + ".json";
|
||||
|
||||
if (std::filesystem::exists(metadata_path)) {
|
||||
std::ifstream metadata_in(metadata_path);
|
||||
try {
|
||||
nlohmann::json metadata_json;
|
||||
metadata_in >> metadata_json;
|
||||
LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(),
|
||||
metadata_json.dump().c_str());
|
||||
if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) {
|
||||
std::string etag = metadata_json.at("etag");
|
||||
write_etag(path, etag);
|
||||
if (!std::filesystem::remove(metadata_path)) {
|
||||
LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str());
|
||||
}
|
||||
return etag;
|
||||
}
|
||||
} catch (const nlohmann::json::exception & e) {
|
||||
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
||||
}
|
||||
std::ifstream etag_in(etag_path);
|
||||
if (!etag_in) {
|
||||
LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
||||
return {};
|
||||
}
|
||||
return none;
|
||||
std::string etag;
|
||||
std::getline(etag_in, etag);
|
||||
return etag;
|
||||
}
|
||||
|
||||
static bool is_http_status_ok(int status) {
|
||||
@@ -168,8 +140,6 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
|
||||
return {hf_repo, tag};
|
||||
}
|
||||
|
||||
#if defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
class ProgressBar {
|
||||
static inline std::mutex mutex;
|
||||
static inline std::map<const ProgressBar *, int> lines;
|
||||
@@ -347,62 +317,64 @@ static int common_download_file_single_online(const std::string & url,
|
||||
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
auto head = cli.Head(parts.path);
|
||||
bool head_ok = head && head->status >= 200 && head->status < 300;
|
||||
if (!head_ok) {
|
||||
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head->status; // cannot use cached file, return raw status code
|
||||
// TODO: maybe retry only on certain codes
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
if (head_ok && head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
if (head_ok && head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool supports_ranges = false;
|
||||
if (head_ok && head->has_header("Accept-Ranges")) {
|
||||
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
||||
}
|
||||
|
||||
bool should_download_from_scratch = false;
|
||||
if (!last_etag.empty() && !etag.empty() && last_etag != etag) {
|
||||
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__,
|
||||
last_etag.c_str(), etag.c_str());
|
||||
should_download_from_scratch = true;
|
||||
}
|
||||
|
||||
auto head = cli.Head(parts.path);
|
||||
if (!head || head->status < 200 || head->status >= 300) {
|
||||
LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
|
||||
if (file_exists) {
|
||||
if (!should_download_from_scratch) {
|
||||
LOG_INF("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
}
|
||||
LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
return head ? head->status : -1;
|
||||
}
|
||||
|
||||
std::string etag;
|
||||
if (head->has_header("ETag")) {
|
||||
etag = head->get_header_value("ETag");
|
||||
}
|
||||
|
||||
size_t total_size = 0;
|
||||
if (head->has_header("Content-Length")) {
|
||||
try {
|
||||
total_size = std::stoull(head->get_header_value("Content-Length"));
|
||||
} catch (const std::exception& e) {
|
||||
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool supports_ranges = false;
|
||||
if (head->has_header("Accept-Ranges")) {
|
||||
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
||||
}
|
||||
|
||||
if (file_exists) {
|
||||
if (etag.empty()) {
|
||||
LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
if (!last_etag.empty() && last_etag == etag) {
|
||||
LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
int delay = retry_delay_seconds;
|
||||
|
||||
for (int i = 0; i < max_attempts; ++i) {
|
||||
if (i) {
|
||||
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(delay));
|
||||
delay *= retry_delay_seconds;
|
||||
}
|
||||
|
||||
const std::string path_temporary = path + ".downloadInProgress";
|
||||
size_t existing_size = 0;
|
||||
|
||||
if (std::filesystem::exists(path_temporary)) {
|
||||
if (supports_ranges && !should_download_from_scratch) {
|
||||
if (supports_ranges) {
|
||||
existing_size = std::filesystem::file_size(path_temporary);
|
||||
} else if (remove(path_temporary.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
||||
@@ -410,32 +382,23 @@ static int common_download_file_single_online(const std::string & url,
|
||||
}
|
||||
}
|
||||
|
||||
// start the download
|
||||
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
|
||||
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
|
||||
if (!was_pull_successful) {
|
||||
if (i + 1 < max_attempts) {
|
||||
const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000;
|
||||
LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay);
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
||||
} else {
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
|
||||
__func__, common_http_show_masked_url(parts).c_str(),
|
||||
path_temporary.c_str(), etag.c_str());
|
||||
|
||||
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
continue;
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
return head->status;
|
||||
}
|
||||
|
||||
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
||||
return -1;
|
||||
}
|
||||
if (!etag.empty()) {
|
||||
write_etag(path, etag);
|
||||
}
|
||||
|
||||
return head->status; // TODO: use actual GET status?
|
||||
}
|
||||
|
||||
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
||||
return -1; // max attempts reached
|
||||
}
|
||||
|
||||
@@ -801,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) {
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
std::string common_docker_resolve_model(const std::string &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
int common_download_file_single(const std::string &,
|
||||
const std::string &,
|
||||
const std::string &,
|
||||
bool,
|
||||
const common_header_list &) {
|
||||
throw std::runtime_error("download functionality is not enabled in this build");
|
||||
}
|
||||
|
||||
#endif // defined(LLAMA_USE_HTTPLIB)
|
||||
|
||||
std::vector<common_cached_model_info> common_list_cached_models() {
|
||||
std::vector<common_cached_model_info> models;
|
||||
const std::string cache_dir = fs_get_cache_directory();
|
||||
|
||||
@@ -570,6 +570,7 @@ class ModelBase:
|
||||
self.match_model_tensor_name(new_name, key, bid)
|
||||
for key in (
|
||||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
gguf.MODEL_TENSOR.FFN_GATE_INP_SHEXP,
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||
@@ -1048,6 +1049,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902":
|
||||
# ref: https://huggingface.co/zai-org/GLM-4.5-Air
|
||||
res = "glm4"
|
||||
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
|
||||
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
|
||||
res = "glm4"
|
||||
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
|
||||
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
|
||||
res = "minerva-7b"
|
||||
@@ -1081,9 +1085,6 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
|
||||
# ref: https://huggingface.co/aari1995/German_Semantic_V3
|
||||
res = "jina-v2-de"
|
||||
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
|
||||
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
|
||||
res = "glm4"
|
||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||
res = "llama-bpe"
|
||||
@@ -1123,6 +1124,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
|
||||
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
|
||||
res = "command-r"
|
||||
if chkhsh == "d772b220ace2baec124bed8cfafce0ead7d6c38a4b65ef11261cf9d5d62246d1":
|
||||
# ref: https://huggingface.co/CohereLabs/tiny-aya-base
|
||||
res = "tiny_aya"
|
||||
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
||||
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
||||
res = "qwen2"
|
||||
@@ -1264,6 +1268,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
|
||||
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
|
||||
res = "qwen35"
|
||||
if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d":
|
||||
# ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash
|
||||
res = "joyai-llm"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
@@ -2725,8 +2732,6 @@ class AfmoeModel(LlamaModel):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# MoE parameters
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None:
|
||||
self.gguf_writer.add_expert_shared_count(n_shared_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
@@ -2748,7 +2753,7 @@ class AfmoeModel(LlamaModel):
|
||||
# Handle expert weights - they're already merged in the HF format
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -4073,6 +4078,87 @@ class InternVisionModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"NemotronH_Nano_VL_V2",
|
||||
"RADIOModel",
|
||||
)
|
||||
class NemotronNanoV2VLModel(MmprojModel):
|
||||
# ViT-Huge architecture parameters for RADIO v2.5-h
|
||||
_vit_hidden_size = 1280
|
||||
_vit_intermediate_size = 5120
|
||||
_vit_num_layers = 32
|
||||
_vit_num_heads = 16
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
# RADIO config doesn't have standard ViT parameters, so they need to be constructed manually
|
||||
vision_config = self.global_config.get("vision_config")
|
||||
if vision_config is None:
|
||||
return None
|
||||
# Add ViT-H parameters
|
||||
vision_config = {
|
||||
**vision_config,
|
||||
"hidden_size": self._vit_hidden_size,
|
||||
"intermediate_size": self._vit_intermediate_size,
|
||||
"num_hidden_layers": self._vit_num_layers,
|
||||
"num_attention_heads": self._vit_num_heads,
|
||||
"image_size": self.global_config.get("force_image_size", 512),
|
||||
}
|
||||
return vision_config
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if "image_mean" not in self.preprocessor_config:
|
||||
self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406]
|
||||
if "image_std" not in self.preprocessor_config:
|
||||
self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225]
|
||||
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.global_config
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL)
|
||||
self.gguf_writer.add_vision_attention_layernorm_eps(1e-6)
|
||||
self.gguf_writer.add_vision_use_gelu(True)
|
||||
downsample_ratio = hparams.get("downsample_ratio", 0.5)
|
||||
self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
|
||||
|
||||
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||
if ".position_embd." in new_name or "pos_embed" in new_name:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "input_conditioner" in name:
|
||||
return
|
||||
|
||||
# RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it
|
||||
if "patch_generator.pos_embed" in name:
|
||||
if not name.endswith(".weight"):
|
||||
name += ".weight"
|
||||
# Downsample position embeddings for fixed 512x512 image size
|
||||
import torch.nn.functional as F
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
image_size = self.global_config.get("force_image_size", 512)
|
||||
patch_size = self.hparams["patch_size"]
|
||||
target_patches_per_side = image_size // patch_size # 32
|
||||
max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128
|
||||
if target_patches_per_side != max_patches_per_side:
|
||||
# Reshape to grid, interpolate, flatten back
|
||||
data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd)
|
||||
data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128]
|
||||
data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side),
|
||||
mode='bilinear', align_corners=True)
|
||||
data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd]
|
||||
data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd)
|
||||
|
||||
# Reshape linear patch embedding to conv2d format for ggml_conv_2d
|
||||
# From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size]
|
||||
if "patch_generator.embedder" in name:
|
||||
patch_size = self.hparams["patch_size"]
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size)
|
||||
|
||||
if name.startswith("vision_model.radio_model.model.") or name.startswith("mlp1."):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("WavTokenizerDec")
|
||||
class WavTokenizerDecModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
||||
@@ -4115,8 +4201,6 @@ class Qwen2MoeModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
@@ -4161,7 +4245,7 @@ class Qwen2MoeModel(TextModel):
|
||||
return
|
||||
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -4912,13 +4996,13 @@ class PhiMoeModel(Phi3MiniModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
|
||||
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.hparams["num_local_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -5330,7 +5414,7 @@ class KimiLinearModel(TextModel):
|
||||
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False)
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -5925,12 +6009,13 @@ class NomicBertModel(BertModel):
|
||||
if "mlp.experts.bias" in name:
|
||||
return # Explicitly return.
|
||||
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
if "mlp.experts.mlp.w1" in name:
|
||||
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
name += ".weight"
|
||||
|
||||
if "mlp.experts.mlp.w2" in name:
|
||||
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"])
|
||||
data_torch = data_torch.transpose(1, 2)
|
||||
name += ".weight"
|
||||
|
||||
@@ -5940,7 +6025,6 @@ class NomicBertModel(BertModel):
|
||||
super().set_gguf_parameters()
|
||||
if self.is_moe:
|
||||
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
|
||||
|
||||
def _is_tokenizer_xlmroberta(self) -> bool:
|
||||
@@ -7054,6 +7138,8 @@ class Mamba2Model(TextModel):
|
||||
if hparams is None:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
if "llm_config" in hparams:
|
||||
hparams["text_config"] = hparams["llm_config"]
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
|
||||
@@ -7175,8 +7261,8 @@ class JambaModel(TextModel):
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"]))
|
||||
self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"]))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
@@ -7194,7 +7280,7 @@ class JambaModel(TextModel):
|
||||
|
||||
# process the experts separately
|
||||
if ".feed_forward.experts." in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
|
||||
assert bid is not None
|
||||
|
||||
@@ -7280,6 +7366,17 @@ class Cohere2Model(TextModel):
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads)))
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Cohere2 runtime in llama.cpp expects no bias tensors;
|
||||
# the actual weight only contains 0-value tensors as bias, we can skip them
|
||||
if name.endswith(".bias"):
|
||||
if torch.any(data_torch != 0):
|
||||
raise ValueError(f"Bias tensor {name!r} is not zero.")
|
||||
logger.debug(f"Skipping bias tensor {name!r} for Cohere2 conversion.")
|
||||
return
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("OlmoForCausalLM")
|
||||
@ModelBase.register("OLMoForCausalLM")
|
||||
@@ -7342,8 +7439,6 @@ class OlmoeModel(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
@@ -7351,7 +7446,7 @@ class OlmoeModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -7932,10 +8027,6 @@ class MiniMaxM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MINIMAXM2
|
||||
_experts_cache: dict[int, dict[str, Tensor]] = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hparams["num_experts"] = self.hparams["num_local_experts"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
@@ -7948,7 +8039,7 @@ class MiniMaxM2Model(TextModel):
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
@@ -9153,7 +9244,6 @@ class ExaoneMoEModel(Exaone4Model):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
moe_intermediate_size = self.hparams["moe_intermediate_size"]
|
||||
num_shared_experts = self.hparams["num_shared_experts"]
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
@@ -9194,7 +9284,7 @@ class ExaoneMoEModel(Exaone4Model):
|
||||
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -9345,7 +9435,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
||||
# case, the model architecture needs to be updated to a standard
|
||||
# "granite" or "granitemoe" model
|
||||
if not self._ssm_layers:
|
||||
has_experts = self.find_hparam(["num_experts_per_tok"], optional=True)
|
||||
has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True)
|
||||
new_arch = (
|
||||
gguf.MODEL_ARCH.GRANITE_MOE
|
||||
if has_experts else
|
||||
@@ -9541,6 +9631,14 @@ class NemotronHModel(GraniteHybridModel):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# Skip vision model and projector tensors for VLM models (handled by mmproj) (e.g., Nemotron Nano 12B v2 VL)
|
||||
if name.startswith(("vision_model.", "mlp1.")):
|
||||
return
|
||||
|
||||
# Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL)
|
||||
if name.startswith("language_model."):
|
||||
name = name[len("language_model."):]
|
||||
|
||||
if self.is_moe and bid is not None:
|
||||
if name.endswith("mixer.gate.e_score_correction_bias"):
|
||||
new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
|
||||
@@ -9635,7 +9733,6 @@ class BailingMoeModel(TextModel):
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_weights_scale(1.0)
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
@@ -9669,7 +9766,7 @@ class BailingMoeModel(TextModel):
|
||||
yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid)
|
||||
return
|
||||
elif name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -9740,7 +9837,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
@@ -9751,7 +9847,7 @@ class BailingMoeV2Model(TextModel):
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "mlp.experts" in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -9797,8 +9893,6 @@ class GroveMoeModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
@@ -9819,7 +9913,7 @@ class GroveMoeModel(TextModel):
|
||||
|
||||
# process the experts separately
|
||||
if name.find("chunk_experts") != -1:
|
||||
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group
|
||||
assert bid is not None
|
||||
|
||||
if self._chunk_experts is None:
|
||||
@@ -9846,7 +9940,7 @@ class GroveMoeModel(TextModel):
|
||||
else:
|
||||
return
|
||||
elif name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -10239,7 +10333,6 @@ class HunYuanMoEModel(TextModel):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"])
|
||||
|
||||
moe_intermediate_size = hparams["moe_intermediate_size"]
|
||||
@@ -10282,7 +10375,7 @@ class HunYuanMoEModel(TextModel):
|
||||
return
|
||||
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -10324,16 +10417,9 @@ class LLaDAMoEModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
|
||||
if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
|
||||
|
||||
# number of experts used per token (top-k)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
|
||||
self.gguf_writer.add_mask_token_id(156895)
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self.gguf_writer.add_diffusion_shift_logits(False)
|
||||
@@ -10344,7 +10430,7 @@ class LLaDAMoEModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
@@ -10681,7 +10767,6 @@ class LFM2MoeModel(TextModel):
|
||||
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"])
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
@@ -10702,7 +10787,7 @@ class LFM2MoeModel(TextModel):
|
||||
|
||||
# merge expert weights
|
||||
if 'experts' in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
expert_cache = self._experts_cache.setdefault(bid, {})
|
||||
@@ -10812,9 +10897,9 @@ class SmallThinkerModel(TextModel):
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
|
||||
if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
|
||||
if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
@@ -10839,7 +10924,7 @@ class SmallThinkerModel(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("experts") != -1:
|
||||
n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
|
||||
n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"])
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
|
||||
@@ -99,6 +99,7 @@ models = [
|
||||
{"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", },
|
||||
{"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", },
|
||||
{"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", },
|
||||
{"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", },
|
||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", },
|
||||
{"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", },
|
||||
{"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", },
|
||||
@@ -148,7 +149,8 @@ models = [
|
||||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
|
||||
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", },
|
||||
{"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
@@ -158,6 +160,7 @@ pre_computed_hashes = [
|
||||
{"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
|
||||
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
|
||||
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
|
||||
@@ -171,7 +174,6 @@ pre_computed_hashes = [
|
||||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
|------------|-------------|------|-------|
|
||||
| FP32 | ✅ | ✅ | ❓ |
|
||||
| FP16 | ✅ | ✅ | ❓ |
|
||||
| BF16 | 🚫 | ✅ | ❓ |
|
||||
| BF16 | ✅ | ✅ | ❓ |
|
||||
| Q4_0 | ✅ | ❓ | ❓ |
|
||||
| Q4_1 | ✅ | ❓ | ❓ |
|
||||
| MXFP4 | 🚫 | ❓ | ❓ |
|
||||
| MXFP4 | ✅ | ❓ | ❓ |
|
||||
| Q5_0 | ✅ | ❓ | ❓ |
|
||||
| Q5_1 | ✅ | ❓ | ❓ |
|
||||
| Q8_0 | ✅ | ❓ | ❓ |
|
||||
@@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
|
||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026.
|
||||
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 5)
|
||||
set(GGML_VERSION_PATCH 7)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -752,6 +752,7 @@ extern "C" {
|
||||
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_view (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
|
||||
|
||||
@@ -17,11 +17,6 @@
|
||||
//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
|
||||
#define AT_PRINTF(...)
|
||||
|
||||
|
||||
static bool ggml_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
// ops that return true for this function must not use restrict pointers for their backend implementations
|
||||
bool ggml_op_can_inplace(enum ggml_op op) {
|
||||
switch (op) {
|
||||
@@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||
GGML_ASSERT(buffer_id >= 0);
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
|
||||
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
|
||||
if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) {
|
||||
hn->allocated = true;
|
||||
assert(hn->addr.offset == 0);
|
||||
|
||||
@@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
||||
|
||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||
if (p_hn->n_children == 1 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
if (ggml_impl_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
|
||||
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
||||
@@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
// GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
|
||||
// control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
|
||||
// itself is never used and should not be considered a dependency
|
||||
if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
|
||||
if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) {
|
||||
struct ggml_tensor * view_src = node->view_src;
|
||||
ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
|
||||
}
|
||||
@@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
|
||||
|
||||
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
||||
if (ggml_is_view(parent)) {
|
||||
if (ggml_impl_is_view(parent)) {
|
||||
struct ggml_tensor * view_src = parent->view_src;
|
||||
struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
|
||||
view_src_hn->n_views -= 1;
|
||||
|
||||
@@ -569,27 +569,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
endif()
|
||||
|
||||
# TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+
|
||||
# Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28
|
||||
FetchContent_Declare(KleidiAI_Download
|
||||
URL ${KLEIDIAI_DOWNLOAD_URL}
|
||||
DOWNLOAD_EXTRACT_TIMESTAMP NEW
|
||||
URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
|
||||
|
||||
FetchContent_MakeAvailable(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download
|
||||
SOURCE_DIR KLEIDIAI_SRC
|
||||
POPULATED KLEIDIAI_POPULATED)
|
||||
|
||||
if (NOT KLEIDIAI_POPULATED)
|
||||
message(FATAL_ERROR "KleidiAI source downloaded failed.")
|
||||
FetchContent_Populate(KleidiAI_Download)
|
||||
FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC)
|
||||
endif()
|
||||
|
||||
add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
|
||||
|
||||
# Remove kleidiai target after fetching it
|
||||
if (TARGET kleidiai)
|
||||
set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
|
||||
endif()
|
||||
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/kleidiai/kleidiai.cpp
|
||||
ggml-cpu/kleidiai/kernels.cpp
|
||||
|
||||
@@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntb() * 8 == 256) {
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
||||
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
||||
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
||||
svuint32_t idx = svld1(pg, idx_arr);
|
||||
|
||||
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
||||
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
acc_f32_01 = svdup_n_f32(0);
|
||||
acc_f32_23 = svdup_n_f32(0);
|
||||
acc_f32_45 = svdup_n_f32(0);
|
||||
acc_f32_67 = svdup_n_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
|
||||
int32_t bsums_arr32[4][8];
|
||||
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
int16x8_t v16 = bsums[q8_row];
|
||||
|
||||
// low 4
|
||||
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
||||
|
||||
// high 4
|
||||
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
svint32_t acc_22 = svdup_n_s32(0);
|
||||
svint32_t acc_33 = svdup_n_s32(0);
|
||||
svint32_t acc_44 = svdup_n_s32(0);
|
||||
svint32_t acc_55 = svdup_n_s32(0);
|
||||
svint32_t acc_66 = svdup_n_s32(0);
|
||||
svint32_t acc_77 = svdup_n_s32(0);
|
||||
|
||||
svint32_t bias_acc_00 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_22 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_44 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_66 = svdup_n_s32(0);
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
||||
svint32_t q4sb_mins_0, q4sb_mins_1;
|
||||
{
|
||||
// 2-superblock I am working on
|
||||
const int offset = sb * 24 + 0 * 12;
|
||||
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
|
||||
|
||||
const int offset1 = sb * 24 + 12;
|
||||
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
uint32_t sm1[3];
|
||||
memcpy(sm1, scales_in1, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
|
||||
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
||||
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
||||
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
||||
|
||||
/* reinterpret u32 → u8 */
|
||||
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
||||
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
||||
|
||||
/* widen u8 → u16->u32 (lower half only) */
|
||||
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
||||
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
||||
|
||||
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
||||
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
||||
|
||||
uint32_t scales_u32_0 = sm[0] & kmask1;
|
||||
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
||||
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
||||
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
||||
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
||||
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
||||
|
||||
svint8_t S01_b = svreinterpret_s8_u32(S01);
|
||||
svint8_t S23_b = svreinterpret_s8_u32(S23);
|
||||
svint8_t R01_b = svreinterpret_s8_u32(R01);
|
||||
svint8_t R23_b = svreinterpret_s8_u32(R23);
|
||||
|
||||
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
||||
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
||||
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
||||
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
||||
|
||||
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
||||
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
||||
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
||||
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
||||
}
|
||||
|
||||
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
// predicate for activating higher lanes for 16 int8 elements
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
||||
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
||||
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
||||
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
||||
} // for sb
|
||||
|
||||
|
||||
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
||||
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
||||
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
||||
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
||||
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
||||
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
||||
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
||||
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
||||
|
||||
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
||||
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
||||
|
||||
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
||||
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
||||
|
||||
// Broadcast q8 scalar
|
||||
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
||||
|
||||
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
||||
|
||||
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
||||
|
||||
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
||||
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
||||
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
||||
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
||||
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
||||
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
||||
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
// Predicate for exactly 4 lanes
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
|
||||
if (i == 0 && j == 0) {
|
||||
// acc_f32_0 → lower half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, acc_f32_01);
|
||||
} else if (i == 0 && j == 1) {
|
||||
// acc_f32_1 → upper half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
} else if (i == 1 && j == 0) {
|
||||
// acc_f32_2
|
||||
svst1_f32(pg4, s + offset, acc_f32_23);
|
||||
} else if (i == 1 && j == 1) {
|
||||
// acc_f32_3
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
} else if (i == 2 && j == 0) {
|
||||
// acc_f32_4
|
||||
svst1_f32(pg4, s + offset, acc_f32_45);
|
||||
} else if (i == 2 && j == 1) {
|
||||
// acc_f32_5
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
||||
} else if (i == 3 && j == 0) {
|
||||
// acc_f32_6
|
||||
svst1_f32(pg4, s + offset, acc_f32_67);
|
||||
} else if (i == 3 && j == 1) {
|
||||
// acc_f32_7
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
}
|
||||
#endif // SVE compile-time end
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
#define GGML_FA_TILE_Q 64
|
||||
#define GGML_FA_TILE_KV 64
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
|
||||
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
|
||||
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
|
||||
|
||||
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
|
||||
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
|
||||
@@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/*.use_ref =*/ cplan->use_ref,
|
||||
};
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
@@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
}
|
||||
}
|
||||
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
|
||||
#ifdef GGML_USE_OPENMP
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan);
|
||||
#else
|
||||
GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph);
|
||||
#endif
|
||||
|
||||
ggml_barrier(state->threadpool);
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "binary-ops.h"
|
||||
#include "simd-gemm.h"
|
||||
#include "ggml.h"
|
||||
#include "unary-ops.h"
|
||||
#include "vec.h"
|
||||
@@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
@@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
@@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
|
||||
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV;
|
||||
float * K_f32 = V32 + KV_TILE_SZ * DV;
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
@@ -8476,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
{
|
||||
float * Q_f32 = (float *)Q_q;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
|
||||
}
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
|
||||
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
// Pad remaining mask entries with -inf
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
@@ -8505,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
|
||||
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
|
||||
}
|
||||
} else {
|
||||
const float * k_f32_src = (const float *)k_data;
|
||||
for (int64_t dk = 0; dk < DK; dk++) {
|
||||
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
|
||||
}
|
||||
}
|
||||
}
|
||||
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
|
||||
|
||||
// Set padded KQ entries to -inf so softmax gives them zero weight
|
||||
if (kv_tile < KV_TILE_SZ) {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
|
||||
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
// V accumulation: VKQ32 += softmax(KQ) * V
|
||||
// Pack V tile to contiguous F32, zero-padded
|
||||
for (int tk = 0; tk < kv_tile; tk++) {
|
||||
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
|
||||
} else {
|
||||
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
|
||||
}
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) {
|
||||
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
|
||||
}
|
||||
}
|
||||
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
@@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool use_tiled = !use_ref &&
|
||||
bool use_tiled = !use_ref &&
|
||||
(q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ);
|
||||
|
||||
#ifdef GGML_SIMD
|
||||
use_tiled &= (DV % GGML_F32_EPR == 0);
|
||||
#endif
|
||||
int current_chunk = ith;
|
||||
|
||||
while (current_chunk < nchunk) {
|
||||
|
||||
136
ggml/src/ggml-cpu/simd-gemm.h
Normal file
136
ggml/src/ggml-cpu/simd-gemm.h
Normal file
@@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
// Computes C[M x N] += A[M x K] * B[K x N]
|
||||
|
||||
#include "simd-mappings.h"
|
||||
|
||||
// TODO: add support for sizeless vector types
|
||||
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
|
||||
|
||||
// TODO: untested on avx512
|
||||
// These are in units of GGML_F32_EPR
|
||||
#if defined(__AVX512F__) || defined (__ARM_NEON__)
|
||||
static constexpr int GEMM_RM = 4;
|
||||
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
static constexpr int GEMM_RM = 6;
|
||||
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
|
||||
#else
|
||||
static constexpr int GEMM_RM = 2;
|
||||
static constexpr int GEMM_RN = 2;
|
||||
#endif
|
||||
|
||||
template <int RM, int RN>
|
||||
static inline void simd_gemm_ukernel(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
GGML_F32_VEC acc[RM][RN];
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
GGML_F32_VEC Bv[RN];
|
||||
for (int r = 0; r < RN; r++) {
|
||||
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN);
|
||||
}
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]);
|
||||
for (int r = 0; r < RN; r++) {
|
||||
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < RM; i++) {
|
||||
for (int r = 0; r < RN; r++) {
|
||||
GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// C[M x N] += A[M x K] * B[K x N]
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
static constexpr int KN = GGML_F32_EPR;
|
||||
|
||||
int64_t ii = 0;
|
||||
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
for (int64_t i = 0; i < GEMM_RM; i++) {
|
||||
float a = C[i * N + jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[i + kk] * B[kk * N + jj];
|
||||
}
|
||||
C[i * N + jj] = a;
|
||||
}
|
||||
}
|
||||
|
||||
A += GEMM_RM * K;
|
||||
C += GEMM_RM * N;
|
||||
}
|
||||
|
||||
// Tail rows: one at a time
|
||||
for (; ii < M; ii++) {
|
||||
int64_t jj = 0;
|
||||
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
|
||||
simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj + KN <= N; jj += KN) {
|
||||
simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N);
|
||||
}
|
||||
for (; jj < N; jj++) {
|
||||
float a = C[jj];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
a += A[kk] * B[kk * N + jj];
|
||||
}
|
||||
C[jj] = a;
|
||||
}
|
||||
|
||||
A += K;
|
||||
C += N;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) && !defined(__clang__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#else // scalar path
|
||||
|
||||
static void simd_gemm(
|
||||
float * GGML_RESTRICT C,
|
||||
const float * GGML_RESTRICT A,
|
||||
const float * GGML_RESTRICT B,
|
||||
int M, int K, int N)
|
||||
{
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
float sum = C[i * N + j];
|
||||
for (int64_t kk = 0; kk < K; kk++) {
|
||||
sum += A[i * K + kk] * B[kk * N + j];
|
||||
}
|
||||
C[i * N + j] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // GGML_SIMD
|
||||
@@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
||||
res = tmp[0] + tmp[1]; \
|
||||
}
|
||||
#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \
|
||||
{ \
|
||||
float32x4_t v = vec_add(vec_add(s0, s1), \
|
||||
vec_add(s2, s3)); \
|
||||
v = vec_add(v, vec_sld(v, v, 8)); \
|
||||
v = vec_add(v, vec_sld(v, v, 4)); \
|
||||
res += (ggml_float)vec_extract(v, 0); \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO
|
||||
@@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
|
||||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
// BF16 s390x
|
||||
#define GGML_BF16_STEP 16
|
||||
#define GGML_BF16_EPR 8
|
||||
|
||||
#define GGML_BF16x8 __vector unsigned short
|
||||
#define GGML_BF16x8_ZERO vec_splats((unsigned short)0)
|
||||
#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p))
|
||||
|
||||
#define GGML_BF16_VEC GGML_BF16x8
|
||||
#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO
|
||||
#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD
|
||||
#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO))
|
||||
#define GGML_BF16_FMA_LO(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y))
|
||||
#define GGML_BF16_FMA_HI(acc, x, y) \
|
||||
(acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y))
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
|
||||
// compatible with vlen >= 128
|
||||
|
||||
@@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
||||
vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
|
||||
sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
|
||||
|
||||
#endif
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__)
|
||||
const int np = (n & ~(GGML_BF16_STEP - 1));
|
||||
if (np > 0) {
|
||||
GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO};
|
||||
|
||||
@@ -63,7 +63,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
|
||||
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
|
||||
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
|
||||
@@ -135,7 +135,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
|
||||
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
|
||||
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
|
||||
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
|
||||
|
||||
@@ -2872,6 +2872,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
|
||||
const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
|
||||
const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
|
||||
const std::string delta_net_prefix = "dnet_add";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
@@ -2902,7 +2903,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
|
||||
strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
|
||||
strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
|
||||
strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) {
|
||||
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
|
||||
// by means of matching node names. See
|
||||
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
|
||||
@@ -4544,6 +4546,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
// TODO: should become:
|
||||
//return ggml_is_contiguous_rows(op->src[0]);
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XXS; ++l) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
|
||||
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
const float d = bxi->d;
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#else
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
|
||||
x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
|
||||
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
}
|
||||
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR2_XS; ++l) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l] >> 9);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
||||
#pragma unroll
|
||||
for (int l = 0; l < QR3_XXS; ++l) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
||||
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
|
||||
|
||||
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
|
||||
// v is a 7 bit int, with the 8th sign being encodable as popcnt
|
||||
// with xor we can "correct" the bit instead of having to mask
|
||||
const uint32_t p = __popc(v) & 1;
|
||||
const uint32_t s = v ^ p << 7;
|
||||
// broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
|
||||
return s * 0x01010101;
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
int sumi = 0;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < 8; k0 += 2) {
|
||||
const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
|
||||
const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
|
||||
const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
|
||||
|
||||
const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
|
||||
const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
|
||||
sumi = ggml_cuda_dp4a(grid0, u0, sumi);
|
||||
|
||||
const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
|
||||
const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
|
||||
sumi = ggml_cuda_dp4a(grid1, u1, sumi);
|
||||
}
|
||||
|
||||
const int ls = aux32 >> 28;
|
||||
sumi = (ls*sumi + sumi/2)/4;
|
||||
const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
|
||||
sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4
|
||||
const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
|
||||
return d * sumi;
|
||||
}
|
||||
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||
int sumi1 = 0;
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
|
||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
|
||||
const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
|
||||
const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
|
||||
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
if (l0 < 4) {
|
||||
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < 8; l0 += 2) {
|
||||
const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
|
||||
const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
|
||||
|
||||
const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
|
||||
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
|
||||
const int signs0 = __vcmpne4(signs & 0x08040201, 0);
|
||||
const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
|
||||
|
||||
const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
|
||||
|
||||
const int signs1 = __vcmpne4(signs & 0x80402010, 0);
|
||||
const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
|
||||
|
||||
const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
|
||||
|
||||
sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
|
||||
|
||||
@@ -17,121 +17,6 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
|
||||
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
|
||||
}
|
||||
|
||||
rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of FP32 and FP16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
|
||||
const void * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
unsigned int n,
|
||||
float s) {
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
|
||||
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
|
||||
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
|
||||
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
@@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
@@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
@@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
|
||||
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
@@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
|
||||
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
|
||||
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
|
||||
@@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
|
||||
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (float)
|
||||
// MAD: y (F32) += x (F16) * s (F32)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
|
||||
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
@@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict
|
||||
}
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32)
|
||||
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y,
|
||||
const void * restrict x0,
|
||||
const void * restrict x1,
|
||||
float s0,
|
||||
float s1,
|
||||
int n) {
|
||||
const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0;
|
||||
const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1;
|
||||
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(s0);
|
||||
HVX_Vector S1 = hvx_vec_splat_f16(s1);
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
// Multiply x * s -> pair of F32 vectors
|
||||
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
||||
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
||||
|
||||
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
||||
HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
||||
|
||||
ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2]));
|
||||
ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1]));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0);
|
||||
HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1);
|
||||
|
||||
HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p));
|
||||
HVX_Vector xs = xs_p_lo;
|
||||
i = 2 * i; // index for ptr_y
|
||||
|
||||
if (nloe >= 32) {
|
||||
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
nloe -= 32; ++i;
|
||||
xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
|
||||
hvx_vec_store_a(&ptr_y[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define FLASH_ATTN_BLOCK_SIZE 128
|
||||
|
||||
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
|
||||
struct htp_fa_context {
|
||||
const struct htp_ops_context * octx;
|
||||
|
||||
struct fastdiv_values src0_div21;
|
||||
struct fastdiv_values src0_div1;
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values src3_div2;
|
||||
struct fastdiv_values src3_div3;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t n_blocks;
|
||||
|
||||
size_t size_q_row_padded;
|
||||
size_t size_k_row_padded;
|
||||
size_t size_v_row_padded;
|
||||
|
||||
size_t size_k_block;
|
||||
size_t size_v_block;
|
||||
size_t size_m_block;
|
||||
|
||||
bool is_q_fp32;
|
||||
};
|
||||
|
||||
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
||||
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
|
||||
}
|
||||
if (nloe) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_fa_context * factx = (struct htp_fa_context *) data;
|
||||
const struct htp_ops_context * octx = factx->octx;
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const uint32_t neq0 = q->ne[0];
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
@@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
const uint32_t nb2 = dst->nb[2];
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
// total rows in q
|
||||
const uint32_t nr = neq1*neq2*neq3;
|
||||
|
||||
@@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
|
||||
const size_t size_q_row_padded = hex_round_up(size_q_row, 128);
|
||||
|
||||
const size_t size_k_row = DK * sizeof(__fp16);
|
||||
const size_t size_v_row = DV * sizeof(__fp16);
|
||||
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
|
||||
|
||||
const size_t size_k_row_padded = hex_round_up(size_k_row, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(size_v_row, 128);
|
||||
|
||||
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
|
||||
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
|
||||
@@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
|
||||
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
|
||||
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
|
||||
|
||||
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
|
||||
const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3);
|
||||
const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2);
|
||||
|
||||
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
|
||||
const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3);
|
||||
const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2);
|
||||
|
||||
// Fetch Q row
|
||||
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
|
||||
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
|
||||
const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f;
|
||||
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
HVX_Vector S_vec = hvx_vec_splat_f32(0.0f);
|
||||
HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY);
|
||||
|
||||
// Clear accumulator
|
||||
hvx_splat_f32_a(spad_a, 0, DV);
|
||||
@@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
|
||||
const __fp16 * mp_base = NULL;
|
||||
if (mask) {
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
|
||||
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2);
|
||||
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3);
|
||||
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
|
||||
}
|
||||
|
||||
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
// Prefetch first two blocks
|
||||
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
|
||||
for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block;
|
||||
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block;
|
||||
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
|
||||
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
|
||||
// Mask is 1D contiguous for this row
|
||||
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
if (factx->is_q_fp32) {
|
||||
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
||||
}
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
|
||||
for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
||||
@@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
// Inner loop processing the block from VTCM
|
||||
uint32_t ic = 0;
|
||||
|
||||
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
|
||||
// Process in blocks of 32 (VLEN_FP32)
|
||||
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
|
||||
HVX_Vector_x4 scores_x4;
|
||||
@@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
}
|
||||
const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale);
|
||||
}
|
||||
|
||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||
|
||||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
@@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
|
||||
|
||||
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
scores_x4.v[iv] = scores;
|
||||
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
|
||||
v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max
|
||||
}
|
||||
|
||||
{
|
||||
// 4. Online Softmax Update
|
||||
v_max = hvx_vec_reduce_max_f32(v_max);
|
||||
float m_block = hvx_vec_get_f32(v_max);
|
||||
float M_old = M;
|
||||
float M_new = (m_block > M) ? m_block : M;
|
||||
M = M_new;
|
||||
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
|
||||
HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec));
|
||||
M_vec = M_new_vec;
|
||||
|
||||
const float ms = expf(M_old - M_new);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
|
||||
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
|
||||
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
|
||||
HVX_Vector scores = scores_x4.v[iv];
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
|
||||
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
|
||||
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
|
||||
|
||||
p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
|
||||
|
||||
// 5. Accumulate V
|
||||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector*)p_arr = P;
|
||||
*(HVX_Vector *) p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; ++j) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
|
||||
}
|
||||
}
|
||||
|
||||
p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
|
||||
S = S * ms + hvx_vec_get_f32(p_sum_vec);
|
||||
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
|
||||
}
|
||||
|
||||
// Sync scalars for leftover/next block if needed
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
// Leftover
|
||||
for (; ic < current_block_size; ++ic) {
|
||||
float s_val;
|
||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||
|
||||
if (is_q_fp32) {
|
||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
} else {
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
s_val = logit_softcap * tanhf(s_val);
|
||||
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
|
||||
if (factx->logit_softcap != 0.0f) {
|
||||
s_val = factx->logit_softcap * tanhf(s_val);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
@@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
|
||||
const float Mold = M;
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s_val > M) {
|
||||
M = s_val;
|
||||
ms = expf(Mold - M);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
vs = expf(s_val - M);
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
|
||||
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
|
||||
const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;
|
||||
|
||||
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
|
||||
|
||||
S = S * ms + vs;
|
||||
}
|
||||
M_vec = hvx_vec_splat_f32(M);
|
||||
S_vec = hvx_vec_splat_f32(S);
|
||||
|
||||
// Issue DMA for next+1 block (if exists)
|
||||
if (ib + 2 < n_blocks) {
|
||||
if (ib + 2 < factx->n_blocks) {
|
||||
const uint32_t next_ib = ib + 2;
|
||||
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
|
||||
|
||||
// K
|
||||
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size);
|
||||
|
||||
// V
|
||||
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size);
|
||||
|
||||
// Mask
|
||||
if (mask) {
|
||||
@@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
|
||||
// sinks
|
||||
float M = hvx_vec_get_f32(M_vec);
|
||||
float S = hvx_vec_get_f32(S_vec);
|
||||
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M) {
|
||||
ms = expf(M - s);
|
||||
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M);
|
||||
}
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(M - s);
|
||||
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
|
||||
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);
|
||||
|
||||
S = S * ms + vs;
|
||||
float ms = hvx_vec_get_f32(ms_vec);
|
||||
S = S * ms + vs;
|
||||
} else {
|
||||
HVX_Vector diff_vec = hvx_vec_splat_f32(s - M);
|
||||
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
|
||||
S += vs;
|
||||
}
|
||||
}
|
||||
|
||||
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
|
||||
@@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = data;
|
||||
flash_attn_ext_f16_thread(octx, i, n);
|
||||
}
|
||||
|
||||
int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * q = &octx->src0;
|
||||
const struct htp_tensor * k = &octx->src1;
|
||||
const struct htp_tensor * v = &octx->src2;
|
||||
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
// Check support
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
|
||||
k->type != HTP_TYPE_F16 ||
|
||||
v->type != HTP_TYPE_F16) {
|
||||
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
factx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
|
||||
factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
factx.src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
factx.src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
|
||||
size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128);
|
||||
factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128);
|
||||
factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128);
|
||||
|
||||
size_t size_q_block = size_q_row_padded * 1; // single row for now
|
||||
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
size_t size_q_block = factx.size_q_row_padded * 1; // single row for now
|
||||
factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
|
||||
factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
|
||||
factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
factx.scale = scale;
|
||||
factx.max_bias = max_bias;
|
||||
factx.logit_softcap = logit_softcap;
|
||||
|
||||
uint32_t n_head = q->ne[2];
|
||||
factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2);
|
||||
factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2);
|
||||
|
||||
size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
|
||||
|
||||
octx->src0_spad.size_per_thread = size_q_block * 1;
|
||||
octx->src1_spad.size_per_thread = size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
|
||||
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
|
||||
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
|
||||
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
|
||||
octx->dst_spad.size_per_thread = size_vkq_acc;
|
||||
|
||||
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
|
||||
@@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
|
||||
@@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool ggml_impl_is_view(const struct ggml_tensor * t) {
|
||||
return t->view_src != NULL;
|
||||
}
|
||||
|
||||
static inline float ggml_compute_softplus_f32(float input) {
|
||||
return (input > 20.0f) ? input : logf(1 + expf(input));
|
||||
}
|
||||
|
||||
@@ -273,6 +273,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
|
||||
@@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_ADD_ID:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
|
||||
@@ -620,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
|
||||
|
||||
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
|
||||
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
|
||||
@@ -671,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
/*.ne00 =*/ ne10,
|
||||
/*.ne01 =*/ ne11,
|
||||
/*.ne02 =*/ ne12,
|
||||
/*.ne03 =*/ ne13,
|
||||
/*.nb00 =*/ nb00,
|
||||
/*.nb01 =*/ pnb1,
|
||||
/*.nb02 =*/ pnb2,
|
||||
@@ -687,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.ne0 =*/ ne10,
|
||||
/*.ne1 =*/ ne11,
|
||||
/*.ne2 =*/ ne12,
|
||||
/*.ne3 =*/ ne13,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ pnb1,
|
||||
/*.nb2 =*/ pnb2,
|
||||
@@ -707,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
|
||||
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
int nth = 1;
|
||||
|
||||
while (2*nth < args.ne0 && nth < nth_max) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
|
||||
|
||||
|
||||
@@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
#define VK_VENDOR_ID_APPLE 0x106b
|
||||
#define VK_VENDOR_ID_INTEL 0x8086
|
||||
#define VK_VENDOR_ID_NVIDIA 0x10de
|
||||
#define VK_VENDOR_ID_QUALCOMM 0x5143
|
||||
|
||||
#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
|
||||
|
||||
@@ -687,6 +688,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
|
||||
vk_pipeline pipeline_add[2][2][2];
|
||||
@@ -4080,7 +4082,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -4181,7 +4183,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5641,6 +5644,10 @@ static void ggml_vk_instance_init() {
|
||||
driver_priorities[vk::DriverId::eMesaNvk] = 2;
|
||||
#endif
|
||||
break;
|
||||
case VK_VENDOR_ID_QUALCOMM:
|
||||
driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
|
||||
driver_priorities[vk::DriverId::eMesaTurnip] = 2;
|
||||
break;
|
||||
}
|
||||
driver_priorities[vk::DriverId::eMesaDozen] = 100;
|
||||
|
||||
@@ -8422,6 +8429,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
|
||||
const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
|
||||
|
||||
const uint32_t qstride = hsk_pad / 4 + 2;
|
||||
const uint32_t Qf = Br * qstride * f16vec4;
|
||||
|
||||
@@ -8438,7 +8447,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
|
||||
const uint32_t slope = Br * acctype;
|
||||
|
||||
const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
|
||||
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
|
||||
@@ -8815,6 +8824,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SET:
|
||||
if (src0->type == src1->type && src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
|
||||
return ctx->device->pipeline_set_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -9806,7 +9821,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
|
||||
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
@@ -10624,8 +10639,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
|
||||
}
|
||||
|
||||
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -12500,6 +12517,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_SET:
|
||||
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
@@ -14896,8 +14914,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous_rows(op->src[0]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -14960,7 +14980,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
}
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ACC:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET:
|
||||
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
|
||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
|
||||
case GGML_OP_CONCAT:
|
||||
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
|
||||
case GGML_OP_ADD1:
|
||||
@@ -15611,6 +15634,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
} else if (tensor->op == GGML_OP_ACC) {
|
||||
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
||||
} else if (tensor->op == GGML_OP_SET) {
|
||||
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
||||
} else if (tensor->op == GGML_OP_NORM) {
|
||||
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
// false for SET, true for ACC
|
||||
layout(constant_id = 1) const bool ACC = true;
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
@@ -23,7 +26,11 @@ void main() {
|
||||
uint i00, i01, i02, i03;
|
||||
|
||||
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
|
||||
if (ACC) {
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
|
||||
} else {
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
|
||||
}
|
||||
} else {
|
||||
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
|
||||
}
|
||||
|
||||
@@ -130,6 +130,7 @@ void main() {
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
@@ -137,12 +138,25 @@ void main() {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
@@ -260,6 +274,9 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
// prevent race on tmpsh
|
||||
barrier();
|
||||
|
||||
// reduce across threads
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
|
||||
@@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
@@ -213,6 +215,19 @@ void main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -176,7 +176,14 @@ void main() {
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
||||
// Don't clamp against nem1 when GQA is enabled
|
||||
@@ -184,7 +191,14 @@ void main() {
|
||||
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
||||
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
||||
|
||||
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
||||
// skip the block if the mask is entirely -inf
|
||||
coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
||||
if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
@@ -8,19 +8,22 @@
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i3 = row / (p.ne11 * p.ne12);
|
||||
const uint i3_offset = i3 * p.ne12 * p.ne11;
|
||||
const uint i2 = (row - i3_offset) / p.ne11;
|
||||
const uint i2_offset = i2 * p.ne11;
|
||||
const uint i1 = row - i3_offset - i2_offset;
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
@@ -35,7 +38,7 @@ void main() {
|
||||
|
||||
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1496,6 +1496,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
|
||||
(t0->nb[3] == t1->nb[3]);
|
||||
}
|
||||
|
||||
bool ggml_is_view(const struct ggml_tensor * t) {
|
||||
return ggml_impl_is_view(t);
|
||||
}
|
||||
|
||||
// check if t1 can be represented as a repetition of t0
|
||||
bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
@@ -3830,6 +3830,7 @@ class VisionProjectorType:
|
||||
MUSIC_FLAMINGO = "musicflamingo" # audio
|
||||
GLM4V = "glm4v"
|
||||
YOUTUVL = "youtuvl"
|
||||
NEMOTRON_V2_VL = "nemotron_v2_vl"
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
|
||||
@@ -1346,6 +1346,7 @@ class TensorNameMap:
|
||||
"model.vision_tower.embeddings.cls_token", # Intern-S1
|
||||
"vision_model.class_embedding", # llama 4
|
||||
"model.vision.patch_embedding.cls_embedding", # cogvlm
|
||||
"vision_model.radio_model.model.patch_generator.cls_token.token", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
@@ -1360,6 +1361,7 @@ class TensorNameMap:
|
||||
"vision_tower.patch_embed.proj", # kimi-vl
|
||||
"model.vision.patch_embedding.proj", # cogvlm
|
||||
"siglip2.vision_model.embeddings.patch_embedding",
|
||||
"vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_NORM: (
|
||||
@@ -1376,12 +1378,14 @@ class TensorNameMap:
|
||||
"visual.pos_embed", # qwen3vl
|
||||
"model.vision.patch_embedding.position_embedding", # cogvlm
|
||||
"visual.embeddings.position_embedding", # glm4v
|
||||
"vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_QKV: (
|
||||
"visual.blocks.{bid}.attn.qkv", # qwen3vl
|
||||
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
|
||||
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
|
||||
"vision_tower.encoder.blocks.{bid}.wqkv", # Kimi-K2.5
|
||||
"vision_model.radio_model.model.blocks.{bid}.attn.qkv", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
@@ -1446,6 +1450,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm1",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm1", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_O: (
|
||||
@@ -1462,6 +1467,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
|
||||
"model.vision.transformer.layers.{bid}.attention.dense", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl
|
||||
"vision_model.radio_model.model.blocks.{bid}.attn.proj", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
|
||||
@@ -1477,6 +1483,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
|
||||
"model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.norm2", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
@@ -1493,6 +1500,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||
"vision_model.radio_model.model.blocks.{bid}.mlp.fc1", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||
@@ -1515,6 +1523,7 @@ class TensorNameMap:
|
||||
"vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
|
||||
"model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm
|
||||
"siglip2.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||
"vision_model.radio_model.model.blocks.{bid}.mlp.fc2", # Nemotron Nano v2 VL
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_LAYER_SCALE_1: (
|
||||
|
||||
@@ -656,21 +656,12 @@ extern "C" {
|
||||
|
||||
// The following functions operate on a llama_context, hence the naming: llama_verb_...
|
||||
|
||||
// Add a loaded LoRA adapter to given context
|
||||
// This will not modify model's weight
|
||||
LLAMA_API int32_t llama_set_adapter_lora(
|
||||
// Set LoRa adapters on the context. Will only modify if the adapters currently in context are different.
|
||||
LLAMA_API int32_t llama_set_adapters_lora(
|
||||
struct llama_context * ctx,
|
||||
struct llama_adapter_lora * adapter,
|
||||
float scale);
|
||||
|
||||
// Remove a specific LoRA adapter from given context
|
||||
// Return -1 if the adapter is not present in the context
|
||||
LLAMA_API int32_t llama_rm_adapter_lora(
|
||||
struct llama_context * ctx,
|
||||
struct llama_adapter_lora * adapter);
|
||||
|
||||
// Remove all LoRA adapters from given context
|
||||
LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
|
||||
struct llama_adapter_lora ** adapters,
|
||||
size_t n_adapters,
|
||||
float * scales);
|
||||
|
||||
// Apply a loaded control vector to a llama_context, or if data is NULL, clear
|
||||
// the currently loaded vector.
|
||||
@@ -678,7 +669,7 @@ extern "C" {
|
||||
// to an n_embd x n_layers buffer starting from layer 1.
|
||||
// il_start and il_end are the layer range the vector should apply to (both inclusive)
|
||||
// See llama_control_vector_load in common to load a control vector.
|
||||
LLAMA_API int32_t llama_apply_adapter_cvec(
|
||||
LLAMA_API int32_t llama_set_adapter_cvec(
|
||||
struct llama_context * ctx,
|
||||
const float * data,
|
||||
size_t len,
|
||||
|
||||
@@ -1 +1 @@
|
||||
a8db410a252c8c8f2d120c6f2e7133ebe032f35d
|
||||
d6754f3d0e6d0acd21c12442353c9fd2f94188e7
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import urllib.request
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "f80864ca031932351abef49b74097c67f14719c6"
|
||||
HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
@@ -14,7 +17,8 @@ vendor = {
|
||||
# "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
"https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h",
|
||||
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "vendor/cpp-httplib/httplib.h",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py",
|
||||
f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE",
|
||||
|
||||
"https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h",
|
||||
@@ -24,19 +28,16 @@ for url, filename in vendor.items():
|
||||
print(f"downloading {url} to {filename}") # noqa: NP100
|
||||
urllib.request.urlretrieve(url, filename)
|
||||
|
||||
# split cpp/h files for httplib
|
||||
# see: https://github.com/yhirose/cpp-httplib/blob/master/split.py
|
||||
if 'httplib.h' in filename:
|
||||
border = '// ----------------------------------------------------------------------------'
|
||||
with open(filename, 'r') as f:
|
||||
content = f.read()
|
||||
header, implementation, footer = content.split(border, 2)
|
||||
fname_cpp = filename.replace('.h', '.cpp')
|
||||
with open(filename, 'w') as fh:
|
||||
fh.write(header)
|
||||
fh.write(footer)
|
||||
with open(fname_cpp, 'w') as fc:
|
||||
fc.write('#include "httplib.h"\n')
|
||||
fc.write('namespace httplib {\n')
|
||||
fc.write(implementation.replace('\ninline ', '\n'))
|
||||
fc.write('} // namespace httplib\n')
|
||||
print("Splitting httplib.h...") # noqa: NP100
|
||||
try:
|
||||
subprocess.check_call([
|
||||
sys.executable, "split.py",
|
||||
"--extension", "cpp",
|
||||
"--out", "vendor/cpp-httplib"
|
||||
])
|
||||
except Exception as e:
|
||||
print(f"Error: {e}") # noqa: NP100
|
||||
sys.exit(1)
|
||||
finally:
|
||||
os.remove("split.py")
|
||||
os.remove("httplib.h")
|
||||
|
||||
@@ -57,13 +57,14 @@ add_library(llama
|
||||
models/deci.cpp
|
||||
models/deepseek.cpp
|
||||
models/deepseek2.cpp
|
||||
models/delta-net-base.cpp
|
||||
models/dots1.cpp
|
||||
models/dream.cpp
|
||||
models/ernie4-5-moe.cpp
|
||||
models/ernie4-5.cpp
|
||||
models/exaone-moe.cpp
|
||||
models/exaone.cpp
|
||||
models/exaone4.cpp
|
||||
models/exaone-moe.cpp
|
||||
models/falcon-h1.cpp
|
||||
models/falcon.cpp
|
||||
models/gemma-embedding.cpp
|
||||
@@ -91,10 +92,12 @@ add_library(llama
|
||||
models/llama-iswa.cpp
|
||||
models/llama.cpp
|
||||
models/maincoder.cpp
|
||||
models/mamba-base.cpp
|
||||
models/mamba.cpp
|
||||
models/mimo2-iswa.cpp
|
||||
models/minicpm3.cpp
|
||||
models/minimax-m2.cpp
|
||||
models/mistral3.cpp
|
||||
models/modern-bert.cpp
|
||||
models/mpt.cpp
|
||||
models/nemotron-h.cpp
|
||||
@@ -118,12 +121,12 @@ add_library(llama
|
||||
models/qwen2moe.cpp
|
||||
models/qwen2vl.cpp
|
||||
models/qwen3.cpp
|
||||
models/qwen3vl.cpp
|
||||
models/qwen3vl-moe.cpp
|
||||
models/qwen3moe.cpp
|
||||
models/qwen3next.cpp
|
||||
models/qwen35.cpp
|
||||
models/qwen35moe.cpp
|
||||
models/qwen3moe.cpp
|
||||
models/qwen3next.cpp
|
||||
models/qwen3vl-moe.cpp
|
||||
models/qwen3vl.cpp
|
||||
models/refact.cpp
|
||||
models/rnd1.cpp
|
||||
models/rwkv6-base.cpp
|
||||
@@ -142,8 +145,6 @@ add_library(llama
|
||||
models/t5-enc.cpp
|
||||
models/wavtokenizer-dec.cpp
|
||||
models/xverse.cpp
|
||||
models/mistral3.cpp
|
||||
models/graph-context-mamba.cpp
|
||||
)
|
||||
|
||||
set_target_properties(llama PROPERTIES
|
||||
|
||||
@@ -39,6 +39,8 @@ private:
|
||||
std::vector<ggml_tensor *> tensors; // per layer
|
||||
};
|
||||
|
||||
using llama_adapter_cvec_ptr = std::shared_ptr<llama_adapter_cvec>;
|
||||
|
||||
//
|
||||
// llama_adapter_lora
|
||||
//
|
||||
@@ -84,3 +86,4 @@ struct llama_adapter_lora {
|
||||
};
|
||||
|
||||
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
||||
using llama_adapter_loras_ptr = std::unique_ptr<llama_adapter_loras>;
|
||||
|
||||
@@ -22,6 +22,8 @@ llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
llama_context_params params) :
|
||||
model(model),
|
||||
cvec(std::make_unique<llama_adapter_cvec>()),
|
||||
loras(std::make_unique<llama_adapter_loras>()),
|
||||
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
||||
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
||||
// may need to be backend-dependent
|
||||
@@ -878,6 +880,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
// fallback to full vocab list
|
||||
GGML_UNUSED(err);
|
||||
}
|
||||
|
||||
return sampling.token_ids_full_vocab.data();
|
||||
@@ -1057,51 +1060,43 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_context::set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
float scale) {
|
||||
LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale);
|
||||
void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
||||
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
||||
|
||||
if (auto it = loras.find(adapter); it != loras.end()) {
|
||||
if (it->second == scale) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
loras[adapter] = scale;
|
||||
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
bool llama_context::rm_adapter_lora(
|
||||
llama_adapter_lora * adapter) {
|
||||
LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter);
|
||||
|
||||
auto it = loras.find(adapter);
|
||||
if (it != loras.end()) {
|
||||
loras.erase(it);
|
||||
|
||||
sched_need_reserve = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void llama_context::clear_adapter_lora() {
|
||||
LLAMA_LOG_DEBUG("%s: call\n", __func__);
|
||||
|
||||
if (loras.empty()) {
|
||||
if (adapters_lora_are_same(adapters, n_adapters, scales)) {
|
||||
return;
|
||||
}
|
||||
|
||||
loras.clear();
|
||||
loras.reset(new llama_adapter_loras());
|
||||
|
||||
for (size_t i = 0; i < n_adapters; i ++) {
|
||||
if (scales[i] != 0.0f) {
|
||||
loras->insert({adapters[i], scales[i]});
|
||||
}
|
||||
}
|
||||
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
bool llama_context::apply_adapter_cvec(
|
||||
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
||||
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
||||
|
||||
if (n_adapters != loras->size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n_adapters; i ++) {
|
||||
auto it = loras->find(adapters[i]);
|
||||
|
||||
if (it == loras->end() || it->second != scales[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_context::set_adapter_cvec(
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
@@ -1111,7 +1106,7 @@ bool llama_context::apply_adapter_cvec(
|
||||
|
||||
// TODO: should we reserve?
|
||||
|
||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
return cvec->apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
@@ -1817,7 +1812,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
//
|
||||
|
||||
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & vocab = model.vocab;
|
||||
|
||||
@@ -1901,11 +1895,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
|
||||
offset += embd.size * sizeof(float);
|
||||
|
||||
sampling.logits = {nullptr, 0};
|
||||
sampling.probs = {nullptr, 0};
|
||||
sampling.sampled = {nullptr, 0};
|
||||
sampling.candidates = {nullptr, 0};
|
||||
|
||||
if (has_sampling) {
|
||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.logits.size * sizeof(float);
|
||||
@@ -1931,6 +1920,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||
|
||||
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
||||
} else {
|
||||
sampling.logits = {nullptr, 0};
|
||||
sampling.probs = {nullptr, 0};
|
||||
sampling.sampled = {nullptr, 0};
|
||||
sampling.candidates = {nullptr, 0};
|
||||
|
||||
sampling.logits_count.clear();
|
||||
sampling.probs_count.clear();
|
||||
sampling.candidates_count.clear();
|
||||
}
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
@@ -1961,37 +1959,30 @@ void llama_context::output_reorder() {
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.logits.has_data()) {
|
||||
if (!sampling.samplers.empty()) {
|
||||
assert(sampling.logits.size > 0);
|
||||
assert(sampling.probs.size > 0);
|
||||
assert(sampling.candidates.size > 0);
|
||||
assert(sampling.sampled.size > 0);
|
||||
assert(sampling.logits_count.size() > 0);
|
||||
assert(sampling.probs_count.size() > 0);
|
||||
assert(sampling.candidates_count.size() > 0);
|
||||
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.probs.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.candidates.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.sampled.has_data()) {
|
||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.logits_count.empty()) {
|
||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.probs_count.empty()) {
|
||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.candidates_count.empty()) {
|
||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
|
||||
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
|
||||
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
|
||||
}
|
||||
}
|
||||
@@ -2092,8 +2083,8 @@ llm_graph_params llama_context::graph_params(
|
||||
/*.gtype =*/ gtype,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.cvec =*/ cvec.get(),
|
||||
/*.loras =*/ loras.get(),
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ sampling.samplers,
|
||||
@@ -3209,35 +3200,28 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
||||
|
||||
// llama adapter API
|
||||
|
||||
int32_t llama_set_adapter_lora(
|
||||
int32_t llama_set_adapters_lora(
|
||||
llama_context * ctx,
|
||||
llama_adapter_lora * adapter,
|
||||
float scale) {
|
||||
ctx->set_adapter_lora(adapter, scale);
|
||||
llama_adapter_lora ** adapters,
|
||||
size_t n_adapters,
|
||||
float * scales) {
|
||||
if (adapters == nullptr || scales == nullptr) {
|
||||
GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call");
|
||||
}
|
||||
|
||||
ctx->set_adapters_lora(adapters, n_adapters, scales);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t llama_rm_adapter_lora(
|
||||
llama_context * ctx,
|
||||
llama_adapter_lora * adapter) {
|
||||
bool res = ctx->rm_adapter_lora(adapter);
|
||||
|
||||
return res ? 0 : -1;
|
||||
}
|
||||
|
||||
void llama_clear_adapter_lora(llama_context * ctx) {
|
||||
ctx->clear_adapter_lora();
|
||||
}
|
||||
|
||||
int32_t llama_apply_adapter_cvec(
|
||||
int32_t llama_set_adapter_cvec(
|
||||
llama_context * ctx,
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
int32_t il_start,
|
||||
int32_t il_end) {
|
||||
bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end);
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
int32_t il_start,
|
||||
int32_t il_end) {
|
||||
bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end);
|
||||
|
||||
return res ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -105,16 +105,11 @@ struct llama_context {
|
||||
void set_causal_attn(bool value);
|
||||
void set_warmup(bool value);
|
||||
|
||||
void set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
float scale);
|
||||
void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales);
|
||||
|
||||
bool rm_adapter_lora(
|
||||
llama_adapter_lora * adapter);
|
||||
bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales);
|
||||
|
||||
void clear_adapter_lora();
|
||||
|
||||
bool apply_adapter_cvec(
|
||||
bool set_adapter_cvec(
|
||||
const float * data,
|
||||
size_t len,
|
||||
int32_t n_embd,
|
||||
@@ -261,33 +256,36 @@ private:
|
||||
|
||||
const llama_model & model;
|
||||
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_adapter_loras loras;
|
||||
llama_cparams cparams;
|
||||
|
||||
llama_adapter_cvec_ptr cvec;
|
||||
llama_adapter_loras_ptr loras;
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
std::unique_ptr<llama_memory_i> memory;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
struct buffer_view<float> embd = {nullptr, 0};
|
||||
buffer_view<float> embd = {nullptr, 0};
|
||||
|
||||
struct sampling_info {
|
||||
// !samplers.empty() to check if any samplers are active
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
struct buffer_view<llama_token> sampled = {nullptr, 0};
|
||||
struct buffer_view<float> probs = {nullptr, 0};
|
||||
struct buffer_view<llama_token> candidates = {nullptr, 0};
|
||||
buffer_view<float> logits = {nullptr, 0};
|
||||
buffer_view<llama_token> sampled = {nullptr, 0};
|
||||
buffer_view<float> probs = {nullptr, 0};
|
||||
buffer_view<llama_token> candidates = {nullptr, 0};
|
||||
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
std::vector<uint32_t> candidates_count;
|
||||
|
||||
// optimization
|
||||
std::vector<llama_token> token_ids_full_vocab;
|
||||
};
|
||||
|
||||
|
||||
@@ -17,6 +17,41 @@
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
// dedup helpers
|
||||
|
||||
static ggml_tensor * build_kq_mask(
|
||||
ggml_context * ctx,
|
||||
const llama_kv_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
}
|
||||
|
||||
static bool can_reuse_kq_mask(
|
||||
ggml_tensor * kq_mask,
|
||||
const llama_kv_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= (kq_mask->ne[0] == n_kv);
|
||||
res &= (kq_mask->ne[1] == n_tokens/n_stream);
|
||||
res &= (kq_mask->ne[2] == 1);
|
||||
res &= (kq_mask->ne[3] == n_stream);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// impl
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
@@ -403,8 +438,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -424,8 +458,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
|
||||
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -455,11 +488,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
||||
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -521,8 +551,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
@@ -565,8 +594,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
|
||||
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
@@ -625,8 +653,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
@@ -634,8 +661,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
|
||||
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
@@ -1891,14 +1917,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -1983,13 +2006,9 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@@ -2188,15 +2207,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
@@ -2207,12 +2222,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
@@ -2374,27 +2387,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
||||
|
||||
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_base()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask);
|
||||
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
|
||||
@@ -504,6 +504,8 @@ struct llama_mmap::impl {
|
||||
}
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
HANDLE hMapping = nullptr;
|
||||
|
||||
impl(struct llama_file * file, size_t prefetch, bool numa) {
|
||||
GGML_UNUSED(numa);
|
||||
|
||||
@@ -511,7 +513,7 @@ struct llama_mmap::impl {
|
||||
|
||||
HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id());
|
||||
|
||||
HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||
|
||||
if (hMapping == NULL) {
|
||||
DWORD error = GetLastError();
|
||||
@@ -520,9 +522,9 @@ struct llama_mmap::impl {
|
||||
|
||||
addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||
DWORD error = GetLastError();
|
||||
CloseHandle(hMapping);
|
||||
|
||||
if (addr == NULL) {
|
||||
CloseHandle(hMapping);
|
||||
throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
|
||||
}
|
||||
|
||||
@@ -554,9 +556,17 @@ struct llama_mmap::impl {
|
||||
}
|
||||
|
||||
~impl() {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
if (hMapping) {
|
||||
if (addr) {
|
||||
if (!UnmapViewOfFile(addr)) {
|
||||
LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
if (!CloseHandle(hMapping)) {
|
||||
LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -308,6 +308,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM:
|
||||
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE:
|
||||
case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM:
|
||||
regex_exprs = {
|
||||
"\\p{N}{1,3}",
|
||||
"[一-龥-ゟ゠-ヿ]+",
|
||||
@@ -422,6 +423,14 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_TINY_AYA:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)"
|
||||
"\\d{1,3}(?=(?:\\d{3})*\\b)",
|
||||
// original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_KIMI_K2:
|
||||
regex_exprs = {
|
||||
// K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp
|
||||
@@ -2005,10 +2014,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "megrez") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o" ||
|
||||
tokenizer_pre == "llama4") {
|
||||
tokenizer_pre == "gpt-4o" ||
|
||||
tokenizer_pre == "llama4") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "tiny_aya") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "superbpe") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE;
|
||||
@@ -2039,6 +2052,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "hunyuan-dense") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "joyai-llm") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "kimi-k2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
||||
|
||||
@@ -55,6 +55,8 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45,
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46,
|
||||
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
|
||||
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
333
src/models/delta-net-base.cpp
Normal file
333
src/models/delta-net-base.cpp
Normal file
@@ -0,0 +1,333 @@
|
||||
#include "models.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
// utility to get one slice from the third dimension
|
||||
// input dim: [x, y, c, b]
|
||||
// output dim: [x, y, 1, b]
|
||||
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
|
||||
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
||||
}
|
||||
|
||||
llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(S_k == S_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(b, "b_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
||||
g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs]
|
||||
b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
|
||||
|
||||
const int CS = CHUNK_SIZE;
|
||||
|
||||
const int pad = (CS - n_tokens % CS) % CS;
|
||||
const int n_chunks = (n_tokens + pad) / CS;
|
||||
|
||||
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
||||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, 0, pad, 0, 0);
|
||||
b = ggml_pad(ctx0, b, 0, pad, 0, 0);
|
||||
|
||||
ggml_tensor * v_b = ggml_mul(ctx0, v, b);
|
||||
ggml_tensor * k_b = ggml_mul(ctx0, k, b);
|
||||
|
||||
cb(v_b, "v_b", il);
|
||||
cb(k_b, "k_b", il);
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs);
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs);
|
||||
k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
|
||||
v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
|
||||
cb(g_cs, "g_cs", il);
|
||||
|
||||
ggml_tensor * g_cs_i = g_cs;
|
||||
ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
// [CS, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * decay_mask;
|
||||
decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
||||
decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kb;
|
||||
kb = ggml_mul_mat(ctx0, k, k_b);
|
||||
kb = ggml_mul (ctx0, kb, decay_mask);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * attn;
|
||||
attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * identity;
|
||||
identity = ggml_view_1d(ctx0, attn, CS, 0);
|
||||
identity = ggml_fill (ctx0, identity, 1.0f);
|
||||
identity = ggml_diag (ctx0, identity);
|
||||
|
||||
ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
|
||||
cb(lhs, "dnet_add_ch_lhs", il);
|
||||
|
||||
attn = ggml_neg(ctx0, attn);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_add(ctx0, lin_solve, identity);
|
||||
cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
|
||||
|
||||
// [S_v, CS, n_chunks, H_v * n_seqs]
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
|
||||
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
|
||||
|
||||
k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
|
||||
|
||||
// [CS, S_k, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
|
||||
cb(kbg, "k_beta_g_exp", il);
|
||||
|
||||
// [S_k, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
|
||||
cb(k_cd, "k_cumdecay", il);
|
||||
|
||||
// [S_k, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
kq = ggml_mul(ctx0, kq, decay_mask);
|
||||
kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
// vectorized calculation of key_gdiff
|
||||
// improved from the chunked version:
|
||||
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
|
||||
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
|
||||
// key_gdiff = key * g_diff.unsqueeze(-1)
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
|
||||
// get last element in g_cumsum along CS dimension (ne0)
|
||||
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
||||
// [1, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
|
||||
g_cs->nb[1],
|
||||
g_cs->nb[2],
|
||||
g_cs->nb[3],
|
||||
ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
|
||||
cb(g_last, "g_last", il);
|
||||
|
||||
// TODO: remove this cont when CUDA supports non-cont unary ops
|
||||
g_last = ggml_cont(ctx0, g_last);
|
||||
|
||||
// [1, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
||||
cb(g_last_exp, "g_last_exp", il);
|
||||
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
|
||||
cb(g_diff, "g_diff", il);
|
||||
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
|
||||
|
||||
// [S_k, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
cb(kg, "key_gdiff", il);
|
||||
|
||||
// [CS, S_k, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
|
||||
cb(kg_t, "key_gdiff_t", il);
|
||||
|
||||
ggml_tensor * s_t = ggml_transpose(ctx0, s);
|
||||
s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
|
||||
// [CS, S_v, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs]
|
||||
ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs]
|
||||
ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs]
|
||||
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
|
||||
|
||||
// [CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
|
||||
cb(v_t_p, "v_prime", il);
|
||||
|
||||
// [CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
|
||||
cb(v_t_new, "v_t_new", il);
|
||||
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
|
||||
cb(v_attn, "v_attn", il);
|
||||
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
|
||||
cb(attn_inter, "attn_inter", il);
|
||||
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
|
||||
cb(o_ch, "dnet_add_ch_attn_out", il);
|
||||
|
||||
v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
|
||||
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
// TODO: head broadcast might not work here - probably will need a transpose
|
||||
ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
|
||||
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
|
||||
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
|
||||
s_t = ggml_add(ctx0, s_t, kgv);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
}
|
||||
|
||||
s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * o = ggml_view_4d(ctx0, v,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(v->type, S_v),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
|
||||
|
||||
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
||||
|
||||
return {o, s};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b, // beta
|
||||
ggml_tensor * s, // state
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(n_tokens == 1);
|
||||
|
||||
GGML_ASSERT(S_k == S_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(b, "b_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
|
||||
|
||||
// [S_v, S_v, H_v, n_seqs]
|
||||
g = ggml_exp(ctx0, g);
|
||||
s = ggml_mul(ctx0, s, g);
|
||||
|
||||
ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
|
||||
|
||||
// [1, S_v, H_v, n_seqs]
|
||||
ggml_tensor * sk;
|
||||
sk = ggml_mul (ctx0, s_t, k);
|
||||
sk = ggml_sum_rows(ctx0, sk);
|
||||
|
||||
// [S_v, 1, H_v, n_seqs]
|
||||
ggml_tensor * d;
|
||||
d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
|
||||
d = ggml_mul(ctx0, d, b);
|
||||
|
||||
// [1, S_v, H_v, n_seqs]
|
||||
ggml_tensor * d_t;
|
||||
d_t = ggml_transpose(ctx0, d);
|
||||
|
||||
// [S_v, S_v, H_v, n_seqs]
|
||||
ggml_tensor * kd;
|
||||
k = ggml_repeat(ctx0, k, s);
|
||||
kd = ggml_mul (ctx0, k, d_t);
|
||||
|
||||
s_t = ggml_add(ctx0, s_t, kd);
|
||||
|
||||
cb(s_t, "dnet_add_ar_state", il);
|
||||
|
||||
ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
|
||||
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
|
||||
|
||||
o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
||||
|
||||
return {o, s};
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params) {
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params) {
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
|
||||
llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "models.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
// Causal Conv1d function for Q,K,V
|
||||
@@ -65,7 +67,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t
|
||||
}
|
||||
|
||||
llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params), model(model) {
|
||||
llm_build_mamba_base(params), model(model) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#include "models.h"
|
||||
|
||||
llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
|
||||
llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {}
|
||||
|
||||
ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
const llama_model & model,
|
||||
const llama_ubatch & ubatch,
|
||||
@@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
|
||||
ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
const llama_model & model,
|
||||
const llama_ubatch & ubatch,
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
|
||||
llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
|
||||
@@ -1,23 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "../llama-model.h"
|
||||
#include "../llama-graph.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-graph.h"
|
||||
|
||||
// TODO: remove in follow-up PR - move to .cpp files
|
||||
#include "../llama-memory-recurrent.h"
|
||||
// note: almost all graphs require atleast sqrtf, so include cmath globally
|
||||
#include <cmath>
|
||||
|
||||
struct llm_graph_context_mamba : public llm_graph_context {
|
||||
llm_graph_context_mamba(const llm_graph_params & params);
|
||||
//
|
||||
// base classes
|
||||
//
|
||||
|
||||
virtual ~llm_graph_context_mamba() = default;
|
||||
struct llm_build_mamba_base : public llm_graph_context {
|
||||
llm_build_mamba_base(const llm_graph_params & params);
|
||||
|
||||
virtual ~llm_build_mamba_base() = default;
|
||||
|
||||
ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
|
||||
ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
|
||||
|
||||
};
|
||||
|
||||
// Base class for RWKV-related models
|
||||
struct llm_build_delta_net_base : public llm_graph_context {
|
||||
llm_build_delta_net_base(const llm_graph_params & params);
|
||||
|
||||
virtual ~llm_build_delta_net_base() = default;
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * b,
|
||||
ggml_tensor * s,
|
||||
int il);
|
||||
};
|
||||
|
||||
struct llm_build_rwkv6_base : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
@@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
||||
int il) const;
|
||||
};
|
||||
|
||||
//
|
||||
// models
|
||||
//
|
||||
|
||||
struct llm_build_afmoe : public llm_graph_context {
|
||||
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
@@ -175,7 +207,7 @@ struct llm_build_falcon : public llm_graph_context {
|
||||
llm_build_falcon(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_falcon_h1 : public llm_graph_context_mamba {
|
||||
struct llm_build_falcon_h1 : public llm_build_mamba_base {
|
||||
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
@@ -253,7 +285,7 @@ private:
|
||||
const int il);
|
||||
};
|
||||
|
||||
struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
struct llm_build_granite_hybrid : public llm_build_mamba_base {
|
||||
llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
|
||||
ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
|
||||
@@ -284,11 +316,12 @@ struct llm_build_jais : public llm_graph_context {
|
||||
llm_build_jais(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_jamba : public llm_graph_context_mamba {
|
||||
struct llm_build_jamba : public llm_build_mamba_base {
|
||||
llm_build_jamba(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_kimi_linear : public llm_graph_context_mamba {
|
||||
// TODO: derive llm_build_delta_net_base instead
|
||||
struct llm_build_kimi_linear : public llm_build_mamba_base {
|
||||
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive(
|
||||
@@ -347,7 +380,7 @@ struct llm_build_maincoder : public llm_graph_context {
|
||||
llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mamba : public llm_graph_context_mamba {
|
||||
struct llm_build_mamba : public llm_build_mamba_base {
|
||||
llm_build_mamba(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
@@ -379,11 +412,11 @@ struct llm_build_nemotron : public llm_graph_context {
|
||||
llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_nemotron_h : public llm_graph_context_mamba {
|
||||
struct llm_build_nemotron_h : public llm_build_mamba_base {
|
||||
llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
|
||||
ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il);
|
||||
ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
|
||||
const llama_model & model, const int64_t n_embd_head, const int il);
|
||||
const llama_model & model, int64_t n_embd_head, int il);
|
||||
};
|
||||
|
||||
struct llm_build_neo_bert : public llm_graph_context {
|
||||
@@ -428,7 +461,7 @@ struct llm_build_phi3 : public llm_graph_context {
|
||||
llm_build_phi3(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_plamo2 : public llm_graph_context_mamba {
|
||||
struct llm_build_plamo2 : public llm_build_mamba_base {
|
||||
llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
|
||||
@@ -477,7 +510,7 @@ struct llm_build_qwen3vlmoe : public llm_graph_context {
|
||||
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_qwen3next : public llm_graph_context_mamba {
|
||||
struct llm_build_qwen3next : public llm_build_delta_net_base {
|
||||
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
@@ -489,38 +522,12 @@ private:
|
||||
ggml_tensor * build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
@@ -535,7 +542,8 @@ private:
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct llm_build_qwen35 : public llm_graph_context_mamba {
|
||||
// TODO: derive llm_build_delta_net_base instead
|
||||
struct llm_build_qwen35 : public llm_graph_context {
|
||||
llm_build_qwen35(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
@@ -553,6 +561,7 @@ private:
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
int il);
|
||||
@@ -594,7 +603,8 @@ private:
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct llm_build_qwen35moe : public llm_graph_context_mamba {
|
||||
// TODO: derive llm_build_delta_net_base instead
|
||||
struct llm_build_qwen35moe : public llm_graph_context {
|
||||
llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
|
||||
|
||||
llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params) {
|
||||
llm_build_mamba_base(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
@@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_
|
||||
ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
int64_t n_embd_head,
|
||||
int il) {
|
||||
// compute Q and K
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
@@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
|
||||
ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) {
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params) {
|
||||
llm_build_mamba_base(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#include "ggml.h"
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params), model(model) {
|
||||
llm_graph_context(params), model(model) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#include "ggml.h"
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params), model(model) {
|
||||
llm_graph_context(params), model(model) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#include "ggml.h"
|
||||
#include "models.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params), model(model) {
|
||||
llm_build_delta_net_base(params), model(model) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -16,17 +15,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
|
||||
|
||||
ggml_build_forward_expand(gf, causal_mask);
|
||||
ggml_build_forward_expand(gf, identity);
|
||||
ggml_build_forward_expand(gf, diag_mask);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -36,7 +24,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
||||
// Determine layer type and build appropriate attention mechanism
|
||||
if (hparams.is_recurrent(il)) {
|
||||
// Linear attention layer (gated delta net)
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, il);
|
||||
} else {
|
||||
// Full attention layer
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
|
||||
@@ -94,354 +82,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
|
||||
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
cb(q, "q_perm", il);
|
||||
cb(k, "k_perm", il);
|
||||
cb(v, "v_perm", il);
|
||||
cb(beta, "beta_perm", il);
|
||||
cb(g, "g_perm", il);
|
||||
cb(state, "state_in", il);
|
||||
|
||||
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
|
||||
|
||||
// Do padding
|
||||
const int64_t chunk_size = CHUNK_SIZE;
|
||||
|
||||
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
||||
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
|
||||
|
||||
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
||||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
|
||||
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
|
||||
|
||||
cb(q, "q_pad", il);
|
||||
cb(k, "k_pad", il);
|
||||
cb(v, "v_pad", il);
|
||||
cb(beta, "beta_pad", il);
|
||||
cb(g, "g_pad", il);
|
||||
|
||||
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
|
||||
|
||||
cb(v_beta, "v_beta", il);
|
||||
cb(k_beta, "k_beta", il);
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
|
||||
|
||||
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
|
||||
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * gcs_j_broadcast =
|
||||
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
|
||||
cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
|
||||
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
|
||||
|
||||
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
|
||||
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
|
||||
cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
|
||||
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_mul(ctx0, lin_solve, causal_mask);
|
||||
attn = ggml_add(ctx0, attn, identity);
|
||||
cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
|
||||
|
||||
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
|
||||
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
|
||||
|
||||
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
|
||||
cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * k_cumdecay =
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
|
||||
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
|
||||
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
|
||||
// vectorized calculation of key_gdiff
|
||||
// improved from the chunked version:
|
||||
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
|
||||
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
|
||||
// key_gdiff = key * g_diff.unsqueeze(-1)
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
|
||||
// get last element in g_cumsum along chunk_size dimension (ne0)
|
||||
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
|
||||
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
|
||||
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
|
||||
g_last = ggml_cont(ctx0, g_last);
|
||||
cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
||||
cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
|
||||
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
|
||||
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
|
||||
|
||||
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
|
||||
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
|
||||
|
||||
|
||||
// state to be updated per chunk
|
||||
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
|
||||
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
|
||||
|
||||
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * core_attn_out = nullptr;
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
// shape: (S_k, chunk_size, 1, H_k * n_seqs)
|
||||
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
|
||||
|
||||
// shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
|
||||
|
||||
// shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
|
||||
|
||||
// shape: (chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
|
||||
|
||||
// attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
// replaced by precomputed attn_kq
|
||||
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
|
||||
cb(attn_chunk, "attn_chunk", il);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
|
||||
|
||||
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
|
||||
cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
|
||||
|
||||
// v_new = v_i - v_prime
|
||||
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
|
||||
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
|
||||
cb(v_new, "v_new_chunk", il);
|
||||
|
||||
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
|
||||
cb(attn_inter, "attn_inter_chunk", il);
|
||||
|
||||
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
|
||||
cb(v_attn, "v_attn_chunk", il);
|
||||
|
||||
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
|
||||
cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
|
||||
core_attn_out = core_attn_out == nullptr
|
||||
? core_attn_out_chunk
|
||||
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
|
||||
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
|
||||
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
|
||||
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
|
||||
new_state = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
|
||||
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
|
||||
}
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
|
||||
// permute back to (S_v, H_v, n_tokens, n_seqs)
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
|
||||
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
|
||||
|
||||
// Apply exponential to g_t
|
||||
g_t = ggml_exp(ctx0, g_t);
|
||||
|
||||
// Apply the gated delta rule for the single timestep
|
||||
// last_recurrent_state = last_recurrent_state * g_t
|
||||
state = ggml_mul(ctx0, state, g_t);
|
||||
|
||||
// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
|
||||
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
|
||||
// we need to sum over dim=-2, so we transpose, sum, then transpose again
|
||||
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
|
||||
|
||||
// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
|
||||
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
|
||||
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
|
||||
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
|
||||
state = ggml_add(ctx0, state, k_t_delta);
|
||||
|
||||
// Compute the attention output
|
||||
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
|
||||
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
|
||||
// again, since it's over dim = -2, transpose, sum, transpose back
|
||||
ggml_tensor * core_attn_out =
|
||||
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
|
||||
|
||||
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
|
||||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
|
||||
return {core_attn_out, state};
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_qwen3next::build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
@@ -472,39 +112,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||
// Split Q projection into query and gate
|
||||
// The split should be along dimension 0 (the feature dimension)
|
||||
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
||||
cb(Qcur, "Qcur_view", il);
|
||||
|
||||
ggml_tensor * gate =
|
||||
ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(gate, "gate", il);
|
||||
|
||||
// Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
cb(Qcur, "Qcur_reshaped", il);
|
||||
|
||||
// Apply Q normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// Apply K normalization
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
cb(gate, "gate_reshaped", il);
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply RoPE
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
@@ -519,7 +149,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// Attention computation
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
cur = build_attn(inp,
|
||||
@@ -527,10 +156,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_pregate", il);
|
||||
|
||||
ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate_sigmoid, "gate_sigmoid", il);
|
||||
// TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, gate_sigmoid);
|
||||
gate = ggml_sigmoid(ctx0, gate);
|
||||
cb(gate, "gate_sigmoid", il);
|
||||
|
||||
gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, gate);
|
||||
cb(cur, "attn_gated", il);
|
||||
|
||||
cur = build_lora_mm(model.layers[il].wo, cur);
|
||||
@@ -560,7 +194,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
|
||||
cb(z, "z", il);
|
||||
|
||||
return { qkv_mixed, z };
|
||||
|
||||
} else {
|
||||
// legacy (slower) path
|
||||
ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
|
||||
@@ -624,9 +257,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
|
||||
ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
@@ -671,7 +301,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
|
||||
cb(a, "a", il);
|
||||
|
||||
ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
// TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
|
||||
b = ggml_cont(ctx0, b);
|
||||
|
||||
ggml_tensor * beta = ggml_sigmoid(ctx0, b);
|
||||
|
||||
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
|
||||
// Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
|
||||
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -679,6 +314,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
|
||||
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
|
||||
cb(alpha_softplus, "a_softplus", il);
|
||||
|
||||
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
|
||||
cb(gate, "gate", il);
|
||||
|
||||
@@ -686,8 +322,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
@@ -696,11 +330,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
|
||||
cb(qkv_mixed, "qkv_mixed_permuted", il);
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
@@ -720,7 +355,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
cb(conv_states_all, "conv_states_updated", il);
|
||||
|
||||
// Apply SSM convolution
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
|
||||
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
|
||||
cb(conv_output_proper, "conv_output_raw", il);
|
||||
|
||||
@@ -734,26 +372,36 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
|
||||
|
||||
// Extract the convolved Q, K, V from conv_output
|
||||
ggml_tensor * q_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
|
||||
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
0);
|
||||
|
||||
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
|
||||
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_v_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
|
||||
|
||||
cb(q_conv, "q_conv", il);
|
||||
ggml_tensor * k_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(k_conv, "k_conv", il);
|
||||
ggml_tensor * v_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(v_conv, "v_conv", il);
|
||||
|
||||
// Unsqueeze them
|
||||
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
|
||||
|
||||
//q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
if (num_k_heads != num_v_heads) {
|
||||
@@ -786,7 +434,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
@@ -795,19 +443,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// Reshape both attn_out_final and z to 2D tensors for normalization
|
||||
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Apply gated normalization: self.norm(core_attn_out, z)
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
|
||||
|
||||
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
||||
ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -818,7 +462,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
cb(cur, "linear_attn_out", il);
|
||||
|
||||
// Reshape back to original dimensions
|
||||
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
@@ -839,7 +484,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||
if (model.layers[il].ffn_up_shexp != nullptr) {
|
||||
ggml_tensor * ffn_shexp =
|
||||
build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
@@ -852,11 +497,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
||||
ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
|
||||
cb(shared_gate, "shared_expert_gate", il);
|
||||
|
||||
// Apply sigmoid to the gate
|
||||
shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
||||
cb(shared_gate, "shared_expert_gate_sigmoid", il);
|
||||
|
||||
// Apply the gate to the shared expert output
|
||||
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
||||
cb(ffn_shexp, "ffn_shexp_gated", il);
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model) {}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "models.h"
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model) {}
|
||||
|
||||
@@ -769,6 +769,12 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
|
||||
} else if (regex_expr == "\\p{AFMoE_digits}") {
|
||||
// AFMOE digit pattern - use custom implementation for proper splitting
|
||||
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
|
||||
} else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") {
|
||||
// tiny_aya digit grouping pattern from tokenizer.json:
|
||||
// {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"}
|
||||
// Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567)
|
||||
// TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex.
|
||||
bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets);
|
||||
}
|
||||
|
||||
return bpe_offsets;
|
||||
|
||||
@@ -5821,20 +5821,27 @@ struct test_l2_norm : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
bool v;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR4(type, ne, eps, v);
|
||||
}
|
||||
|
||||
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
float eps = 1e-12f)
|
||||
: type(type), ne(ne), eps(eps) {}
|
||||
float eps = 1e-12f,
|
||||
bool v = false)
|
||||
: type(type), ne(ne), eps(eps), v(v) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (v) {
|
||||
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view of a");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
@@ -7596,7 +7603,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8293,7 +8301,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
|
||||
for (int kv : { 113, 512, 1024, }) {
|
||||
if (nr2 != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (int nb : { 1, 3, 32, 75, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
|
||||
@@ -20,6 +20,7 @@ add_library(mtmd
|
||||
models/internvl.cpp
|
||||
models/kimivl.cpp
|
||||
models/kimik25.cpp
|
||||
models/nemotron-v2-vl.cpp
|
||||
models/llama4.cpp
|
||||
models/llava.cpp
|
||||
models/minicpmv.cpp
|
||||
|
||||
@@ -236,6 +236,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_GLM4V,
|
||||
PROJECTOR_TYPE_YOUTUVL,
|
||||
PROJECTOR_TYPE_KIMIK25,
|
||||
PROJECTOR_TYPE_NEMOTRON_V2_VL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -270,6 +271,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_GLM4V, "glm4v"},
|
||||
{ PROJECTOR_TYPE_YOUTUVL, "youtuvl"},
|
||||
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
|
||||
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
||||
@@ -15,6 +15,7 @@ enum ffn_op_type {
|
||||
FFN_GELU_ERF,
|
||||
FFN_SILU,
|
||||
FFN_GELU_QUICK,
|
||||
FFN_RELU_SQR,
|
||||
};
|
||||
|
||||
enum norm_type {
|
||||
|
||||
@@ -559,6 +559,12 @@ ggml_tensor * clip_graph::build_ffn(
|
||||
cur = ggml_gelu_quick(ctx0, cur);
|
||||
cb(cur, "ffn_gelu_quick", il);
|
||||
} break;
|
||||
case FFN_RELU_SQR:
|
||||
{
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cur = ggml_sqr(ctx0, cur);
|
||||
cb(cur, "ffn_relu_sqr", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
if (down) {
|
||||
@@ -810,6 +816,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
builder = std::make_unique<clip_graph_internvl>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_nemotron_v2_vl>(ctx, img);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
builder = std::make_unique<clip_graph_llama4>(ctx, img);
|
||||
@@ -1110,6 +1120,7 @@ struct clip_model_loader {
|
||||
}
|
||||
} break;
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
|
||||
} break;
|
||||
@@ -1767,6 +1778,12 @@ struct clip_model_loader {
|
||||
model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
||||
model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
|
||||
model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
|
||||
model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
|
||||
} break;
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
{
|
||||
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
|
||||
@@ -3088,6 +3105,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
case PROJECTOR_TYPE_GLM_EDGE:
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
{
|
||||
clip_image_u8 resized_image;
|
||||
int sz = params.image_size;
|
||||
@@ -3397,6 +3415,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
||||
case PROJECTOR_TYPE_GEMMA3:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
{
|
||||
// both X and Y are downscaled by the scale factor
|
||||
@@ -3805,6 +3824,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
case PROJECTOR_TYPE_GEMMA3NV:
|
||||
case PROJECTOR_TYPE_IDEFICS3:
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
case PROJECTOR_TYPE_QWEN2A:
|
||||
case PROJECTOR_TYPE_GLMA:
|
||||
case PROJECTOR_TYPE_ULTRAVOX:
|
||||
@@ -3968,6 +3988,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_MUSIC_FLAMINGO:
|
||||
return ctx->model.mm_2_w->ne[1];
|
||||
case PROJECTOR_TYPE_INTERNVL:
|
||||
case PROJECTOR_TYPE_NEMOTRON_V2_VL:
|
||||
return ctx->model.mm_3_w->ne[1];
|
||||
case PROJECTOR_TYPE_LLAMA4:
|
||||
return ctx->model.mm_model_proj->ne[1];
|
||||
|
||||
@@ -42,6 +42,11 @@ struct clip_graph_internvl : clip_graph {
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_nemotron_v2_vl : clip_graph {
|
||||
clip_graph_nemotron_v2_vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
};
|
||||
|
||||
struct clip_graph_llama4 : clip_graph {
|
||||
clip_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {}
|
||||
ggml_cgraph * build() override;
|
||||
|
||||
35
tools/mtmd/models/nemotron-v2-vl.cpp
Normal file
35
tools/mtmd/models/nemotron-v2-vl.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#include "models.h"
|
||||
|
||||
ggml_cgraph * clip_graph_nemotron_v2_vl::build() {
|
||||
GGML_ASSERT(model.class_embedding != nullptr);
|
||||
GGML_ASSERT(model.position_embeddings != nullptr);
|
||||
|
||||
const int n_registers = model.class_embedding->ne[1];
|
||||
const int n_pos = n_patches + n_registers;
|
||||
|
||||
ggml_tensor * inp = build_inp();
|
||||
|
||||
// add position embeddings (pre-downsampled during GGUF conversion for fixed 512x512 input)
|
||||
inp = ggml_add(ctx0, inp, model.position_embeddings);
|
||||
cb(inp, "inp_pos", -1);
|
||||
|
||||
inp = ggml_concat(ctx0, model.class_embedding, inp, 1);
|
||||
|
||||
ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, hparams.ffn_op, nullptr, nullptr);
|
||||
|
||||
cur = ggml_view_2d(ctx0, cur,
|
||||
n_embd, n_patches,
|
||||
ggml_row_size(cur->type, n_embd),
|
||||
n_registers * ggml_row_size(cur->type, n_embd));
|
||||
|
||||
cur = build_patch_merge_permute(cur, model.hparams.n_merge);
|
||||
|
||||
{
|
||||
cur = build_norm(cur, model.mm_0_w, nullptr, NORM_TYPE_RMS, 1e-6, -1);
|
||||
cur = build_ffn(cur, model.mm_1_w, nullptr, nullptr, nullptr, model.mm_3_w, nullptr, FFN_RELU_SQR, -1);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
@@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
||||
int count = 0;
|
||||
double nll = 0.0;
|
||||
|
||||
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||
const int n_seq = std::max(1, n_batch / n_ctx);
|
||||
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int start = i * params.ppl_stride;
|
||||
@@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
}
|
||||
|
||||
const int n_batch = params.n_batch;
|
||||
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
|
||||
const int num_batches = (static_cast<int>(n_ctx) + n_batch - 1) / n_batch;
|
||||
// Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports
|
||||
const int n_seq_max = llama_n_seq_max(ctx);
|
||||
int n_seq = std::max(1, n_batch / static_cast<int>(n_ctx));
|
||||
if (n_seq > n_seq_max) {
|
||||
LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n",
|
||||
__func__, n_seq, n_seq_max, n_seq_max);
|
||||
n_seq = n_seq_max;
|
||||
}
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
llama_batch batch = llama_batch_init(std::min(n_batch, static_cast<int>(n_ctx)*n_seq), 0, 1);
|
||||
|
||||
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
||||
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||
std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||
@@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
logits.reserve(size_t(n_ctx) * n_vocab);
|
||||
}
|
||||
|
||||
LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
|
||||
@@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
||||
auto kld_ptr = kld_values.data();
|
||||
auto p_diff_ptr = p_diff_values.data();
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int first = n_ctx/2;
|
||||
|
||||
for (int i = 0; i < n_chunk; i += n_seq) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
const int n_seq_batch = std::min(n_seq, n_chunk - i);
|
||||
|
||||
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
||||
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
|
||||
return;
|
||||
}
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[batch_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
int n_outputs = 0;
|
||||
|
||||
common_batch_clear(batch);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[seq_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[seq_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
const int pos = j*n_batch + k;
|
||||
const bool need_logits = pos >= first;
|
||||
common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits);
|
||||
n_outputs += need_logits;
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
LOG_ERR("%s : failed to decode\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return;
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[batch_start] = token_org;
|
||||
|
||||
if (num_batches > 1) {
|
||||
if (num_batches > 1 && n_outputs > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0) {
|
||||
llama_synchronize(ctx);
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||
int total_seconds = (int)(t_total * n_chunk);
|
||||
int total_seconds = (int)(t_total * n_chunk / n_seq);
|
||||
if (total_seconds >= 60*60) {
|
||||
LOG("%d hours ", total_seconds / (60*60));
|
||||
total_seconds = total_seconds % (60*60);
|
||||
}
|
||||
LOG("%.2f minutes\n", total_seconds / 60.0);
|
||||
LOG("\n");
|
||||
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
|
||||
}
|
||||
LOG("\n");
|
||||
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
|
||||
|
||||
const int first = n_ctx/2;
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
|
||||
p_diff_ptr += n_ctx - 1 - first;
|
||||
kld_ptr += n_ctx - 1 - first;
|
||||
// Read log probs for each sequence in the batch
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
||||
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq);
|
||||
llama_batch_free(batch);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG("%4d", i+1);
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
|
||||
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||
const double ppl_val = exp(log_ppl.first);
|
||||
const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
|
||||
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
|
||||
process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first,
|
||||
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
|
||||
p_diff_ptr += n_ctx - 1 - first;
|
||||
kld_ptr += n_ctx - 1 - first;
|
||||
|
||||
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
|
||||
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
|
||||
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
|
||||
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
|
||||
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
|
||||
LOG("%4d", i + seq + 1);
|
||||
|
||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
|
||||
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||
const double ppl_val = exp(log_ppl.first);
|
||||
const double ppl_unc = ppl_val * log_ppl.second;
|
||||
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
|
||||
|
||||
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
|
||||
const double p_diff_rms_val = sqrt(p_diff_mse.first);
|
||||
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
|
||||
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
|
||||
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
|
||||
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
|
||||
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
|
||||
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
|
||||
|
||||
double p_top_val = 1.*kld.n_same_top/kld.count;
|
||||
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
|
||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
|
||||
|
||||
LOG("\n");
|
||||
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
|
||||
const double p_diff_rms_val = sqrt(p_diff_mse.first);
|
||||
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
|
||||
|
||||
double p_top_val = 1.*kld.n_same_top/kld.count;
|
||||
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
|
||||
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
logits.clear();
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
LOG("\n");
|
||||
|
||||
if (kld.count < 100) return; // we do not wish to do statistics on so few values
|
||||
@@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
|
||||
|
||||
if (ppl) {
|
||||
if (ppl || params.kl_divergence) {
|
||||
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
|
||||
const int32_t n_kv = n_seq * n_ctx;
|
||||
|
||||
@@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) {
|
||||
params.n_batch = std::min(params.n_batch, n_kv);
|
||||
} else {
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
if (params.kl_divergence) {
|
||||
params.n_parallel = 1;
|
||||
} else {
|
||||
// ensure there's at least enough seq_ids for HellaSwag
|
||||
params.n_parallel = std::max(4, params.n_parallel);
|
||||
}
|
||||
// ensure there's at least enough seq_ids for HellaSwag
|
||||
params.n_parallel = std::max(4, params.n_parallel);
|
||||
}
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
|
||||
@@ -132,7 +132,8 @@ static std::string fs_get_cache_directory() {
|
||||
if (getenv("LLAMA_CACHE")) {
|
||||
cache_directory = std::getenv("LLAMA_CACHE");
|
||||
} else {
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
|
||||
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \
|
||||
defined(__OpenBSD__) || defined(__NetBSD__)
|
||||
if (std::getenv("XDG_CACHE_HOME")) {
|
||||
cache_directory = std::getenv("XDG_CACHE_HOME");
|
||||
} else if (std::getenv("HOME")) {
|
||||
|
||||
@@ -28,10 +28,6 @@ target_link_libraries(${TARGET} PUBLIC common mtmd ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
set(TARGET llama-server)
|
||||
|
||||
if (NOT LLAMA_HTTPLIB)
|
||||
message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF")
|
||||
endif()
|
||||
|
||||
set(TARGET_SRCS
|
||||
server.cpp
|
||||
server-http.cpp
|
||||
|
||||
Binary file not shown.
@@ -1,17 +1,24 @@
|
||||
import type { StorybookConfig } from '@storybook/sveltekit';
|
||||
import { dirname, resolve } from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
const config: StorybookConfig = {
|
||||
stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'],
|
||||
addons: [
|
||||
'@storybook/addon-svelte-csf',
|
||||
'@chromatic-com/storybook',
|
||||
'@storybook/addon-docs',
|
||||
'@storybook/addon-vitest',
|
||||
'@storybook/addon-a11y',
|
||||
'@storybook/addon-vitest'
|
||||
'@storybook/addon-docs'
|
||||
],
|
||||
framework: {
|
||||
name: '@storybook/sveltekit',
|
||||
options: {}
|
||||
framework: '@storybook/sveltekit',
|
||||
viteFinal: async (config) => {
|
||||
config.server = config.server || {};
|
||||
config.server.fs = config.server.fs || {};
|
||||
config.server.fs.allow = [...(config.server.fs.allow || []), resolve(__dirname, '../tests')];
|
||||
return config;
|
||||
}
|
||||
};
|
||||
export default config;
|
||||
|
||||
@@ -13,7 +13,7 @@ const preview: Preview = {
|
||||
},
|
||||
|
||||
backgrounds: {
|
||||
disable: true
|
||||
disabled: true
|
||||
},
|
||||
|
||||
a11y: {
|
||||
|
||||
@@ -49,14 +49,20 @@ sequenceDiagram
|
||||
settingsStore->>serverStore: defaultParams
|
||||
serverStore-->>settingsStore: {temperature, top_p, top_k, ...}
|
||||
|
||||
settingsStore->>ParamSvc: extractServerDefaults(defaultParams)
|
||||
ParamSvc-->>settingsStore: Record<string, value>
|
||||
loop each SYNCABLE_PARAMETER
|
||||
alt key NOT in userOverrides
|
||||
settingsStore->>settingsStore: config[key] = serverDefault[key]
|
||||
Note right of settingsStore: Non-overridden params adopt server default
|
||||
else key in userOverrides
|
||||
Note right of settingsStore: Keep user value, skip server default
|
||||
end
|
||||
end
|
||||
|
||||
settingsStore->>ParamSvc: mergeWithServerDefaults(config, serverDefaults)
|
||||
Note right of ParamSvc: For each syncable parameter:<br/>- If NOT in userOverrides → use server default<br/>- If in userOverrides → keep user value
|
||||
ParamSvc-->>settingsStore: mergedConfig
|
||||
alt serverStore.props has webuiSettings
|
||||
settingsStore->>settingsStore: Apply webuiSettings from server
|
||||
Note right of settingsStore: Server-provided UI settings<br/>(e.g. showRawOutputSwitch)
|
||||
end
|
||||
|
||||
settingsStore->>settingsStore: config = mergedConfig
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
@@ -67,11 +73,18 @@ sequenceDiagram
|
||||
UI->>settingsStore: updateConfig(key, value)
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: config[key] = value
|
||||
settingsStore->>settingsStore: userOverrides.add(key)
|
||||
Note right of settingsStore: Mark as user-modified (won't be overwritten by server)
|
||||
|
||||
alt value matches server default for key
|
||||
settingsStore->>settingsStore: userOverrides.delete(key)
|
||||
Note right of settingsStore: Matches server default, remove override
|
||||
else value differs from server default
|
||||
settingsStore->>settingsStore: userOverrides.add(key)
|
||||
Note right of settingsStore: Mark as user-modified (won't be overwritten)
|
||||
end
|
||||
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
settingsStore->>LS: set("llama-config", config)
|
||||
settingsStore->>LS: set("llama-userOverrides", [...userOverrides])
|
||||
settingsStore->>LS: set(CONFIG_LOCALSTORAGE_KEY, config)
|
||||
settingsStore->>LS: set(USER_OVERRIDES_LOCALSTORAGE_KEY, [...userOverrides])
|
||||
deactivate settingsStore
|
||||
|
||||
UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2})
|
||||
@@ -88,10 +101,9 @@ sequenceDiagram
|
||||
|
||||
UI->>settingsStore: resetConfig()
|
||||
activate settingsStore
|
||||
settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT
|
||||
settingsStore->>settingsStore: config = {...SETTING_CONFIG_DEFAULT}
|
||||
settingsStore->>settingsStore: userOverrides.clear()
|
||||
settingsStore->>settingsStore: syncWithServerDefaults()
|
||||
Note right of settingsStore: Apply server defaults for syncable params
|
||||
Note right of settingsStore: All params reset to defaults<br/>Next syncWithServerDefaults will adopt server values
|
||||
settingsStore->>settingsStore: saveConfig()
|
||||
deactivate settingsStore
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<script lang="ts">
|
||||
import { Eye } from '@lucide/svelte';
|
||||
import ActionIconCopyToClipboard from '$lib/components/app/actions/ActionIconCopyToClipboard.svelte';
|
||||
import { ActionIconCopyToClipboard } from '$lib/components/app';
|
||||
import { FileTypeText } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
|
||||
@@ -57,13 +57,13 @@
|
||||
let currentConfig = $derived(config());
|
||||
let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined);
|
||||
let isRecording = $state(false);
|
||||
let message = $state(initialMessage);
|
||||
let message = $derived(initialMessage);
|
||||
let pasteLongTextToFileLength = $derived.by(() => {
|
||||
const n = Number(currentConfig.pasteLongTextToFileLen);
|
||||
return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n;
|
||||
});
|
||||
let previousIsLoading = $state(isLoading);
|
||||
let previousInitialMessage = $state(initialMessage);
|
||||
let previousIsLoading = $derived(isLoading);
|
||||
let previousInitialMessage = $derived(initialMessage);
|
||||
let recordingSupported = $state(false);
|
||||
let textareaRef: ChatFormTextarea | undefined = $state(undefined);
|
||||
|
||||
@@ -289,7 +289,7 @@
|
||||
|
||||
<form
|
||||
onsubmit={handleSubmit}
|
||||
class="{INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
class="relative {INPUT_CLASSES} border-radius-bottom-none mx-auto max-w-[48rem] overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
? 'cursor-not-allowed opacity-60'
|
||||
: ''} {className}"
|
||||
data-slot="chat-form"
|
||||
@@ -304,10 +304,11 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl px-5 py-3 shadow-sm transition-all focus-within:shadow-md"
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
|
||||
onpaste={handlePaste}
|
||||
>
|
||||
<ChatFormTextarea
|
||||
class="px-5 py-1.5 md:pt-0"
|
||||
bind:this={textareaRef}
|
||||
bind:value={message}
|
||||
onKeydown={handleKeydown}
|
||||
@@ -315,6 +316,7 @@
|
||||
/>
|
||||
|
||||
<ChatFormActions
|
||||
class="px-3"
|
||||
bind:this={chatFormActionsRef}
|
||||
canSend={message.trim().length > 0 || uploadedFiles.length > 0}
|
||||
hasText={message.trim().length > 0}
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
<script lang="ts">
|
||||
import { page } from '$app/state';
|
||||
import { MessageSquare, Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { FILE_TYPE_ICONS } from '$lib/constants/icons';
|
||||
import { TOOLTIP_DELAY_DURATION } from '$lib/constants/tooltip-config';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
hasAudioModality?: boolean;
|
||||
hasVisionModality?: boolean;
|
||||
onFileUpload?: () => void;
|
||||
onSystemPromptClick?: () => void;
|
||||
}
|
||||
|
||||
type AttachmentActionId = 'images' | 'audio' | 'text' | 'pdf' | 'system';
|
||||
|
||||
interface AttachmentAction {
|
||||
id: AttachmentActionId;
|
||||
label: string;
|
||||
disabled?: boolean;
|
||||
disabledReason?: string;
|
||||
tooltip?: string;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
hasAudioModality = false,
|
||||
hasVisionModality = false,
|
||||
onFileUpload,
|
||||
onSystemPromptClick
|
||||
}: Props = $props();
|
||||
|
||||
let isNewChat = $derived(!page.params.id);
|
||||
let systemMessageTooltip = $derived(
|
||||
isNewChat
|
||||
? 'Add custom system message for a new conversation'
|
||||
: 'Inject custom system message at the beginning of the conversation'
|
||||
);
|
||||
|
||||
let actions = $derived.by<AttachmentAction[]>(() => [
|
||||
{
|
||||
id: 'images',
|
||||
label: 'Images',
|
||||
disabled: !hasVisionModality,
|
||||
disabledReason: !hasVisionModality
|
||||
? 'Images require vision models to be processed'
|
||||
: undefined
|
||||
},
|
||||
{
|
||||
id: 'audio',
|
||||
label: 'Audio Files',
|
||||
disabled: !hasAudioModality,
|
||||
disabledReason: !hasAudioModality
|
||||
? 'Audio files require audio models to be processed'
|
||||
: undefined
|
||||
},
|
||||
{
|
||||
id: 'text',
|
||||
label: 'Text Files'
|
||||
},
|
||||
{
|
||||
id: 'pdf',
|
||||
label: 'PDF Files',
|
||||
tooltip: !hasVisionModality
|
||||
? 'PDFs will be converted to text. Image-based PDFs may not work properly.'
|
||||
: undefined
|
||||
},
|
||||
{
|
||||
id: 'system',
|
||||
label: 'System Message',
|
||||
tooltip: systemMessageTooltip
|
||||
}
|
||||
]);
|
||||
|
||||
function handleActionClick(id: AttachmentActionId) {
|
||||
if (id === 'system') {
|
||||
onSystemPromptClick?.();
|
||||
return;
|
||||
}
|
||||
|
||||
onFileUpload?.();
|
||||
}
|
||||
|
||||
const triggerTooltipText = 'Add files or system message';
|
||||
const itemClass = 'flex cursor-pointer items-center gap-2';
|
||||
</script>
|
||||
|
||||
<div class="flex items-center gap-1 {className}">
|
||||
<DropdownMenu.Root>
|
||||
<DropdownMenu.Trigger name="Attach files" {disabled}>
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
{disabled}
|
||||
variant="secondary"
|
||||
type="button"
|
||||
>
|
||||
<span class="sr-only">{triggerTooltipText}</span>
|
||||
|
||||
<Plus class="h-4 w-4" />
|
||||
</Button>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content>
|
||||
<p>{triggerTooltipText}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
</DropdownMenu.Trigger>
|
||||
|
||||
<DropdownMenu.Content align="start" class="w-56">
|
||||
{#each actions as item (item.id)}
|
||||
{@const hasDisabledTooltip = !!item.disabled && !!item.disabledReason}
|
||||
{@const hasEnabledTooltip = !item.disabled && !!item.tooltip}
|
||||
|
||||
{#if hasDisabledTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item class={itemClass} disabled>
|
||||
{#if item.id === 'images'}
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
{:else if item.id === 'audio'}
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
{:else if item.id === 'text'}
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
{:else if item.id === 'pdf'}
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
{:else}
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{item.disabledReason}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else if hasEnabledTooltip}
|
||||
<Tooltip.Root delayDuration={TOOLTIP_DELAY_DURATION}>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<DropdownMenu.Item class={itemClass} onclick={() => handleActionClick(item.id)}>
|
||||
{#if item.id === 'images'}
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
{:else if item.id === 'audio'}
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
{:else if item.id === 'text'}
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
{:else if item.id === 'pdf'}
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
{:else}
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side="right">
|
||||
<p>{item.tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
<DropdownMenu.Item class={itemClass} onclick={() => handleActionClick(item.id)}>
|
||||
{#if item.id === 'images'}
|
||||
<FILE_TYPE_ICONS.image class="h-4 w-4" />
|
||||
{:else if item.id === 'audio'}
|
||||
<FILE_TYPE_ICONS.audio class="h-4 w-4" />
|
||||
{:else if item.id === 'text'}
|
||||
<FILE_TYPE_ICONS.text class="h-4 w-4" />
|
||||
{:else if item.id === 'pdf'}
|
||||
<FILE_TYPE_ICONS.pdf class="h-4 w-4" />
|
||||
{:else}
|
||||
<MessageSquare class="h-4 w-4" />
|
||||
{/if}
|
||||
|
||||
<span>{item.label}</span>
|
||||
</DropdownMenu.Item>
|
||||
{/if}
|
||||
{/each}
|
||||
</DropdownMenu.Content>
|
||||
</DropdownMenu.Root>
|
||||
</div>
|
||||
@@ -2,7 +2,7 @@
|
||||
import { Square } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import {
|
||||
ChatFormActionFileAttachments,
|
||||
ChatFormActionAttachmentsDropdown,
|
||||
ChatFormActionRecord,
|
||||
ChatFormActionSubmit,
|
||||
ModelsSelector
|
||||
@@ -157,7 +157,7 @@
|
||||
|
||||
const { handleModelChange } = useModelChangeValidation({
|
||||
getRequiredModalities: () => usedModalities(),
|
||||
onValidationFailure: async (previousModelId) => {
|
||||
onValidationFailure: async (previousModelId: string | null) => {
|
||||
if (previousModelId) {
|
||||
await modelsStore.selectModelById(previousModelId);
|
||||
}
|
||||
@@ -166,32 +166,39 @@
|
||||
</script>
|
||||
|
||||
<div class="flex w-full items-center gap-3 {className}" style="container-type: inline-size">
|
||||
<ChatFormActionFileAttachments
|
||||
class="mr-auto"
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
/>
|
||||
<div class="mr-auto flex items-center gap-2">
|
||||
<ChatFormActionAttachmentsDropdown
|
||||
{disabled}
|
||||
{hasAudioModality}
|
||||
{hasVisionModality}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<ModelsSelector
|
||||
{disabled}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
forceForegroundText={true}
|
||||
useGlobalSelection={true}
|
||||
onModelChange={handleModelChange}
|
||||
/>
|
||||
<div class="ml-auto flex items-center gap-1.5">
|
||||
<ModelsSelector
|
||||
{disabled}
|
||||
bind:this={selectorModelRef}
|
||||
currentModel={conversationModel}
|
||||
forceForegroundText={true}
|
||||
useGlobalSelection={true}
|
||||
onModelChange={handleModelChange}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<Button
|
||||
type="button"
|
||||
variant="secondary"
|
||||
onclick={onStop}
|
||||
class="h-8 w-8 bg-transparent p-0 hover:bg-destructive/20"
|
||||
class="group h-8 w-8 rounded-full p-0 hover:bg-destructive/10!"
|
||||
>
|
||||
<span class="sr-only">Stop</span>
|
||||
<Square class="h-8 w-8 fill-destructive stroke-destructive" />
|
||||
|
||||
<Square
|
||||
class="h-8 w-8 fill-muted-foreground stroke-muted-foreground group-hover:fill-destructive group-hover:stroke-destructive hover:fill-destructive hover:stroke-destructive"
|
||||
/>
|
||||
</Button>
|
||||
{:else if shouldShowRecordButton}
|
||||
<ChatFormActionRecord {disabled} {hasAudioModality} {isLoading} {isRecording} {onMicClick} />
|
||||
|
||||
@@ -62,8 +62,8 @@
|
||||
assistantMessages: number;
|
||||
messageTypes: string[];
|
||||
} | null>(null);
|
||||
let editedContent = $state(message.content);
|
||||
let editedExtras = $state<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
|
||||
let editedContent = $derived(message.content);
|
||||
let editedExtras = $derived<DatabaseMessageExtra[]>(message.extra ? [...message.extra] : []);
|
||||
let editedUploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
let isEditing = $state(false);
|
||||
let showDeleteDialog = $state(false);
|
||||
|
||||
@@ -105,7 +105,7 @@
|
||||
|
||||
const { handleModelChange } = useModelChangeValidation({
|
||||
getRequiredModalities: () => conversationsStore.getModalitiesUpToMessage(message.id),
|
||||
onSuccess: (modelName) => onRegenerate(modelName)
|
||||
onSuccess: (modelName: string) => onRegenerate(modelName)
|
||||
});
|
||||
|
||||
function handleCopyModel() {
|
||||
|
||||
@@ -133,7 +133,7 @@
|
||||
|
||||
const { handleModelChange } = useModelChangeValidation({
|
||||
getRequiredModalities,
|
||||
onValidationFailure: async (previousModelId) => {
|
||||
onValidationFailure: async (previousModelId: string | null) => {
|
||||
if (previousModelId) {
|
||||
await modelsStore.selectModelById(previousModelId);
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
initialView = ChatMessageStatsView.GENERATION
|
||||
}: Props = $props();
|
||||
|
||||
let activeView: ChatMessageStatsView = $state(initialView);
|
||||
let activeView: ChatMessageStatsView = $derived(initialView);
|
||||
let hasAutoSwitchedToGeneration = $state(false);
|
||||
|
||||
// In live mode: auto-switch to GENERATION tab when prompt processing completes
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { onMount } from 'svelte';
|
||||
import { fade, fly, slide } from 'svelte/transition';
|
||||
import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte';
|
||||
@@ -616,7 +617,7 @@
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? 'server'}
|
||||
type={(activeErrorDialog?.type as ErrorDialogType) ?? ErrorDialogType.SERVER}
|
||||
/>
|
||||
|
||||
<style>
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
<script lang="ts">
|
||||
import ChatForm from '$lib/components/app/chat/ChatForm/ChatForm.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
initialMessage?: string;
|
||||
isLoading?: boolean;
|
||||
onFileRemove?: (fileId: string) => void;
|
||||
onFileUpload?: (files: File[]) => void;
|
||||
onSend?: (message: string, files?: ChatUploadedFile[]) => Promise<boolean>;
|
||||
onStop?: () => void;
|
||||
onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void;
|
||||
showHelperText?: boolean;
|
||||
uploadedFiles?: ChatUploadedFile[];
|
||||
}
|
||||
|
||||
let {
|
||||
class: className,
|
||||
disabled = false,
|
||||
initialMessage = '',
|
||||
isLoading = false,
|
||||
onFileRemove,
|
||||
onFileUpload,
|
||||
onSend,
|
||||
onStop,
|
||||
onSystemPromptAdd,
|
||||
showHelperText = true,
|
||||
uploadedFiles = $bindable([])
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative mx-auto max-w-[48rem]">
|
||||
<ChatForm
|
||||
class={className}
|
||||
{disabled}
|
||||
{initialMessage}
|
||||
{isLoading}
|
||||
{onFileRemove}
|
||||
{onFileUpload}
|
||||
{onSend}
|
||||
{onStop}
|
||||
{onSystemPromptAdd}
|
||||
{showHelperText}
|
||||
bind:uploadedFiles
|
||||
/>
|
||||
</div>
|
||||
@@ -18,19 +18,24 @@
|
||||
} from '$lib/components/app';
|
||||
import { ScrollArea } from '$lib/components/ui/scroll-area';
|
||||
import { config, settingsStore } from '$lib/stores/settings.svelte';
|
||||
import {
|
||||
SETTINGS_SECTION_TITLES,
|
||||
type SettingsSectionTitle
|
||||
} from '$lib/constants/settings-sections';
|
||||
import { setMode } from 'mode-watcher';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
onSave?: () => void;
|
||||
initialSection?: SettingsSectionTitle;
|
||||
}
|
||||
|
||||
let { onSave }: Props = $props();
|
||||
let { onSave, initialSection }: Props = $props();
|
||||
|
||||
const settingSections: Array<{
|
||||
fields: SettingsFieldConfig[];
|
||||
icon: Component;
|
||||
title: string;
|
||||
title: SettingsSectionTitle;
|
||||
}> = [
|
||||
{
|
||||
title: 'General',
|
||||
@@ -285,7 +290,9 @@
|
||||
// }
|
||||
];
|
||||
|
||||
let activeSection = $state('General');
|
||||
let activeSection = $derived<SettingsSectionTitle>(
|
||||
initialSection ?? SETTINGS_SECTION_TITLES.GENERAL
|
||||
);
|
||||
let currentSection = $derived(
|
||||
settingSections.find((section) => section.title === activeSection) || settingSections[0]
|
||||
);
|
||||
@@ -295,6 +302,16 @@
|
||||
let canScrollRight = $state(false);
|
||||
let scrollContainer: HTMLDivElement | undefined = $state();
|
||||
|
||||
$effect(() => {
|
||||
if (!initialSection) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (settingSections.some((section) => section.title === initialSection)) {
|
||||
activeSection = initialSection;
|
||||
}
|
||||
});
|
||||
|
||||
function handleThemeChange(newTheme: string) {
|
||||
localConfig.theme = newTheme;
|
||||
|
||||
|
||||
@@ -142,7 +142,7 @@
|
||||
{
|
||||
icon: Download,
|
||||
label: 'Export',
|
||||
onclick: (e) => {
|
||||
onclick: (e: Event) => {
|
||||
e.stopPropagation();
|
||||
conversationsStore.downloadConversation(conversation.id);
|
||||
},
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import { rehypeRestoreTableHtml } from '$lib/markdown/table-html-restorer';
|
||||
import { rehypeEnhanceLinks } from '$lib/markdown/enhance-links';
|
||||
import { rehypeEnhanceCodeBlocks } from '$lib/markdown/enhance-code-blocks';
|
||||
import { rehypeResolveAttachmentImages } from '$lib/markdown/resolve-attachment-images';
|
||||
import { remarkLiteralHtml } from '$lib/markdown/literal-html';
|
||||
import { copyCodeToClipboard, preprocessLaTeX, getImageErrorFallbackHtml } from '$lib/utils';
|
||||
import {
|
||||
@@ -23,6 +24,7 @@
|
||||
DATA_ERROR_HANDLED_ATTR,
|
||||
BOOL_TRUE_STRING
|
||||
} from '$lib/constants/markdown';
|
||||
import { UrlPrefix } from '$lib/enums';
|
||||
import { FileTypeText } from '$lib/enums/files';
|
||||
import {
|
||||
highlightCode,
|
||||
@@ -33,8 +35,7 @@
|
||||
import githubDarkCss from 'highlight.js/styles/github-dark.css?inline';
|
||||
import githubLightCss from 'highlight.js/styles/github.css?inline';
|
||||
import { mode } from 'mode-watcher';
|
||||
import ActionIconsCodeBlock from '$lib/components/app/actions/ActionIconsCodeBlock.svelte';
|
||||
import DialogCodePreview from '$lib/components/app/misc/CodePreviewDialog.svelte';
|
||||
import { ActionIconsCodeBlock, DialogCodePreview } from '$lib/components/app';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
|
||||
@@ -100,6 +101,7 @@
|
||||
.use(rehypeRestoreTableHtml) // Restore limited HTML (e.g., <br>, <ul>) inside Markdown tables
|
||||
.use(rehypeEnhanceLinks) // Add target="_blank" to links
|
||||
.use(rehypeEnhanceCodeBlocks) // Wrap code blocks with header and actions
|
||||
.use(rehypeResolveAttachmentImages, { attachments })
|
||||
.use(rehypeStringify, { allowDangerousHtml: true }); // Convert to HTML string
|
||||
});
|
||||
|
||||
@@ -500,7 +502,10 @@
|
||||
if (!img || !img.src) return;
|
||||
|
||||
// Don't handle data URLs or already-handled images
|
||||
if (img.src.startsWith('data:') || img.dataset[DATA_ERROR_HANDLED_ATTR] === BOOL_TRUE_STRING)
|
||||
if (
|
||||
img.src.startsWith(UrlPrefix.DATA) ||
|
||||
img.dataset[DATA_ERROR_HANDLED_ATTR] === BOOL_TRUE_STRING
|
||||
)
|
||||
return;
|
||||
img.dataset[DATA_ERROR_HANDLED_ATTR] = BOOL_TRUE_STRING;
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import { AlertTriangle, TimerOff } from '@lucide/svelte';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
type: 'timeout' | 'server';
|
||||
type: ErrorDialogType;
|
||||
message: string;
|
||||
contextInfo?: { n_prompt_tokens: number; n_ctx: number };
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
@@ -12,7 +13,7 @@
|
||||
|
||||
let { open = $bindable(), type, message, contextInfo, onOpenChange }: Props = $props();
|
||||
|
||||
const isTimeout = $derived(type === 'timeout');
|
||||
const isTimeout = $derived(type === ErrorDialogType.TIMEOUT);
|
||||
const title = $derived(isTimeout ? 'TCP Timeout' : 'Server Error');
|
||||
const description = $derived(
|
||||
isTimeout
|
||||
@@ -58,7 +59,12 @@
|
||||
<span class="font-medium">Prompt tokens:</span>
|
||||
{contextInfo.n_prompt_tokens.toLocaleString()}
|
||||
</p>
|
||||
<p><span class="font-medium">Context size:</span> {contextInfo.n_ctx.toLocaleString()}</p>
|
||||
{#if contextInfo.n_ctx}
|
||||
<p>
|
||||
<span class="font-medium">Context size:</span>
|
||||
{contextInfo.n_ctx.toLocaleString()}
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import { ChatSettings } from '$lib/components/app';
|
||||
import type { SettingsSectionTitle } from '$lib/constants/settings-sections';
|
||||
|
||||
interface Props {
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
open?: boolean;
|
||||
initialSection?: SettingsSectionTitle;
|
||||
}
|
||||
|
||||
let { onOpenChange, open = false }: Props = $props();
|
||||
let { onOpenChange, open = false, initialSection }: Props = $props();
|
||||
|
||||
let chatSettingsRef: ChatSettings | undefined = $state();
|
||||
|
||||
@@ -28,10 +30,9 @@
|
||||
|
||||
<Dialog.Root {open} onOpenChange={handleClose}>
|
||||
<Dialog.Content
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] flex-col gap-0 rounded-none p-0
|
||||
md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
style="max-width: 48rem;"
|
||||
class="z-999999 flex h-[100dvh] max-h-[100dvh] min-h-[100dvh] max-w-4xl! flex-col gap-0 rounded-none
|
||||
p-0 md:h-[64vh] md:max-h-[64vh] md:min-h-0 md:rounded-lg"
|
||||
>
|
||||
<ChatSettings bind:this={chatSettingsRef} onSave={handleSave} />
|
||||
<ChatSettings bind:this={chatSettingsRef} onSave={handleSave} {initialSection} />
|
||||
</Dialog.Content>
|
||||
</Dialog.Root>
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
<iframe
|
||||
bind:this={iframeRef}
|
||||
title="Preview {language}"
|
||||
sandbox="allow-scripts"
|
||||
sandbox="allow-scripts allow-same-origin"
|
||||
class="code-preview-iframe"
|
||||
></iframe>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import * as AlertDialog from '$lib/components/ui/alert-dialog';
|
||||
import type { Component } from 'svelte';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
|
||||
interface Props {
|
||||
open: boolean;
|
||||
@@ -29,7 +30,7 @@
|
||||
}: Props = $props();
|
||||
|
||||
function handleKeydown(event: KeyboardEvent) {
|
||||
if (event.key === 'Enter') {
|
||||
if (event.key === KeyboardKey.ENTER) {
|
||||
event.preventDefault();
|
||||
onConfirm();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import * as Dialog from '$lib/components/ui/dialog';
|
||||
import * as Table from '$lib/components/ui/table';
|
||||
import { BadgeModality, CopyToClipboardIcon } from '$lib/components/app';
|
||||
import { BadgeModality, ActionIconCopyToClipboard } from '$lib/components/app';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, modelOptions, modelsLoading } from '$lib/stores/models.svelte';
|
||||
import { formatFileSize, formatParameters, formatNumber } from '$lib/utils';
|
||||
@@ -47,6 +47,7 @@
|
||||
|
||||
<Dialog.Header>
|
||||
<Dialog.Title>Model Information</Dialog.Title>
|
||||
|
||||
<Dialog.Description>Current model details and capabilities</Dialog.Description>
|
||||
</Dialog.Header>
|
||||
|
||||
@@ -73,7 +74,7 @@
|
||||
{modelName}
|
||||
</span>
|
||||
|
||||
<CopyToClipboardIcon
|
||||
<ActionIconCopyToClipboard
|
||||
text={modelName || ''}
|
||||
canCopy={!!modelName}
|
||||
ariaLabel="Copy model name to clipboard"
|
||||
@@ -97,7 +98,7 @@
|
||||
{serverProps.model_path}
|
||||
</span>
|
||||
|
||||
<CopyToClipboardIcon
|
||||
<ActionIconCopyToClipboard
|
||||
text={serverProps.model_path}
|
||||
ariaLabel="Copy model path to clipboard"
|
||||
/>
|
||||
@@ -105,17 +106,29 @@
|
||||
</Table.Row>
|
||||
|
||||
<!-- Context Size -->
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Context Size</Table.Cell>
|
||||
<Table.Cell
|
||||
>{formatNumber(serverProps.default_generation_settings.n_ctx)} tokens</Table.Cell
|
||||
>
|
||||
</Table.Row>
|
||||
{#if serverProps?.default_generation_settings?.n_ctx}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Context Size</Table.Cell>
|
||||
|
||||
<Table.Cell
|
||||
>{formatNumber(serverProps.default_generation_settings.n_ctx)} tokens</Table.Cell
|
||||
>
|
||||
</Table.Row>
|
||||
{:else}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium text-red-500"
|
||||
>Context Size</Table.Cell
|
||||
>
|
||||
|
||||
<Table.Cell class="text-red-500">Not available</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
|
||||
<!-- Training Context -->
|
||||
{#if modelMeta?.n_ctx_train}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Training Context</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatNumber(modelMeta.n_ctx_train)} tokens</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -124,6 +137,7 @@
|
||||
{#if modelMeta?.size}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Model Size</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatFileSize(modelMeta.size)}</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -132,6 +146,7 @@
|
||||
{#if modelMeta?.n_params}
|
||||
<Table.Row>
|
||||
<Table.Cell class="h-10 align-middle font-medium">Parameters</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatParameters(modelMeta.n_params)}</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -140,6 +155,7 @@
|
||||
{#if modelMeta?.n_embd}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Embedding Size</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatNumber(modelMeta.n_embd)}</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -148,6 +164,7 @@
|
||||
{#if modelMeta?.n_vocab}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Vocabulary Size</Table.Cell>
|
||||
|
||||
<Table.Cell>{formatNumber(modelMeta.n_vocab)} tokens</Table.Cell>
|
||||
</Table.Row>
|
||||
{/if}
|
||||
@@ -163,6 +180,7 @@
|
||||
<!-- Total Slots -->
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Parallel Slots</Table.Cell>
|
||||
|
||||
<Table.Cell>{serverProps.total_slots}</Table.Cell>
|
||||
</Table.Row>
|
||||
|
||||
@@ -170,6 +188,7 @@
|
||||
{#if modalities.length > 0}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Modalities</Table.Cell>
|
||||
|
||||
<Table.Cell>
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<BadgeModality {modalities} />
|
||||
@@ -181,6 +200,7 @@
|
||||
<!-- Build Info -->
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Build Info</Table.Cell>
|
||||
|
||||
<Table.Cell class="align-middle font-mono text-xs"
|
||||
>{serverProps.build_info}</Table.Cell
|
||||
>
|
||||
@@ -190,6 +210,7 @@
|
||||
{#if serverProps.chat_template}
|
||||
<Table.Row>
|
||||
<Table.Cell class="align-middle font-medium">Chat Template</Table.Cell>
|
||||
|
||||
<Table.Cell class="py-10">
|
||||
<div class="max-h-120 overflow-y-auto rounded-md bg-muted p-4">
|
||||
<pre
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
<script lang="ts">
|
||||
import { Plus, Trash2 } from '@lucide/svelte';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { autoResizeTextarea } from '$lib/utils';
|
||||
import type { KeyValuePair } from '$lib/types';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
pairs: KeyValuePair[];
|
||||
onPairsChange: (pairs: KeyValuePair[]) => void;
|
||||
keyPlaceholder?: string;
|
||||
valuePlaceholder?: string;
|
||||
addButtonLabel?: string;
|
||||
emptyMessage?: string;
|
||||
sectionLabel?: string;
|
||||
sectionLabelOptional?: boolean;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
pairs,
|
||||
onPairsChange,
|
||||
keyPlaceholder = 'Key',
|
||||
valuePlaceholder = 'Value',
|
||||
addButtonLabel = 'Add',
|
||||
emptyMessage = 'No items configured.',
|
||||
sectionLabel,
|
||||
sectionLabelOptional = true
|
||||
}: Props = $props();
|
||||
|
||||
function addPair() {
|
||||
onPairsChange([...pairs, { key: '', value: '' }]);
|
||||
}
|
||||
|
||||
function removePair(index: number) {
|
||||
onPairsChange(pairs.filter((_, i) => i !== index));
|
||||
}
|
||||
|
||||
function updatePairKey(index: number, key: string) {
|
||||
const newPairs = [...pairs];
|
||||
newPairs[index] = { ...newPairs[index], key };
|
||||
onPairsChange(newPairs);
|
||||
}
|
||||
|
||||
function updatePairValue(index: number, value: string) {
|
||||
const newPairs = [...pairs];
|
||||
newPairs[index] = { ...newPairs[index], value };
|
||||
onPairsChange(newPairs);
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class={className}>
|
||||
<div class="mb-2 flex items-center justify-between">
|
||||
{#if sectionLabel}
|
||||
<span class="text-xs font-medium">
|
||||
{sectionLabel}
|
||||
{#if sectionLabelOptional}
|
||||
<span class="text-muted-foreground">(optional)</span>
|
||||
{/if}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="inline-flex cursor-pointer items-center gap-1 rounded-md px-1.5 py-1 text-xs text-muted-foreground hover:bg-muted hover:text-foreground"
|
||||
onclick={addPair}
|
||||
>
|
||||
<Plus class="h-3 w-3" />
|
||||
{addButtonLabel}
|
||||
</button>
|
||||
</div>
|
||||
{#if pairs.length > 0}
|
||||
<div class="space-y-3">
|
||||
{#each pairs as pair, index (index)}
|
||||
<div class="flex items-start gap-2">
|
||||
<Input
|
||||
type="text"
|
||||
placeholder={keyPlaceholder}
|
||||
value={pair.key}
|
||||
oninput={(e) => updatePairKey(index, e.currentTarget.value)}
|
||||
class="flex-1"
|
||||
/>
|
||||
|
||||
<textarea
|
||||
use:autoResizeTextarea
|
||||
placeholder={valuePlaceholder}
|
||||
value={pair.value}
|
||||
oninput={(e) => {
|
||||
updatePairValue(index, e.currentTarget.value);
|
||||
autoResizeTextarea(e.currentTarget);
|
||||
}}
|
||||
class="flex-1 resize-none rounded-md border border-input bg-transparent px-3 py-2 text-sm leading-5 placeholder:text-muted-foreground focus-visible:ring-1 focus-visible:ring-ring focus-visible:outline-none"
|
||||
rows="1"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="button"
|
||||
class="mt-1.5 shrink-0 cursor-pointer rounded-md p-1 text-muted-foreground hover:bg-destructive/10 hover:text-destructive"
|
||||
onclick={() => removePair(index)}
|
||||
aria-label="Remove item"
|
||||
>
|
||||
<Trash2 class="h-3.5 w-3.5" />
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{:else}
|
||||
<p class="text-xs text-muted-foreground">{emptyMessage}</p>
|
||||
{/if}
|
||||
</div>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user