mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-07 09:34:07 +00:00
Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
defe2158dd | ||
|
|
7b50d589a8 | ||
|
|
3a9457df96 | ||
|
|
fa4a9f2a1c | ||
|
|
238005c2dc | ||
|
|
66aba7aca9 | ||
|
|
f1f5e82df6 | ||
|
|
af3373f1ad | ||
|
|
5d5c066de8 | ||
|
|
40bfa04c95 | ||
|
|
aa064b2eb7 | ||
|
|
aa0ef5c578 | ||
|
|
bb16041cae | ||
|
|
58cba76a9a | ||
|
|
67ae5312e2 | ||
|
|
692e3cdd0a | ||
|
|
b23fa0b3f4 | ||
|
|
06cbedfca1 | ||
|
|
b7147673f2 | ||
|
|
d860dd99a4 | ||
|
|
c959f462a0 | ||
|
|
22015b2092 | ||
|
|
dd6e6d0b6a | ||
|
|
8308f98c7f | ||
|
|
6369be0735 | ||
|
|
88fc854b4b | ||
|
|
e28c1b93fd | ||
|
|
d27b3ca175 | ||
|
|
9230dbe2c7 | ||
|
|
812939a9e9 | ||
|
|
4c9fdfbe15 | ||
|
|
9eaa51e7f0 | ||
|
|
8f71d0f3e8 | ||
|
|
381174bbda | ||
|
|
d67341dc18 | ||
|
|
456af35eb7 | ||
|
|
600e3e9b50 | ||
|
|
fffcce535e | ||
|
|
5fc7856815 | ||
|
|
faed5a5f5d | ||
|
|
10bb545c5b | ||
|
|
edc4a29eff | ||
|
|
ed3290ab34 | ||
|
|
8d94713654 | ||
|
|
50d2227953 | ||
|
|
6231c5cd6d | ||
|
|
ef035803eb | ||
|
|
413977de32 | ||
|
|
95402553a5 | ||
|
|
3865cff4f5 | ||
|
|
d03172cc79 | ||
|
|
dd8e59f443 | ||
|
|
bbe98d2784 | ||
|
|
c2056ed6d4 | ||
|
|
c46503014d | ||
|
|
860a9e4eef | ||
|
|
fe9d60e74a | ||
|
|
e434e69183 | ||
|
|
89fea80d29 | ||
|
|
6adc3c3ebc | ||
|
|
0dbcabde8c | ||
|
|
ad590be98c | ||
|
|
7d6d91babf | ||
|
|
d3e64b9f49 | ||
|
|
3ba0d843c6 | ||
|
|
0bf49eb668 | ||
|
|
4ad243677b | ||
|
|
c89c2d1ab9 | ||
|
|
3555b3004b | ||
|
|
d7da8dc83a | ||
|
|
cd355eda7d | ||
|
|
30e5b01de2 | ||
|
|
e54b394082 | ||
|
|
2c2caa4443 | ||
|
|
5fce5f948d |
7
.github/workflows/build.yml
vendored
7
.github/workflows/build.yml
vendored
@@ -683,7 +683,7 @@ jobs:
|
||||
env:
|
||||
OPENBLAS_VERSION: 0.3.23
|
||||
SDE_VERSION: 9.33.0-2024-01-07
|
||||
VULKAN_VERSION: 1.4.309.0
|
||||
VULKAN_VERSION: 1.4.313.2
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -693,7 +693,7 @@ jobs:
|
||||
- build: 'openblas-x64'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"'
|
||||
- build: 'vulkan-x64'
|
||||
defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
|
||||
defines: '-DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON'
|
||||
- build: 'llvm-arm64'
|
||||
defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON'
|
||||
- build: 'llvm-arm64-opencl-adreno'
|
||||
@@ -736,7 +736,7 @@ jobs:
|
||||
id: get_vulkan
|
||||
if: ${{ matrix.build == 'kompute-x64' || matrix.build == 'vulkan-x64' }}
|
||||
run: |
|
||||
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
|
||||
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
|
||||
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
|
||||
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
|
||||
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"
|
||||
@@ -778,6 +778,7 @@ jobs:
|
||||
cmake -S . -B build ${{ matrix.defines }} `
|
||||
-DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include"
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS}
|
||||
cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release
|
||||
|
||||
- name: Add libopenblas.dll
|
||||
id: add_libopenblas_dll
|
||||
|
||||
@@ -39,7 +39,7 @@ sd=`dirname $0`
|
||||
cd $sd/../
|
||||
SRC=`pwd`
|
||||
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=OFF"
|
||||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
|
||||
|
||||
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON"
|
||||
@@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
|
||||
model_f16="${path_models}/ggml-model-f16.gguf"
|
||||
|
||||
# for this model, the SEP token is "</s>"
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
|
||||
|
||||
# sample output
|
||||
# rerank score 0: 0.029
|
||||
|
||||
@@ -988,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
params.tensor_buft_overrides.push_back({nullptr, nullptr});
|
||||
}
|
||||
|
||||
if (params.reranking && params.embedding) {
|
||||
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
|
||||
}
|
||||
|
||||
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
|
||||
throw std::runtime_error(string_format(
|
||||
"error: the supplied chat template is not supported: %s%s\n",
|
||||
@@ -2710,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.embd_sep = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--cls-separator"}, "STRING",
|
||||
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.cls_sep = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
|
||||
add_opt(common_arg(
|
||||
{"--host"}, "HOST",
|
||||
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),
|
||||
@@ -2747,9 +2750,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
|
||||
add_opt(common_arg(
|
||||
{"--reranking", "--rerank"},
|
||||
string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
|
||||
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
|
||||
[](common_params & params) {
|
||||
params.reranking = true;
|
||||
params.embedding = true;
|
||||
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
|
||||
add_opt(common_arg(
|
||||
@@ -3213,6 +3217,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.speculative.model.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
|
||||
string_format(
|
||||
"KV cache data type for K for the draft model\n"
|
||||
"allowed values: %s\n"
|
||||
"(default: %s)",
|
||||
get_all_kv_cache_types().c_str(),
|
||||
ggml_type_name(params.speculative.cache_type_k)
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.cache_type_k = kv_cache_type_from_str(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
|
||||
string_format(
|
||||
"KV cache data type for V for the draft model\n"
|
||||
"allowed values: %s\n"
|
||||
"(default: %s)",
|
||||
get_all_kv_cache_types().c_str(),
|
||||
ggml_type_name(params.speculative.cache_type_v)
|
||||
),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.cache_type_v = kv_cache_type_from_str(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"-mv", "--model-vocoder"}, "FNAME",
|
||||
|
||||
@@ -1838,7 +1838,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
if (res < 0) {
|
||||
// if the custom "tmpl" is not supported, we throw an error
|
||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||
throw std::runtime_error("this custom template is not supported");
|
||||
throw std::runtime_error("this custom template is not supported, try using --jinja");
|
||||
}
|
||||
|
||||
// if it turns out that our buffer is too small, we resize it
|
||||
|
||||
@@ -706,11 +706,17 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
filename_utf32 = converter.from_bytes(filename);
|
||||
@@ -767,6 +773,9 @@ bool fs_validate_filename(const std::string & filename) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#include <iostream>
|
||||
|
||||
|
||||
// returns true if successful, false otherwise
|
||||
bool fs_create_directory_with_parents(const std::string & path) {
|
||||
#ifdef _WIN32
|
||||
@@ -784,9 +793,16 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||
// process path from front to back, procedurally creating directories
|
||||
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
||||
const std::wstring subpath = wpath.substr(0, pos_slash);
|
||||
const wchar_t * test = subpath.c_str();
|
||||
|
||||
const bool success = CreateDirectoryW(test, NULL);
|
||||
pos_slash += 1;
|
||||
|
||||
// skip the drive letter, in some systems it can return an access denied error
|
||||
if (subpath.length() == 2 && subpath[1] == ':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const bool success = CreateDirectoryW(subpath.c_str(), NULL);
|
||||
|
||||
if (!success) {
|
||||
const DWORD error = GetLastError();
|
||||
|
||||
@@ -800,8 +816,6 @@ bool fs_create_directory_with_parents(const std::string & path) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
pos_slash += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -897,34 +911,6 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.reranking) {
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (!has_eos && !has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
} else if (!has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
@@ -966,6 +952,35 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
|
||||
|
||||
if (!has_eos && !has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
} else if (!has_sep) {
|
||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
}
|
||||
}
|
||||
|
||||
// load and optionally apply lora adapters
|
||||
for (auto & la : params.lora_adapters) {
|
||||
llama_adapter_lora_ptr lora;
|
||||
@@ -1143,11 +1158,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
|
||||
if (params.reranking) {
|
||||
cparams.embeddings = true;
|
||||
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
||||
}
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
|
||||
@@ -1280,6 +1290,9 @@ std::vector<llama_token> common_tokenize(
|
||||
int n_tokens = text.length() + 2 * add_special;
|
||||
std::vector<llama_token> result(n_tokens);
|
||||
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
if (n_tokens == std::numeric_limits<int32_t>::min()) {
|
||||
throw std::runtime_error("Tokenization failed: input text too large, tokenization result exceeds int32_t limit");
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
result.resize(-n_tokens);
|
||||
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
||||
|
||||
@@ -199,6 +199,9 @@ struct common_params_speculative {
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
||||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
@@ -355,7 +358,7 @@ struct common_params {
|
||||
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
||||
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
||||
std::string embd_sep = "\n"; // separator of embeddings
|
||||
bool reranking = false; // enable reranking support on server
|
||||
std::string cls_sep = "\t"; // separator of classification sequences
|
||||
|
||||
// server params
|
||||
int32_t port = 8080; // server listens on this network port
|
||||
|
||||
@@ -41,49 +41,6 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
|
||||
class string_view {
|
||||
const std::string & _str;
|
||||
const size_t _start;
|
||||
const size_t _end;
|
||||
public:
|
||||
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
|
||||
|
||||
size_t size() const {
|
||||
return _end - _start;
|
||||
}
|
||||
|
||||
size_t length() const {
|
||||
return size();
|
||||
}
|
||||
|
||||
operator std::string() const {
|
||||
return str();
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
return _str.substr(_start, _end - _start);
|
||||
}
|
||||
|
||||
string_view substr(size_t pos, size_t len = std::string::npos) const {
|
||||
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
|
||||
}
|
||||
|
||||
char operator[](size_t pos) const {
|
||||
auto index = _start + pos;
|
||||
if (index >= _end) {
|
||||
throw std::out_of_range("string_view index out of range");
|
||||
}
|
||||
return _str[_start + pos];
|
||||
}
|
||||
|
||||
bool operator==(const string_view & other) const {
|
||||
std::string this_str = *this;
|
||||
std::string other_str = other;
|
||||
return this_str == other_str;
|
||||
}
|
||||
};
|
||||
|
||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||
@@ -112,14 +69,14 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
||||
}
|
||||
out << "}";
|
||||
};
|
||||
std::function<void(const string_view &, const string_view &)> uniform_range =
|
||||
[&](const string_view & from, const string_view & to) {
|
||||
std::function<void(const std::string_view &, const std::string_view &)> uniform_range =
|
||||
[&](const std::string_view & from, const std::string_view & to) {
|
||||
size_t i = 0;
|
||||
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
||||
i++;
|
||||
}
|
||||
if (i > 0) {
|
||||
out << "\"" << from.substr(0, i).str() << "\"";
|
||||
out << "\"" << from.substr(0, i) << "\"";
|
||||
}
|
||||
if (i < from.length() && i < to.length()) {
|
||||
if (i > 0) {
|
||||
|
||||
@@ -519,7 +519,7 @@ class TextModel(ModelBase):
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
|
||||
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions"], optional=True)) is not None:
|
||||
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None:
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
logger.info(f"gguf: context length = {n_ctx}")
|
||||
|
||||
@@ -556,11 +556,8 @@ class TextModel(ModelBase):
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
# Workaround for incorrect AutoConfig value for DeepSeekV3 (is set correctly in DeepSeekV2Model class)
|
||||
# https://github.com/huggingface/transformers/blob/19224c3642705c5b6988c9f5f4251f83323d05ae/src/transformers/models/deepseek_v3/configuration_deepseek_v3.py#L210
|
||||
if self.hparams.get("model_type") != "deepseek_v3":
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
self.gguf_writer.add_value_length(head_dim)
|
||||
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
logger.info(f"gguf: file type = {self.ftype}")
|
||||
@@ -1901,9 +1898,7 @@ class LlamaModel(TextModel):
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
|
||||
if "head_dim" in hparams:
|
||||
rope_dim = hparams["head_dim"]
|
||||
else:
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
@@ -1985,7 +1980,8 @@ class LlamaModel(TextModel):
|
||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
if (dim := self.hparams.get("head_dim")) is None:
|
||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
|
||||
factor = rope_scaling.get("factor", 8.0)
|
||||
@@ -2020,6 +2016,20 @@ class LlamaModel(TextModel):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("ArceeForCausalLM")
|
||||
class ArceeModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.ARCEE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self._try_set_pooling_type()
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"LlavaForConditionalGeneration", # pixtral
|
||||
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||
@@ -2135,7 +2145,6 @@ class Llama4Model(LlamaModel):
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
@@ -2184,7 +2193,7 @@ class Llama4VisionModel(MmprojModel):
|
||||
name += ".weight"
|
||||
if "multi_modal_projector.linear_1" in name:
|
||||
# despite the name with number postfix, this is a single fully connected layer
|
||||
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)]
|
||||
return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)]
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
return []
|
||||
|
||||
@@ -2307,9 +2316,7 @@ class DeciModel(TextModel):
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
|
||||
if "head_dim" in hparams:
|
||||
rope_dim = hparams["head_dim"]
|
||||
else:
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
@@ -2349,7 +2356,8 @@ class DeciModel(TextModel):
|
||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
if (dim := self.hparams.get("head_dim")) is None:
|
||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
|
||||
factor = rope_scaling.get("factor", 8.0)
|
||||
@@ -3667,9 +3675,7 @@ class InternLM3Model(TextModel):
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
|
||||
if "head_dim" in hparams:
|
||||
rope_dim = hparams["head_dim"]
|
||||
else:
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
|
||||
@@ -3911,9 +3917,6 @@ class BertModel(TextModel):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
|
||||
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
|
||||
class DistilBertModel(BertModel):
|
||||
@@ -3955,8 +3958,6 @@ class RobertaModel(BertModel):
|
||||
bpe_tok_path = self.dir_model / "tokenizer.json"
|
||||
if bpe_tok_path.exists():
|
||||
self._set_vocab_gpt2()
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
# we need this to validate the size of the token_type embeddings
|
||||
# though currently we are passing all zeros to the token_type embeddings
|
||||
@@ -4062,6 +4063,34 @@ class NomicBertModel(BertModel):
|
||||
raise ValueError(f"unknown tokenizer: {toktyp}")
|
||||
|
||||
|
||||
@ModelBase.register("NeoBERT", "NeoBERTLMHead", "NeoBERTForSequenceClassification")
|
||||
class NeoBert(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.NEO_BERT
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# NeoBERT uses 2/3 of the intermediate size as feed forward length
|
||||
self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3))
|
||||
self.gguf_writer.add_rope_freq_base(10000.0) # default value for NeoBERT
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
|
||||
f_rms_eps = self.hparams.get("norm_eps", 1e-6) # default value for NeoBERT
|
||||
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
|
||||
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
|
||||
|
||||
self.gguf_writer.add_pooling_type(gguf.PoolingType.CLS) # https://huggingface.co/chandar-lab/NeoBERT#how-to-use
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
if name.startswith("decoder."):
|
||||
return []
|
||||
|
||||
if name.startswith("model."):
|
||||
name = name[6:]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
|
||||
class XLMRobertaModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
@@ -4813,8 +4842,6 @@ class JinaBertV2Model(BertModel):
|
||||
self.gguf_writer.add_token_type_count(2)
|
||||
else:
|
||||
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
|
||||
@ModelBase.register("OpenELMForCausalLM")
|
||||
@@ -5056,9 +5083,7 @@ class DeepseekModel(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if "head_dim" in hparams:
|
||||
rope_dim = hparams["head_dim"]
|
||||
else:
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
@@ -5418,9 +5443,6 @@ class T5Model(TextModel):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
|
||||
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
|
||||
@@ -5558,9 +5580,6 @@ class T5EncoderModel(TextModel):
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
self.gguf_writer.add_add_bos_token(False)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
|
||||
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
|
||||
@@ -5948,7 +5967,8 @@ class ExaoneModel(TextModel):
|
||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
if (dim := self.hparams.get("head_dim")) is None:
|
||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
|
||||
factor = rope_scaling.get("factor", 8.0)
|
||||
@@ -6060,7 +6080,8 @@ class BailingMoeModel(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
@@ -6092,7 +6113,8 @@ class BailingMoeModel(TextModel):
|
||||
n_head = self.hparams["num_attention_heads"]
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
n_embd = self.hparams["hidden_size"]
|
||||
head_dim = self.hparams.get("head_dim") or n_embd // n_head
|
||||
if (head_dim := self.hparams.get("head_dim")) is None:
|
||||
head_dim = n_embd // n_head
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
|
||||
@@ -6353,8 +6375,8 @@ def parse_args() -> argparse.Namespace:
|
||||
help="model is executed on big endian machine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model", type=Path,
|
||||
help="directory containing model file",
|
||||
"model", type=str,
|
||||
help="directory containing model file or huggingface repository ID (if --remote)",
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -6457,18 +6479,20 @@ def main() -> None:
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
dir_model = args.model
|
||||
|
||||
if args.remote:
|
||||
hf_repo_id = args.model
|
||||
from huggingface_hub import snapshot_download
|
||||
local_dir = snapshot_download(
|
||||
repo_id=str(dir_model),
|
||||
repo_id=hf_repo_id,
|
||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
||||
dir_model = Path(local_dir)
|
||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||
else:
|
||||
hf_repo_id = None
|
||||
dir_model = Path(args.model)
|
||||
|
||||
if not dir_model.is_dir():
|
||||
logger.error(f'Error: {args.model} is not a directory')
|
||||
logger.error(f'Error: {dir_model} is not a directory')
|
||||
sys.exit(1)
|
||||
|
||||
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||
@@ -6488,9 +6512,9 @@ def main() -> None:
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
elif args.remote:
|
||||
elif hf_repo_id:
|
||||
# if remote, use the model ID as the output file name
|
||||
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
|
||||
fname_out = Path("./" + hf_repo_id.replace("/", "-") + "-{ftype}.gguf")
|
||||
else:
|
||||
fname_out = dir_model
|
||||
|
||||
@@ -6519,7 +6543,7 @@ def main() -> None:
|
||||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=str(args.model) if args.remote else None)
|
||||
remote_hf_model_id=hf_repo_id)
|
||||
|
||||
if args.vocab_only:
|
||||
logger.info("Exporting model vocab...")
|
||||
|
||||
157
docs/build-s390x.md
Normal file
157
docs/build-s390x.md
Normal file
@@ -0,0 +1,157 @@
|
||||
> [!IMPORTANT]
|
||||
> This build documentation is specific only to IBM Z & LinuxONE mainframes (s390x). You can find the build documentation for other architectures: [build.md](build.md).
|
||||
|
||||
# Build llama.cpp locally (for s390x)
|
||||
|
||||
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
|
||||
|
||||
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
|
||||
|
||||
**To get the code:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
## CPU Build with BLAS
|
||||
|
||||
Building llama.cpp with BLAS support is highly recommended as it has shown to provide performance improvements.
|
||||
|
||||
```bash
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BLAS=ON \
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
|
||||
**Notes**:
|
||||
- For faster repeated compilation, install [ccache](https://ccache.dev/)
|
||||
- By default, VXE/VXE2 is enabled. To disable it (not recommended):
|
||||
|
||||
```bash
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BLAS=ON \
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS \
|
||||
-DGGML_VXE=OFF
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
|
||||
- For debug builds:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DGGML_BLAS=ON \
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS
|
||||
|
||||
cmake --build build --config Debug -j $(nproc)
|
||||
```
|
||||
|
||||
- For static builds, add `-DBUILD_SHARED_LIBS=OFF`:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BLAS=ON \
|
||||
-DGGML_BLAS_VENDOR=OpenBLAS \
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
```
|
||||
|
||||
## Getting GGUF Models
|
||||
|
||||
All models need to be converted to Big-Endian. You can achieve this in three cases:
|
||||
|
||||
1. **Use pre-converted models verified for use on IBM Z & LinuxONE (easiest)**
|
||||
|
||||
You can find popular models pre-converted and verified at [s390x Ready Models](hf.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08).
|
||||
|
||||
These models and their respective tokenizers are verified to run correctly on IBM Z & LinuxONE.
|
||||
|
||||
2. **Convert safetensors model to GGUF Big-Endian directly (recommended)**
|
||||
|
||||
```bash
|
||||
python3 convert_hf_to_gguf.py \
|
||||
--outfile model-name-be.f16.gguf \
|
||||
--outtype f16 \
|
||||
--bigendian \
|
||||
model-directory/
|
||||
```
|
||||
|
||||
For example,
|
||||
|
||||
```bash
|
||||
python3 convert_hf_to_gguf.py \
|
||||
--outfile granite-3.3-2b-instruct-be.f16.gguf \
|
||||
--outtype f16 \
|
||||
--bigendian \
|
||||
granite-3.3-2b-instruct/
|
||||
```
|
||||
|
||||
3. **Convert existing GGUF Little-Endian model to Big-Endian**
|
||||
|
||||
```bash
|
||||
python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG
|
||||
```
|
||||
|
||||
For example,
|
||||
```bash
|
||||
python3 gguf-py/gguf/scripts/gguf_convert_endian.py granite-3.3-2b-instruct-le.f16.gguf BIG
|
||||
mv granite-3.3-2b-instruct-le.f16.gguf granite-3.3-2b-instruct-be.f16.gguf
|
||||
```
|
||||
|
||||
**Notes:**
|
||||
- The GGUF endian conversion script may not support all data types at the moment and may fail for some models/quantizations. When that happens, please try manually converting the safetensors model to GGUF Big-Endian via Step 2.
|
||||
|
||||
## IBM Accelerators
|
||||
|
||||
### 1. SIMD Acceleration
|
||||
|
||||
Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14 or EC13. In such systems, the APIs can still run but will use a scalar implementation.
|
||||
|
||||
### 2. zDNN Accelerator
|
||||
|
||||
*Only available in IBM z16 or later system. No direction at the moment.*
|
||||
|
||||
### 3. Spyre Accelerator
|
||||
|
||||
*No direction at the moment.*
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### 1. Virtualization Setup
|
||||
|
||||
It is strongly recommended to use only LPAR (Type-1) virtualization to get the most performance.
|
||||
|
||||
Note: Type-2 virtualization is not supported at the moment, while you can get it running, the performance will not be the best.
|
||||
|
||||
### 2. IFL (Core) Count
|
||||
|
||||
It is recommended to allocate a minimum of 8 shared IFLs assigned to the LPAR. Increasing the IFL count past 8 shared IFLs will only improve Prompt Processing performance but not Token Generation.
|
||||
|
||||
Note: IFL count does not equate to vCPU count.
|
||||
|
||||
### 3. SMT vs NOSMT (Simultaneous Multithreading)
|
||||
|
||||
It is strongly recommended to disable SMT via the kernel boot parameters as it negatively affects performance. Please refer to your Linux distribution's guide on disabling SMT via kernel boot parameters.
|
||||
|
||||
### 4. BLAS vs NOBLAS
|
||||
|
||||
IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongly recommended to use BLAS.
|
||||
|
||||
## Getting Help on IBM Z & LinuxONE
|
||||
|
||||
1. **Bugs, Feature Requests**
|
||||
|
||||
Please file an issue in llama.cpp and ensure that the title contains "s390x".
|
||||
|
||||
2. **Other Questions**
|
||||
|
||||
Please reach out directly to [aionz@us.ibm.com](mailto:aionz@us.ibm.com).
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Build llama.cpp locally
|
||||
|
||||
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h).
|
||||
The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h).
|
||||
|
||||
The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server.
|
||||
|
||||
|
||||
@@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
|
||||
// max batch size
|
||||
const uint64_t n_batch = params.n_batch;
|
||||
|
||||
// get added sep and eos token, if any
|
||||
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
|
||||
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
|
||||
|
||||
// tokenize the prompts and trim
|
||||
std::vector<std::vector<int32_t>> inputs;
|
||||
for (const auto & prompt : prompts) {
|
||||
auto inp = common_tokenize(ctx, prompt, true, true);
|
||||
std::vector<llama_token> inp;
|
||||
|
||||
// split classification pairs and insert expected separator tokens
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
|
||||
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
|
||||
std::string final_prompt;
|
||||
|
||||
for (size_t i = 0; i < pairs.size(); i++) {
|
||||
final_prompt += pairs[i];
|
||||
if (i != pairs.size() - 1) {
|
||||
if (!added_eos_token.empty()) {
|
||||
final_prompt += added_eos_token;
|
||||
}
|
||||
if (!added_sep_token.empty()) {
|
||||
final_prompt += added_sep_token;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inp = common_tokenize(ctx, final_prompt, true, true);
|
||||
} else {
|
||||
inp = common_tokenize(ctx, prompt, true, true);
|
||||
}
|
||||
if (inp.size() > n_batch) {
|
||||
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
|
||||
__func__, (long long int) inp.size(), (long long int) n_batch);
|
||||
@@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
|
||||
inputs.push_back(inp);
|
||||
}
|
||||
|
||||
// check if the last token is SEP
|
||||
// check if the last token is SEP/EOS
|
||||
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
|
||||
for (auto & inp : inputs) {
|
||||
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
|
||||
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
|
||||
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
|
||||
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
|
||||
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
||||
|
||||
// add input to batch (this increments n_tokens)
|
||||
for (int32_t j = 0; j < n_toks; j++) {
|
||||
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
|
||||
common_batch_add(batch, inputs[j], j, { 0 }, true);
|
||||
}
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
llama_set_embeddings(ctx, true);
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
// run model
|
||||
@@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
||||
llama_token eos_token = llama_vocab_eos(vocab);
|
||||
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
llama_set_embeddings(ctx, false);
|
||||
llama_set_causal_attn(ctx, true);
|
||||
|
||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
@@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
|
||||
llama_model_params mparams = common_model_params_to_llama(params);
|
||||
llama_context_params cparams = common_context_params_to_llama(params);
|
||||
|
||||
cparams.embeddings = true;
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
@@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
|
||||
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
|
||||
}
|
||||
|
||||
llama_set_embeddings(ctx, false);
|
||||
|
||||
// ### Generation ###
|
||||
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
||||
{
|
||||
|
||||
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
|
||||
auto generate = [&](const std::string & prompt) {
|
||||
std::string response;
|
||||
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
|
||||
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1;
|
||||
|
||||
// tokenize the prompt
|
||||
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
|
||||
|
||||
@@ -172,6 +172,7 @@ option(GGML_HIP "ggml: use HIP"
|
||||
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
|
||||
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
|
||||
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
|
||||
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
|
||||
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
||||
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
||||
option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF)
|
||||
@@ -367,6 +368,8 @@ if (MSVC)
|
||||
/wd4005 # Macro redefinition
|
||||
/wd4244 # Conversion from one type to another type, possible loss of data
|
||||
/wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data
|
||||
/wd4305 # Conversion from 'type1' to 'type2', possible loss of data
|
||||
/wd4566 # Conversion from 'char' to 'wchar_t', possible loss of data
|
||||
/wd4996 # Disable POSIX deprecation warnings
|
||||
/wd4702 # Unreachable code warnings
|
||||
)
|
||||
@@ -386,4 +389,46 @@ if (MSVC)
|
||||
disable_msvc_warnings(ggml-cpu-skylakex)
|
||||
disable_msvc_warnings(ggml-cpu-icelake)
|
||||
disable_msvc_warnings(ggml-cpu-alderlake)
|
||||
|
||||
if (GGML_BUILD_EXAMPLES)
|
||||
disable_msvc_warnings(common-ggml)
|
||||
disable_msvc_warnings(common)
|
||||
|
||||
disable_msvc_warnings(mnist-common)
|
||||
disable_msvc_warnings(mnist-eval)
|
||||
disable_msvc_warnings(mnist-train)
|
||||
|
||||
disable_msvc_warnings(gpt-2-ctx)
|
||||
disable_msvc_warnings(gpt-2-alloc)
|
||||
disable_msvc_warnings(gpt-2-backend)
|
||||
disable_msvc_warnings(gpt-2-sched)
|
||||
disable_msvc_warnings(gpt-2-quantize)
|
||||
disable_msvc_warnings(gpt-2-batched)
|
||||
|
||||
disable_msvc_warnings(gpt-j)
|
||||
disable_msvc_warnings(gpt-j-quantize)
|
||||
|
||||
disable_msvc_warnings(magika)
|
||||
disable_msvc_warnings(yolov3-tiny)
|
||||
disable_msvc_warnings(sam)
|
||||
|
||||
disable_msvc_warnings(simple-ctx)
|
||||
disable_msvc_warnings(simple-backend)
|
||||
endif()
|
||||
|
||||
if (GGML_BUILD_TESTS)
|
||||
disable_msvc_warnings(test-mul-mat)
|
||||
disable_msvc_warnings(test-arange)
|
||||
disable_msvc_warnings(test-backend-ops)
|
||||
disable_msvc_warnings(test-cont)
|
||||
disable_msvc_warnings(test-conv-transpose)
|
||||
disable_msvc_warnings(test-conv-transpose-1d)
|
||||
disable_msvc_warnings(test-conv1d)
|
||||
disable_msvc_warnings(test-conv2d)
|
||||
disable_msvc_warnings(test-conv2d-dw)
|
||||
disable_msvc_warnings(test-customop)
|
||||
disable_msvc_warnings(test-dup)
|
||||
disable_msvc_warnings(test-opt)
|
||||
disable_msvc_warnings(test-pool)
|
||||
endif ()
|
||||
endif()
|
||||
|
||||
@@ -36,8 +36,7 @@ function(ggml_get_system_arch)
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$"))
|
||||
set(GGML_SYSTEM_ARCH "x86" PARENT_SCOPE)
|
||||
elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR
|
||||
"${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ")
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc|power")
|
||||
set(GGML_SYSTEM_ARCH "PowerPC" PARENT_SCOPE)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
||||
set(GGML_SYSTEM_ARCH "loongarch64" PARENT_SCOPE)
|
||||
|
||||
@@ -489,6 +489,7 @@ extern "C" {
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_PAD_REFLECT_1D,
|
||||
GGML_OP_ROLL,
|
||||
GGML_OP_ARANGE,
|
||||
GGML_OP_TIMESTEP_EMBEDDING,
|
||||
GGML_OP_ARGSORT,
|
||||
@@ -1801,6 +1802,17 @@ extern "C" {
|
||||
int p0,
|
||||
int p1);
|
||||
|
||||
// Move tensor elements by an offset given for each dimension. Elements that
|
||||
// are shifted beyond the last position are wrapped around to the beginning.
|
||||
GGML_API struct ggml_tensor * ggml_roll(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int shift0,
|
||||
int shift1,
|
||||
int shift2,
|
||||
int shift3);
|
||||
|
||||
|
||||
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
|
||||
// timesteps: [N,]
|
||||
// return: [N, dim]
|
||||
|
||||
@@ -286,6 +286,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||
@@ -311,18 +315,45 @@ if (GGML_CPU_ALL_VARIANTS)
|
||||
# MSVC doesn't support AMX
|
||||
ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM" AND CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
# Many of these features are optional so we build versions with popular
|
||||
# combinations and name the backends based on the version they were
|
||||
# first released with
|
||||
ggml_add_cpu_backend_variant(armv8.0_1)
|
||||
ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD)
|
||||
ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||
ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE)
|
||||
ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8)
|
||||
ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2)
|
||||
ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME)
|
||||
ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME)
|
||||
elseif(GGML_SYSTEM_ARCH STREQUAL "ARM")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
# Many of these features are optional so we build versions with popular
|
||||
# combinations and name the backends based on the version they were
|
||||
# first released with
|
||||
ggml_add_cpu_backend_variant(armv8.0_1)
|
||||
ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD)
|
||||
ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||
ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE)
|
||||
ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8)
|
||||
ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2)
|
||||
ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME)
|
||||
ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Android")
|
||||
# Android-specific backends with SoC-compatible feature sets
|
||||
ggml_add_cpu_backend_variant(android_armv8.0_1)
|
||||
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
|
||||
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
|
||||
elseif (APPLE)
|
||||
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
|
||||
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
|
||||
ggml_add_cpu_backend_variant(apple_m4 DOTPROD MATMUL_INT8 NOSVE SME)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(power0)
|
||||
ggml_add_cpu_backend_variant(power7_1 POWER7)
|
||||
ggml_add_cpu_backend_variant(power7_2 POWER7 VSX)
|
||||
ggml_add_cpu_backend_variant(power8_1 POWER8)
|
||||
ggml_add_cpu_backend_variant(power8_2 POWER8 VSX)
|
||||
ggml_add_cpu_backend_variant(power9 POWER9 VSX)
|
||||
ggml_add_cpu_backend_variant(power10 POWER10 VSX)
|
||||
ggml_add_cpu_backend_variant(power11 POWER11 VSX)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
||||
@@ -69,6 +69,9 @@
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
@@ -91,6 +94,8 @@ static std::string path_str(const fs::path & path) {
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
@@ -158,48 +158,48 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (GGML_CPU_ARM_ARCH)
|
||||
list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH})
|
||||
elseif(GGML_CPU_ALL_VARIANTS)
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
# Begin with the lowest baseline
|
||||
set(ARM_MCPU "armv8-a")
|
||||
set(ARCH_TAGS "")
|
||||
set(ARCH_DEFINITIONS "")
|
||||
# Begin with the lowest baseline
|
||||
set(ARM_MCPU "armv8-a")
|
||||
set(ARCH_TAGS "")
|
||||
set(ARCH_DEFINITIONS "")
|
||||
|
||||
# When a feature is selected, bump the MCPU to the first
|
||||
# version that supported it
|
||||
if (GGML_INTERNAL_DOTPROD)
|
||||
set(ARM_MCPU "armv8.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+dotprod")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_DOTPROD)
|
||||
endif()
|
||||
if (GGML_INTERNAL_FP16_VECTOR_ARITHMETIC)
|
||||
set(ARM_MCPU "armv8.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+fp16")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_FP16_VECTOR_ARITHMETIC)
|
||||
endif()
|
||||
if (GGML_INTERNAL_SVE)
|
||||
set(ARM_MCPU "armv8.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+sve")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_SVE)
|
||||
endif()
|
||||
if (GGML_INTERNAL_MATMUL_INT8)
|
||||
set(ARM_MCPU "armv8.6-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+i8mm")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_MATMUL_INT8)
|
||||
endif()
|
||||
if (GGML_INTERNAL_SVE2)
|
||||
set(ARM_MCPU "armv8.6-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+sve2")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_SVE2)
|
||||
endif()
|
||||
if (GGML_INTERNAL_SME)
|
||||
set(ARM_MCPU "armv9.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+sme")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_SME)
|
||||
endif()
|
||||
|
||||
list(APPEND ARCH_FLAGS "-march=${ARM_MCPU}${ARCH_TAGS}")
|
||||
ggml_add_cpu_backend_features(${GGML_CPU_NAME} arm ${ARCH_DEFINITIONS})
|
||||
# When a feature is selected, bump the MCPU to the first
|
||||
# version that supported it
|
||||
if (GGML_INTERNAL_DOTPROD)
|
||||
set(ARM_MCPU "armv8.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+dotprod")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_DOTPROD)
|
||||
endif()
|
||||
if (GGML_INTERNAL_FP16_VECTOR_ARITHMETIC)
|
||||
set(ARM_MCPU "armv8.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+fp16")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_FP16_VECTOR_ARITHMETIC)
|
||||
endif()
|
||||
if (GGML_INTERNAL_SVE)
|
||||
set(ARM_MCPU "armv8.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+sve")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_SVE)
|
||||
endif()
|
||||
if (GGML_INTERNAL_MATMUL_INT8)
|
||||
set(ARM_MCPU "armv8.6-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+i8mm")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_MATMUL_INT8)
|
||||
endif()
|
||||
if (GGML_INTERNAL_SVE2)
|
||||
set(ARM_MCPU "armv8.6-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+sve2")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_SVE2)
|
||||
endif()
|
||||
if (GGML_INTERNAL_NOSVE)
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+nosve")
|
||||
endif()
|
||||
if (GGML_INTERNAL_SME)
|
||||
set(ARM_MCPU "armv9.2-a")
|
||||
set(ARCH_TAGS "${ARCH_TAGS}+sme")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_SME)
|
||||
endif()
|
||||
list(APPEND ARCH_FLAGS "-march=${ARM_MCPU}${ARCH_TAGS}")
|
||||
ggml_add_cpu_backend_features(${GGML_CPU_NAME} arm ${ARCH_DEFINITIONS})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -388,6 +388,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native -mpowerpc64)
|
||||
endif()
|
||||
elseif(GGML_CPU_ALL_VARIANTS)
|
||||
# Begin with the lowest baseline
|
||||
set(ARCH_DEFINITIONS "")
|
||||
|
||||
# When a feature is selected, bump the MCPU to the first
|
||||
# version that supported it
|
||||
foreach(PVER RANGE 7 11)
|
||||
if(DEFINED GGML_INTERNAL_POWER${PVER})
|
||||
set(POWERPC_MCPU "power${PVER}")
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_POWER${PVER})
|
||||
endif()
|
||||
endforeach()
|
||||
if (GGML_INTERNAL_VSX)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_USE_VSX)
|
||||
list(APPEND ARCH_FLAGS -mvsx)
|
||||
endif()
|
||||
|
||||
if (DEFINED POWERPC_MCPU)
|
||||
list(APPEND ARCH_FLAGS -mcpu=${POWERPC_MCPU})
|
||||
endif()
|
||||
ggml_add_cpu_backend_features(${GGML_CPU_NAME} powerpc ${ARCH_DEFINITIONS})
|
||||
else()
|
||||
if (GGML_CPU_POWERPC_CPUTYPE)
|
||||
list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE})
|
||||
@@ -465,9 +486,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.6.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
184
ggml/src/ggml-cpu/arch-fallback.h
Normal file
184
ggml/src/ggml-cpu/arch-fallback.h
Normal file
@@ -0,0 +1,184 @@
|
||||
#pragma once
|
||||
|
||||
// Rename `_generic` functions if no native implementation is available.
|
||||
// This effectively selects the generic implementation.
|
||||
|
||||
#if defined(GGML_CPU_GENERIC)
|
||||
// quants.c
|
||||
#define quantize_row_q8_0_generic quantize_row_q8_0
|
||||
#define quantize_row_q8_1_generic quantize_row_q8_1
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q4_0_q8_0_generic ggml_vec_dot_q4_0_q8_0
|
||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
#define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K
|
||||
#define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K
|
||||
#define ggml_vec_dot_q5_K_q8_K_generic ggml_vec_dot_q5_K_q8_K
|
||||
#define ggml_vec_dot_q6_K_q8_K_generic ggml_vec_dot_q6_K_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__POWERPC__) || defined(__powerpc__)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__loongarch64)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__riscv)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__s390x__)
|
||||
// quants.c
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#elif defined(__wasm__)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
82
ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp
Normal file
82
ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
# include "ggml-backend-impl.h"
|
||||
|
||||
#if defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <sys/auxv.h>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
|
||||
struct powerpc_features {
|
||||
std::string platform = "";
|
||||
int power_version = -1;
|
||||
|
||||
bool has_vsx = false;
|
||||
|
||||
powerpc_features() {
|
||||
#if defined(__linux__)
|
||||
unsigned long auxval = getauxval(AT_PLATFORM);
|
||||
if (auxval) {
|
||||
platform = std::string(reinterpret_cast<const char*>(auxval));
|
||||
// TBD: Do systems exist that return this in uppercase?
|
||||
if (platform.substr(0, 5) == "power") {
|
||||
// Extractt a numeric suffix, if one exists
|
||||
int vpos = -1;
|
||||
for (int i = platform.length() - 1; i >= 0; i--) {
|
||||
if (std::isdigit(platform[i])) {
|
||||
vpos = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (vpos > -1) {
|
||||
power_version = std::stoi(platform.substr(vpos));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (power_version >= 9) {
|
||||
has_vsx = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static int ggml_backend_cpu_powerpc_score() {
|
||||
int score = 1;
|
||||
powerpc_features pf;
|
||||
|
||||
// Platform scores
|
||||
#if defined(GGML_USE_POWER7)
|
||||
if (pf.power_version < 7) { return 0; }
|
||||
score += 1<<1;
|
||||
#endif
|
||||
#if defined(GGML_USE_POWER8)
|
||||
if (pf.power_version < 8) { return 0; }
|
||||
score += 1<<2;
|
||||
#endif
|
||||
#if defined(GGML_USE_POWER9)
|
||||
if (pf.power_version < 9) { return 0; }
|
||||
score += 1<<3;
|
||||
#endif
|
||||
#if defined(GGML_USE_POWER10)
|
||||
if (pf.power_version < 10) { return 0; }
|
||||
score += 1<<4;
|
||||
#endif
|
||||
#if defined(GGML_USE_POWER11)
|
||||
if (pf.power_version < 11) { return 0; }
|
||||
score += 1<<5;
|
||||
#endif
|
||||
|
||||
// Feature scores
|
||||
#if defined(GGML_USE_VSX)
|
||||
if (!pf.has_vsx) { return 0; }
|
||||
score += 1<<6;
|
||||
#endif
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_powerpc_score)
|
||||
|
||||
#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__)
|
||||
@@ -371,7 +371,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
|
||||
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
|
||||
#endif
|
||||
|
||||
typedef signed char char8x16_t __attribute__((vector_size(16)));
|
||||
typedef signed char char8x16_t __attribute__((vector_size(16)));
|
||||
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef int8_t int8x16_t __attribute__((vector_size(16)));
|
||||
@@ -382,10 +382,10 @@ typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
|
||||
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
|
||||
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef float float32x4_t __attribute__((vector_size(16)));
|
||||
typedef double double64x2_t __attribute((vector_size(16)));
|
||||
typedef float float32x4_t __attribute__((vector_size(16)));
|
||||
typedef double double64x2_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef signed long long long64x2_t __attribute((vector_size(16)));
|
||||
typedef signed long long long64x2_t __attribute__((vector_size(16)));
|
||||
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef struct ggml_uint8x16x2_t {
|
||||
@@ -503,31 +503,9 @@ static __m256 __lasx_xvreplfr2vr_s(const float val) {
|
||||
// TODO: move to ggml-threading
|
||||
void ggml_barrier(struct ggml_threadpool * tp);
|
||||
|
||||
void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value);
|
||||
int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#define GGML_DO_PRAGMA_(x) _Pragma (#x)
|
||||
#define GGML_DO_PRAGMA(x) GGML_DO_PRAGMA_(x)
|
||||
#if defined(GGML_CPU_GENERIC) || defined(__HIPCC__)
|
||||
// Note for Apple targets:
|
||||
// - clang: aliases are not supported on darwin
|
||||
// - all native kernels need to be implemented in both x86 and arm files
|
||||
// - on iOS, tvOS, and visionOS, if cmake cannot determine the target architecture, all `_generic` names are replaced by defines
|
||||
# define GGML_WEAK_ALIAS(name, alias)
|
||||
#elif defined(__GNUC__)
|
||||
// GCC/Clang on *nix
|
||||
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(weak name = alias) // NOLINT
|
||||
#elif defined(_MSC_VER) && defined(_WIN64)
|
||||
// MSVC
|
||||
// Note: C name mangling varies across different calling conventions
|
||||
// see https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170
|
||||
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:" #name "=" #alias))
|
||||
#elif defined(_MSC_VER) && defined(WIN32)
|
||||
// ref: https://github.com/ggml-org/whisper.cpp/pull/3239#issuecomment-2958224591
|
||||
# define GGML_WEAK_ALIAS(name, alias) GGML_DO_PRAGMA(comment(linker, "/alternatename:_" #name "=_" #alias))
|
||||
#else
|
||||
# error "Unsupported compiler for GGML_WEAK_ALIAS"
|
||||
#endif
|
||||
|
||||
#define GGML_CPU_NATIVE_IMPL(name) GGML_WEAK_ALIAS(name, name ## _generic)
|
||||
|
||||
@@ -74,13 +74,8 @@
|
||||
|
||||
#if defined(__ARM_ARCH)
|
||||
struct ggml_arm_arch_features_type {
|
||||
int has_neon;
|
||||
int has_dotprod;
|
||||
int has_i8mm;
|
||||
int has_sve;
|
||||
int sve_cnt;
|
||||
int has_sme;
|
||||
} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
|
||||
} ggml_arm_arch_features = { 0 };
|
||||
#endif
|
||||
|
||||
|
||||
@@ -559,6 +554,14 @@ void ggml_barrier(struct ggml_threadpool * tp) {
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value) {
|
||||
atomic_store_explicit(&tp->current_chunk, value, memory_order_relaxed);
|
||||
}
|
||||
|
||||
int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value) {
|
||||
return atomic_fetch_add_explicit(&tp->current_chunk, value, memory_order_relaxed);
|
||||
}
|
||||
|
||||
#if defined(__gnu_linux__)
|
||||
static cpu_set_t ggml_get_numa_affinity(void) {
|
||||
cpu_set_t cpuset;
|
||||
@@ -670,87 +673,15 @@ bool ggml_is_numa(void) {
|
||||
|
||||
#if defined(__linux__) && defined(__aarch64__)
|
||||
#include <sys/auxv.h>
|
||||
#elif defined(__APPLE__)
|
||||
#include <sys/sysctl.h>
|
||||
#endif
|
||||
|
||||
#if !defined(HWCAP2_I8MM)
|
||||
#define HWCAP2_I8MM (1 << 13)
|
||||
#endif
|
||||
|
||||
#if !defined(HWCAP2_SME)
|
||||
#define HWCAP2_SME (1 << 23)
|
||||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
#if defined(__linux__) && defined(__aarch64__)
|
||||
uint32_t hwcap = getauxval(AT_HWCAP);
|
||||
uint32_t hwcap2 = getauxval(AT_HWCAP2);
|
||||
|
||||
ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
|
||||
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
|
||||
ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
|
||||
ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
|
||||
ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
||||
#endif
|
||||
#elif defined(__APPLE__)
|
||||
int oldp = 0;
|
||||
size_t size = sizeof(oldp);
|
||||
if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
|
||||
oldp = 0;
|
||||
}
|
||||
ggml_arm_arch_features.has_neon = oldp;
|
||||
|
||||
if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
|
||||
oldp = 0;
|
||||
}
|
||||
ggml_arm_arch_features.has_dotprod = oldp;
|
||||
|
||||
if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
|
||||
oldp = 0;
|
||||
}
|
||||
ggml_arm_arch_features.has_i8mm = oldp;
|
||||
|
||||
if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
|
||||
oldp = 0;
|
||||
}
|
||||
ggml_arm_arch_features.has_sme = oldp;
|
||||
|
||||
ggml_arm_arch_features.has_sve = 0;
|
||||
ggml_arm_arch_features.sve_cnt = 0;
|
||||
#else
|
||||
// Run-time CPU feature detection not implemented for this platform, fallback to compile time
|
||||
#if defined(__ARM_NEON)
|
||||
ggml_arm_arch_features.has_neon = 1;
|
||||
#else
|
||||
ggml_arm_arch_features.has_neon = 0;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
ggml_arm_arch_features.has_i8mm = 1;
|
||||
#else
|
||||
ggml_arm_arch_features.has_i8mm = 0;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
ggml_arm_arch_features.has_sve = 1;
|
||||
ggml_arm_arch_features.sve_cnt = 16;
|
||||
#else
|
||||
ggml_arm_arch_features.has_sve = 0;
|
||||
ggml_arm_arch_features.sve_cnt = 0;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
|
||||
ggml_arm_arch_features.has_sme = 1;
|
||||
#else
|
||||
ggml_arm_arch_features.has_sme = 0;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // __ARM_ARCH
|
||||
|
||||
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
|
||||
GGML_ASSERT(!ggml_get_no_alloc(ctx));
|
||||
@@ -1959,6 +1890,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_pad_reflect_1d(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ROLL:
|
||||
{
|
||||
ggml_compute_forward_roll(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_ARANGE:
|
||||
{
|
||||
ggml_compute_forward_arange(params, tensor);
|
||||
@@ -2283,6 +2218,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_ARGSORT:
|
||||
@@ -3435,7 +3371,7 @@ int ggml_cpu_has_vxe(void) {
|
||||
|
||||
int ggml_cpu_has_neon(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
|
||||
return ggml_arm_arch_features.has_neon;
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
@@ -3443,7 +3379,7 @@ int ggml_cpu_has_neon(void) {
|
||||
|
||||
int ggml_cpu_has_dotprod(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
|
||||
return ggml_arm_arch_features.has_dotprod;
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
@@ -3451,7 +3387,7 @@ int ggml_cpu_has_dotprod(void) {
|
||||
|
||||
int ggml_cpu_has_sve(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
|
||||
return ggml_arm_arch_features.has_sve;
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
@@ -3459,7 +3395,7 @@ int ggml_cpu_has_sve(void) {
|
||||
|
||||
int ggml_cpu_has_matmul_int8(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
return ggml_arm_arch_features.has_i8mm;
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
@@ -3475,7 +3411,7 @@ int ggml_cpu_get_sve_cnt(void) {
|
||||
|
||||
int ggml_cpu_has_sme(void) {
|
||||
#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
|
||||
return ggml_arm_arch_features.has_sme;
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
|
||||
@@ -53,7 +53,6 @@
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -63,7 +62,7 @@
|
||||
#define NOINLINE __attribute__((__noinline__))
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_NEON) || defined(__AVX512F__)
|
||||
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
|
||||
#define VECTOR_REGISTERS 32
|
||||
#else
|
||||
#define VECTOR_REGISTERS 16
|
||||
@@ -110,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
||||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
|
||||
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
|
||||
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
||||
#endif
|
||||
|
||||
#if defined(__MMA__)
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
@@ -163,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
template <>
|
||||
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
||||
return vec_madd(a, b, c);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED HORIZONTAL SUM
|
||||
|
||||
@@ -179,6 +191,13 @@ inline float hsum(float16x8_t x) {
|
||||
}
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
inline float hsum(float32x4_t x) {
|
||||
float32x4_t tmp = x + vec_reve(x);
|
||||
return tmp[0] + tmp[1];
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
inline float hsum(__m128 x) {
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
@@ -228,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
|
||||
#endif // _MSC_VER
|
||||
#endif // __ARM_NEON
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
template <> inline float32x4_t load(const ggml_fp16_t * p) {
|
||||
float tmp[4];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tmp[i] = GGML_FP16_TO_FP32(p[i]);
|
||||
}
|
||||
|
||||
return vec_xl(0, (const float *)(tmp));
|
||||
}
|
||||
template <> inline float32x4_t load(const float * p) {
|
||||
return vec_xl(0, p);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
template <> inline __m128 load(const float *p) {
|
||||
return _mm_loadu_ps(p);
|
||||
@@ -394,8 +428,6 @@ class tinyBLAS {
|
||||
|
||||
template <int RM, int RN, int BM>
|
||||
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
||||
static std::atomic<int64_t> current_chunk;
|
||||
|
||||
GGML_ASSERT(m % (RM * BM) == 0);
|
||||
const int64_t ytiles = m / (RM * BM);
|
||||
const int64_t xtiles = (n + RN -1) / RN;
|
||||
@@ -410,7 +442,7 @@ class tinyBLAS {
|
||||
if (params->ith == 0) {
|
||||
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
|
||||
ggml_threadpool_chunk_set(params->threadpool, params->nth);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
@@ -439,8 +471,7 @@ class tinyBLAS {
|
||||
GGML_ASSERT(jj == jj2);
|
||||
}
|
||||
|
||||
// next step.
|
||||
job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
|
||||
job = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
@@ -3323,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
(const float *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
return tb.matmul(m, n);
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
if (n < 4)
|
||||
return false;
|
||||
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
|
||||
k, (const float *)A, lda,
|
||||
(const float *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
return tb.matmul(m, n);
|
||||
#elif defined(__MMA__)
|
||||
if (k % 8)
|
||||
return false;
|
||||
@@ -3414,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
||||
(float *)C, ldc};
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
if (n < 4)
|
||||
return false;
|
||||
if (Btype == GGML_TYPE_F16) {
|
||||
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
|
||||
k, (const ggml_fp16_t *)A, lda,
|
||||
(const ggml_fp16_t *)B, ldb,
|
||||
(float *)C, ldc};
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
#pragma once
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#if defined(__VXE__) || defined(__VXE2__)
|
||||
#include <vecintrin.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
@@ -6793,6 +6793,73 @@ void ggml_compute_forward_pad_reflect_1d(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_roll
|
||||
|
||||
static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
|
||||
if (i < 0) {
|
||||
return i + ne;
|
||||
} else if (i >= ne) {
|
||||
return i - ne;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_roll_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src_data = (const float *) src0->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const int s0 = ggml_get_op_params_i32(dst, 0);
|
||||
const int s1 = ggml_get_op_params_i32(dst, 1);
|
||||
const int s2 = ggml_get_op_params_i32(dst, 2);
|
||||
const int s3 = ggml_get_op_params_i32(dst, 3);
|
||||
|
||||
const int64_t total = ne1 * ne2 * ne3;
|
||||
const int64_t per_thread = (total + params->nth) / params->nth;
|
||||
const int64_t start = params->ith * per_thread;
|
||||
const int64_t end = std::min(start + per_thread, total);
|
||||
|
||||
for (int64_t i = start; i < end; ++i) {
|
||||
const int64_t i1 = i % ne1;
|
||||
const int64_t i2 = (i / ne1) % ne2;
|
||||
const int64_t i3 = i / (ne2 * ne1);
|
||||
float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
|
||||
|
||||
const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
|
||||
const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
|
||||
const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
|
||||
const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
|
||||
|
||||
const int64_t s = ggml_wrap_index(-s0, ne00);
|
||||
const int64_t n = ne00 - s;
|
||||
ggml_vec_cpy_f32(n, dst_row, src_row + s);
|
||||
ggml_vec_cpy_f32(s, dst_row + n, src_row);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_roll(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_roll_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_arange
|
||||
|
||||
static void ggml_compute_forward_arange_f32(
|
||||
|
||||
@@ -72,6 +72,7 @@ void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params
|
||||
void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_roll(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
#include "ggml-quants.h"
|
||||
#include "quants.h"
|
||||
|
||||
#include "arch-fallback.h"
|
||||
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
#include <float.h>
|
||||
@@ -38,12 +40,10 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
|
||||
void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_q8_0_ref(x, y, k);
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(quantize_row_q8_0)
|
||||
|
||||
void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_q8_1_ref(x, y, k);
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(quantize_row_q8_1)
|
||||
|
||||
//
|
||||
// 2-6 bit quantization in super-blocks
|
||||
@@ -104,7 +104,6 @@ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy,
|
||||
void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
|
||||
quantize_row_q8_K_ref(x, y, k);
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(quantize_row_q8_K)
|
||||
|
||||
//===================================== Dot products =================================
|
||||
|
||||
@@ -143,7 +142,6 @@ void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q4_0_q8_0)
|
||||
|
||||
// TODO: add WASM SIMD
|
||||
void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
@@ -181,7 +179,6 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q4_1_q8_1)
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -225,7 +222,6 @@ void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q5_0_q8_0)
|
||||
|
||||
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_1;
|
||||
@@ -269,7 +265,6 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q5_1_q8_1)
|
||||
|
||||
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -300,7 +295,6 @@ void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q8_0_q8_0)
|
||||
|
||||
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@@ -353,7 +347,6 @@ void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_tq1_0_q8_K)
|
||||
|
||||
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@@ -386,7 +379,6 @@ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_tq2_0_q8_K)
|
||||
|
||||
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@@ -439,7 +431,6 @@ void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q2_K_q8_K)
|
||||
|
||||
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -519,7 +510,6 @@ void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q3_K_q8_K)
|
||||
|
||||
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -595,7 +585,6 @@ void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q4_K_q8_K)
|
||||
|
||||
void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -676,7 +665,6 @@ void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q5_K_q8_K)
|
||||
|
||||
void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -732,7 +720,6 @@ void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, c
|
||||
for (int l = 0; l < 8; ++l) sumf += sums[l];
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_q6_K_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -775,7 +762,6 @@ void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq2_xxs_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -826,7 +812,6 @@ void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq2_xs_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -879,7 +864,6 @@ void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
|
||||
*s = 0.125f * sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq2_s_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -924,7 +908,6 @@ void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
*s = 0.25f * sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq3_xxs_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -981,7 +964,6 @@ void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq3_s_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -1025,7 +1007,6 @@ void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq1_s_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -1087,7 +1068,6 @@ void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq1_m_q8_K)
|
||||
|
||||
void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@@ -1117,7 +1097,6 @@ void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq4_nl_q8_0)
|
||||
|
||||
void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
@@ -1164,7 +1143,6 @@ void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_vec_dot_iq4_xs_q8_K)
|
||||
|
||||
// ============================ 4-bit non-linear quants
|
||||
|
||||
|
||||
@@ -84,33 +84,6 @@ void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
|
||||
#if defined(GGML_CPU_GENERIC)
|
||||
#define quantize_row_q8_0_generic quantize_row_q8_0
|
||||
#define quantize_row_q8_1_generic quantize_row_q8_1
|
||||
#define quantize_row_q8_K_generic quantize_row_q8_K
|
||||
#define ggml_vec_dot_q4_0_q8_0_generic ggml_vec_dot_q4_0_q8_0
|
||||
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
||||
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
|
||||
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
|
||||
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
|
||||
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
|
||||
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
|
||||
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
|
||||
#define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K
|
||||
#define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K
|
||||
#define ggml_vec_dot_q5_K_q8_K_generic ggml_vec_dot_q5_K_q8_K
|
||||
#define ggml_vec_dot_q6_K_q8_K_generic ggml_vec_dot_q6_K_q8_K
|
||||
#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K
|
||||
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
|
||||
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
|
||||
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
|
||||
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
|
||||
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
|
||||
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
|
||||
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
|
||||
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "traits.h"
|
||||
|
||||
#include "arch-fallback.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
@@ -83,7 +85,6 @@ void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GG
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_quantize_mat_q8_0_4x4)
|
||||
|
||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK8_0 == 32);
|
||||
@@ -122,7 +123,6 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_quantize_mat_q8_0_4x8)
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
@@ -174,7 +174,6 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_quantize_mat_q8_K_4x8)
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -244,7 +243,6 @@ void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_0_4x4_q8_0)
|
||||
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -289,7 +287,6 @@ void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_0_4x8_q8_0)
|
||||
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -336,7 +333,6 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_0_8x8_q8_0)
|
||||
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
@@ -415,7 +411,6 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemv_q4_K_8x8_q8_K)
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -462,7 +457,6 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemv_iq4_nl_4x4_q8_0)
|
||||
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -519,7 +513,6 @@ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_0_4x4_q8_0)
|
||||
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -574,7 +567,6 @@ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_0_4x8_q8_0)
|
||||
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -629,7 +621,6 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_0_8x8_q8_0)
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
@@ -719,7 +710,6 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemm_q4_K_8x8_q8_K)
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
@@ -776,7 +766,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
|
||||
}
|
||||
}
|
||||
}
|
||||
GGML_CPU_NATIVE_IMPL(ggml_gemm_iq4_nl_4x4_q8_0)
|
||||
|
||||
} // extern "C"
|
||||
|
||||
@@ -1174,13 +1163,24 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
// not realy a GGML_TYPE_Q8_0 but same size.
|
||||
switch (op->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
||||
return true;
|
||||
{
|
||||
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
||||
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
||||
size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
|
||||
return true;
|
||||
{
|
||||
size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
|
||||
size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
|
||||
|
||||
const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
|
||||
const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
|
||||
|
||||
const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
|
||||
|
||||
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
// GGML_ABORT("fatal error");
|
||||
break;
|
||||
@@ -1316,14 +1316,17 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
|
||||
n_as * ne12 * sizeof(mmid_row_mapping)));
|
||||
GGML_ASSERT(params->wsize >=
|
||||
(GGML_PAD(nbw3, sizeof(int64_t)) +
|
||||
n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
|
||||
);
|
||||
|
||||
auto * wdata = (char *) params->wdata;
|
||||
auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
||||
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
auto * wdata = (char *)params->wdata;
|
||||
auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
|
||||
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
||||
// total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
|
||||
auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
|
||||
|
||||
// src1: float32 => param type
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
@@ -1408,44 +1411,45 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||
}
|
||||
};
|
||||
|
||||
// instance for Q4
|
||||
static const tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
||||
static const tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
||||
static const tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||
static const tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for IQ4
|
||||
static const tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||
|
||||
} // namespace ggml::cpu::repack
|
||||
|
||||
static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
|
||||
|
||||
// instance for Q4
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for IQ4
|
||||
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
|
||||
|
||||
if (cur->type == GGML_TYPE_Q4_0) {
|
||||
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &ggml::cpu::repack::q4_0_8x8_q8_0;
|
||||
return &q4_0_8x8_q8_0;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &ggml::cpu::repack::q4_0_4x8_q8_0;
|
||||
return &q4_0_4x8_q8_0;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &ggml::cpu::repack::q4_0_4x4_q8_0;
|
||||
return &q4_0_4x4_q8_0;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q4_K) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &ggml::cpu::repack::q4_K_8x8_q8_K;
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &ggml::cpu::repack::iq4_nl_4x4_q8_0;
|
||||
return &iq4_nl_4x4_q8_0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,10 +64,6 @@ static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wro
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Workaround for clang:
|
||||
// clang++ complains: ``error: call to 'ggml_gemm_q4_0_4x4_q8_0' is ambiguous''
|
||||
// repro: https://godbolt.org/z/oKdeWKonM (ICE), https://godbolt.org/z/1szq6P36v (ambiguous call)
|
||||
#if defined(GGML_CPU_CLANG_WORKAROUND) || !(defined(__GNUC__) && defined(__clang__)) || defined(__HIPCC__)
|
||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
@@ -81,7 +77,6 @@ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
#endif // !defined(__clang__)
|
||||
|
||||
// Native implementations
|
||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
@@ -98,22 +93,6 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
#if defined(GGML_CPU_GENERIC)
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
@@ -944,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
|
||||
for (int i = 0; i < offset; ++i) { \
|
||||
x[i] = vec_add(x[i], x[offset + i]); \
|
||||
} \
|
||||
res = vec_extract(x[0], 0) + \
|
||||
vec_extract(x[0], 1) + \
|
||||
vec_extract(x[0], 2) + \
|
||||
vec_extract(x[0], 3); \
|
||||
float32x4_t tmp = x[0] + vec_reve(x[0]); \
|
||||
res = tmp[0] + tmp[1]; \
|
||||
}
|
||||
|
||||
#define GGML_F32_VEC GGML_F32x4
|
||||
|
||||
@@ -19,10 +19,10 @@
|
||||
#endif
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -207,9 +207,9 @@ typedef float2 dfloat2;
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
|
||||
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
|
||||
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
#define NEW_MMA_AVAILABLE
|
||||
@@ -241,8 +241,18 @@ static bool fp16_mma_available(const int cc) {
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
return false;
|
||||
#else
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
||||
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return true;
|
||||
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
}
|
||||
|
||||
@@ -252,6 +262,10 @@ static bool fp16_mma_hardware_available(const int cc) {
|
||||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
|
||||
}
|
||||
|
||||
static bool bf16_mma_hardware_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
|
||||
}
|
||||
|
||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
|
||||
static bool new_mma_available(const int cc) {
|
||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
|
||||
@@ -262,11 +276,11 @@ static bool cp_async_available(const int cc) {
|
||||
}
|
||||
|
||||
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
return __AMDGCN_WAVEFRONT_SIZE;
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
return 64;
|
||||
#else
|
||||
return 32;
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
@@ -362,6 +376,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||
template<bool norm>
|
||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = norm ? sum / ncols : sum;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
@@ -767,21 +801,7 @@ struct ggml_backend_cuda_context {
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
}
|
||||
|
||||
~ggml_backend_cuda_context() {
|
||||
if (copy_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||
}
|
||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||
if (streams[i][j] != nullptr) {
|
||||
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
||||
}
|
||||
}
|
||||
if (cublas_handles[i] != nullptr) {
|
||||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
~ggml_backend_cuda_context();
|
||||
|
||||
cudaStream_t stream(int device, int stream) {
|
||||
if (streams[device][stream] == nullptr) {
|
||||
|
||||
161
ggml/src/ggml-cuda/conv2d-dw.cu
Normal file
161
ggml/src/ggml-cuda/conv2d-dw.cu
Normal file
@@ -0,0 +1,161 @@
|
||||
#include "conv2d-dw.cuh"
|
||||
|
||||
struct conv_params {
|
||||
int in_w, in_h;
|
||||
int out_w, out_h;
|
||||
int kernel_w, kernel_h;
|
||||
int stride_x, stride_y;
|
||||
int padding_x, padding_y;
|
||||
int dilation_x, dilation_y;
|
||||
int channels, batches;
|
||||
};
|
||||
|
||||
struct kernel_bounds {
|
||||
int y_min, y_max;
|
||||
int x_min, x_max;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
|
||||
kernel_bounds bounds;
|
||||
bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
||||
bounds.y_max =
|
||||
min(params.kernel_h,
|
||||
(params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
|
||||
bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
||||
bounds.x_max =
|
||||
min(params.kernel_w,
|
||||
(params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
|
||||
return bounds;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
|
||||
return out_coord * stride + kern_coord * dilation - padding;
|
||||
}
|
||||
|
||||
struct whcn_layout {
|
||||
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
||||
return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
|
||||
}
|
||||
|
||||
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
||||
return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
|
||||
}
|
||||
|
||||
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
||||
return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
|
||||
y * params.out_w + x;
|
||||
}
|
||||
|
||||
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
||||
int & out_x) {
|
||||
out_x = global_idx % params.out_w;
|
||||
out_y = (global_idx / params.out_w) % params.out_h;
|
||||
c = (global_idx / (params.out_w * params.out_h)) % params.channels;
|
||||
n = global_idx / (params.out_w * params.out_h * params.channels);
|
||||
}
|
||||
};
|
||||
|
||||
struct cwhn_layout {
|
||||
__device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
|
||||
return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
|
||||
}
|
||||
|
||||
__device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
|
||||
return (ky * params.kernel_w + kx) * params.channels + c;
|
||||
}
|
||||
|
||||
__device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
|
||||
return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
|
||||
x * params.channels + c;
|
||||
}
|
||||
|
||||
__device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
|
||||
int & out_x) {
|
||||
c = global_idx % params.channels;
|
||||
out_x = (global_idx / params.channels) % params.out_w;
|
||||
out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
|
||||
n = global_idx / (params.channels * params.out_w * params.out_h);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Layout>
|
||||
__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
|
||||
const int in_w, const int in_h, const int out_w, const int out_h,
|
||||
const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
|
||||
const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
|
||||
const int channels, const int batches) {
|
||||
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int total_elements = batches * channels * out_h * out_w;
|
||||
|
||||
if (global_idx >= total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
|
||||
stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
|
||||
|
||||
int batch_idx, channel_idx, out_y_idx, out_x_idx;
|
||||
Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
|
||||
|
||||
T accumulator = 0;
|
||||
kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
|
||||
|
||||
for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
|
||||
int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
|
||||
|
||||
for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
|
||||
int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
|
||||
|
||||
const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
|
||||
const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
|
||||
|
||||
accumulator += input_val * kernel_val;
|
||||
}
|
||||
}
|
||||
|
||||
output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
|
||||
}
|
||||
|
||||
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * kernel = dst->src[0];
|
||||
const ggml_tensor * input = dst->src[1];
|
||||
|
||||
GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
const float * w_d = (const float *) kernel->data;
|
||||
const float * x_d = (const float *) input->data;
|
||||
float * y_d = (float *) dst->data;
|
||||
|
||||
const int32_t * p = (const int32_t *) dst->op_params;
|
||||
const int stride_x = p[0];
|
||||
const int stride_y = p[1];
|
||||
const int padding_x = p[2];
|
||||
const int padding_y = p[3];
|
||||
const int dilation_x = p[4];
|
||||
const int dilation_y = p[5];
|
||||
|
||||
const int in_w = input->ne[0];
|
||||
const int in_h = input->ne[1];
|
||||
const int kernel_w = kernel->ne[0];
|
||||
const int kernel_h = kernel->ne[1];
|
||||
const int out_w = dst->ne[0];
|
||||
const int out_h = dst->ne[1];
|
||||
const int channels = dst->ne[2];
|
||||
const int batches = dst->ne[3];
|
||||
|
||||
cudaStream_t st = ctx.stream();
|
||||
|
||||
const int total = batches * channels * out_h * out_w;
|
||||
const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
|
||||
|
||||
if (ggml_is_contiguous(input)) {
|
||||
conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
||||
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
||||
dilation_x, dilation_y, channels, batches);
|
||||
} else if (ggml_is_contiguous_channels(input)) {
|
||||
conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
|
||||
x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
|
||||
dilation_x, dilation_y, channels, batches);
|
||||
} else {
|
||||
GGML_ABORT("Unsupported memory layout for conv_2d_dw");
|
||||
}
|
||||
}
|
||||
5
ggml/src/ggml-cuda/conv2d-dw.cuh
Normal file
5
ggml/src/ggml-cuda/conv2d-dw.cuh
Normal file
@@ -0,0 +1,5 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CONV2D_DW_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
91
ggml/src/ggml-cuda/conv2d-transpose.cu
Normal file
91
ggml/src/ggml-cuda/conv2d-transpose.cu
Normal file
@@ -0,0 +1,91 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "conv2d-transpose.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
|
||||
float * __restrict__ output, const int in_w, const int in_h, const int out_w,
|
||||
const int out_h, const int kernel_w, const int kernel_h, const int stride,
|
||||
const int c_in, const int c_out, const int batches) {
|
||||
const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
const int total_elements = out_w * out_h * c_out * batches;
|
||||
|
||||
if (global_idx >= total_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int out_x_idx = global_idx % out_w;
|
||||
const int out_y_idx = (global_idx / out_w) % out_h;
|
||||
const int c_idx = (global_idx / (out_w * out_h)) % c_out;
|
||||
const int n_idx = global_idx / (out_w * out_h * c_out);
|
||||
|
||||
float accumulator = 0;
|
||||
// For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
|
||||
|
||||
for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
|
||||
for (int kh = 0; kh < kernel_h; ++kh) {
|
||||
int in_y = out_y_idx - kh;
|
||||
if (in_y < 0 || in_y % stride) continue;
|
||||
in_y /= stride;
|
||||
if (in_y >= in_h) continue;
|
||||
|
||||
for (int kw = 0; kw < kernel_w; ++kw) {
|
||||
int in_x = out_x_idx - kw;
|
||||
if (in_x < 0 || in_x % stride) continue;
|
||||
in_x /= stride;
|
||||
if (in_x >= in_w) continue;
|
||||
|
||||
const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
|
||||
const int kernel_idx =
|
||||
(kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
|
||||
|
||||
float input_val = input[input_idx];
|
||||
half kern_val = kernel[kernel_idx];
|
||||
|
||||
accumulator += input_val * (float) kern_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
|
||||
}
|
||||
|
||||
//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
|
||||
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * kernel = dst->src[0];
|
||||
const ggml_tensor * input = dst->src[1];
|
||||
|
||||
GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
|
||||
const float * input_data = (const float *) input->data;
|
||||
float * output_data = (float *) dst->data;
|
||||
const half * kernel_data = (const half *) kernel->data;
|
||||
|
||||
const int input_w = input->ne[0];
|
||||
const int input_h = input->ne[1];
|
||||
const int output_w = dst->ne[0];
|
||||
const int output_h = dst->ne[1];
|
||||
const int channels_in = input->ne[2];
|
||||
const int channels_out = kernel->ne[2];
|
||||
const int kernel_w = kernel->ne[0];
|
||||
const int kernel_h = kernel->ne[1];
|
||||
const int stride = dst->op_params[0];
|
||||
const int batches = input->ne[3];
|
||||
|
||||
GGML_ASSERT(channels_in == kernel->ne[3]);
|
||||
GGML_ASSERT(stride > 0);
|
||||
|
||||
cudaStream_t st = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(input));
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
const int total = (output_w * output_h * channels_out * batches);
|
||||
const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
|
||||
|
||||
conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
|
||||
input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
|
||||
channels_in, channels_out, batches);
|
||||
}
|
||||
4
ggml/src/ggml-cuda/conv2d-transpose.cuh
Normal file
4
ggml/src/ggml-cuda/conv2d-transpose.cuh
Normal file
@@ -0,0 +1,4 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
|
||||
void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "ggml-cuda/clamp.cuh"
|
||||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/conv2d-dw.cuh"
|
||||
#include "ggml-cuda/conv2d-transpose.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
#include "ggml-cuda/count-equal.cuh"
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
@@ -35,6 +37,7 @@
|
||||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
@@ -47,6 +50,7 @@
|
||||
#include <atomic>
|
||||
#include <charconv>
|
||||
#include <cinttypes>
|
||||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <float.h>
|
||||
@@ -54,9 +58,8 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -97,8 +100,7 @@ int ggml_cuda_get_device() {
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
|
||||
{
|
||||
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
|
||||
err = cudaMallocManaged(ptr, size);
|
||||
#if defined(GGML_USE_HIP)
|
||||
if (err == hipSuccess) {
|
||||
@@ -116,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
|
||||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
else
|
||||
{
|
||||
} else {
|
||||
err = cudaMalloc(ptr, size);
|
||||
}
|
||||
return err;
|
||||
@@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||
}
|
||||
|
||||
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
|
||||
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
|
||||
|
||||
static std::mutex ggml_cuda_lock;
|
||||
static std::condition_variable ggml_cuda_lock_cv;
|
||||
static std::atomic<int> ggml_cuda_lock_counter;
|
||||
|
||||
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
|
||||
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
|
||||
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
|
||||
|
||||
if (copy_event != nullptr) {
|
||||
CUDA_CHECK(cudaEventDestroy(copy_event));
|
||||
}
|
||||
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
|
||||
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
|
||||
if (streams[i][j] != nullptr) {
|
||||
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
|
||||
}
|
||||
}
|
||||
if (cublas_handles[i] != nullptr) {
|
||||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// cuda buffer
|
||||
|
||||
struct ggml_backend_cuda_buffer_context {
|
||||
@@ -1916,16 +1943,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
&& ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
|
||||
|
||||
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
bool any_gpus_with_slow_fp16 = false;
|
||||
bool any_gpus_without_fp16_mma = false;
|
||||
bool any_gpus_with_slow_fp16 = false;
|
||||
|
||||
if (split) {
|
||||
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
||||
@@ -1936,16 +1961,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
continue;
|
||||
}
|
||||
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
} else {
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
|
||||
const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
||||
use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
||||
use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
|
||||
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
|
||||
}
|
||||
|
||||
// debug helpers
|
||||
@@ -1956,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
|
||||
if (!split && use_mul_mat_vec) {
|
||||
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
|
||||
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
|
||||
ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
|
||||
@@ -2310,6 +2335,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_IM2COL:
|
||||
ggml_cuda_op_im2col(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_cuda_op_conv2d_dw(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_cuda_op_conv_transpose_1d(ctx,dst);
|
||||
break;
|
||||
@@ -2322,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cuda_op_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||
break;
|
||||
@@ -2664,7 +2698,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
GGML_UNUSED(integrated);
|
||||
#endif // NDEBUG
|
||||
|
||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
||||
if (!ok) {
|
||||
@@ -2683,6 +2719,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
|
||||
ggml_cuda_lock_cv.notify_all();
|
||||
}
|
||||
} else {
|
||||
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
||||
}
|
||||
@@ -2758,7 +2799,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
// Start CUDA graph capture
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
|
||||
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
@@ -3207,9 +3254,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
|
||||
}
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
|
||||
19
ggml/src/ggml-cuda/mean.cu
Normal file
19
ggml/src/ggml-cuda/mean.cu
Normal file
@@ -0,0 +1,19 @@
|
||||
#include "mean.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
3
ggml/src/ggml-cuda/mean.cuh
Normal file
3
ggml/src/ggml-cuda/mean.cuh
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
@@ -2,25 +2,26 @@
|
||||
#include "common.cuh"
|
||||
#include "mmv.cuh"
|
||||
|
||||
template <typename T, typename type_acc, int block_size>
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
static __global__ void mul_mat_vec(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
const int64_t channel_dst = blockIdx.y;
|
||||
const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int64_t sample_dst = blockIdx.z;
|
||||
const int64_t sample_x = sample_dst / sample_ratio;
|
||||
const int64_t sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int row = blockIdx.x;
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
|
||||
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf = 0.0f;
|
||||
float sumf[ncols_dst] = {0.0f};
|
||||
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
const float2 * x2 = (const float2 *) x;
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = x2[col2];
|
||||
const float2 tmpy = y2[col2];
|
||||
sumf += tmpx.x*tmpy.x;
|
||||
sumf += tmpx.y*tmpy.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x*tmpy.x;
|
||||
sumf[j] += tmpx.y*tmpy.y;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same<T, half>::value) {
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
|
||||
if (std::is_same<type_acc, float>::value) {
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = __half22float2(x2[col2]);
|
||||
const float2 tmpy = y2[col2];
|
||||
sumf += tmpx.x * tmpy.x;
|
||||
sumf += tmpx.y * tmpy.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x * tmpy.x;
|
||||
sumf[j] += tmpx.y * tmpy.y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#ifdef FP16_AVAILABLE
|
||||
half2 sumh2 = make_half2(0.0f, 0.0f);
|
||||
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
||||
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmp = y2[col2];
|
||||
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const half2 tmpx = x2[col2];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
||||
}
|
||||
}
|
||||
|
||||
sumf = __low2float(sumh2) + __high2float(sumh2);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
||||
const int * x2 = (const int *) x;
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
const float2 tmpy = y2[col2];
|
||||
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
||||
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
||||
sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same<T, void>::value, "unsupported type");
|
||||
}
|
||||
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf;
|
||||
__syncthreads();
|
||||
if (tid >= warp_size) {
|
||||
return;
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf[j];
|
||||
__syncthreads();
|
||||
if (tid < warp_size) {
|
||||
sumf[j] = buf_iw[tid];
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
}
|
||||
if (j < ncols_dst) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
sumf = buf_iw[tid];
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
}
|
||||
|
||||
if (tid != 0) {
|
||||
if (tid >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = sumf;
|
||||
dst[tid*stride_col_dst + row] = sumf[tid];
|
||||
}
|
||||
|
||||
template <typename T, typename type_acc>
|
||||
template <typename T, typename type_acc, int ncols_dst>
|
||||
static void launch_mul_mat_vec_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(ncols % 2 == 0);
|
||||
GGML_ASSERT(stride_row % 2 == 0);
|
||||
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename type_acc>
|
||||
static void mul_mat_vec_cuda_switch_ncols_dst(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 1>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 2:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 2>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 3:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 3>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 4>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 5:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 5>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 6>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 7:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 7>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mul_mat_vec_cuda<T, type_acc, 8>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void mul_mat_vec_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
if constexpr(std::is_same<T, half>::value) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
launch_mul_mat_vec_cuda<T, half>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
mul_mat_vec_cuda_switch_ncols_dst<T, half>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
return;
|
||||
}
|
||||
}
|
||||
launch_mul_mat_vec_cuda<T, float>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
mul_mat_vec_cuda_switch_ncols_dst<T, float>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
}
|
||||
|
||||
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
||||
const int64_t stride_channel_dst = ids ? s1 : s2;
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
GGML_ASSERT(ncols_dst == 1);
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
|
||||
mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
GGML_ASSERT(src1_ncols == 1);
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
|
||||
// ggml_cuda_op provides single, contiguous matrices
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t stride_col_y = ne10;
|
||||
const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
|
||||
const int64_t nchannels_x = 1;
|
||||
const int64_t nchannels_y = 1;
|
||||
const int64_t nchannels_dst = 1;
|
||||
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
|
||||
mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
@@ -334,3 +441,48 @@ void ggml_cuda_op_mul_mat_vec(
|
||||
GGML_UNUSED(src1_ncols);
|
||||
GGML_UNUSED(src1_padded_row_size);
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
||||
if (src0_ne[0] % 2 != 0) {
|
||||
return false;
|
||||
}
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return ne11 <= 8;
|
||||
}
|
||||
if (cc >= GGML_CUDA_CC_TURING) {
|
||||
return ne11 <= 4;
|
||||
}
|
||||
return ne11 <= 3;
|
||||
}
|
||||
return ne11 <= 8;
|
||||
case GGML_TYPE_F16:
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return src0_small && ne11 <= 4;
|
||||
}
|
||||
if (fp16_mma_hardware_available(cc)) {
|
||||
return src0_small && ne11 <= 3;
|
||||
}
|
||||
return ne11 <= 8;
|
||||
}
|
||||
return ne11 <= 8;
|
||||
case GGML_TYPE_BF16:
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
return src0_small && ne11 <= 4;
|
||||
}
|
||||
if (bf16_mma_hardware_available(cc)) {
|
||||
return src0_small && ne11 <= 3;
|
||||
}
|
||||
return ne11 <= 8;
|
||||
}
|
||||
return ne11 <= 8;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
|
||||
#define MMV_MAX_ROWS 512
|
||||
|
||||
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec(
|
||||
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||
const int64_t src1_padded_row_size, cudaStream_t stream);
|
||||
|
||||
bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);
|
||||
|
||||
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||
float * __restrict__ dst, const int64_t L) {
|
||||
GGML_UNUSED(src1_nb0);
|
||||
GGML_UNUSED(src2_nb0);
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int bidx = blockIdx.x; // split along B
|
||||
const int bidy = blockIdx.y; // split along D
|
||||
const int tid = threadIdx.x;
|
||||
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||
if (N == 16) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
|
||||
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
||||
// todo: bank conflict
|
||||
// I am always confused with how to use the swizzling method to solve
|
||||
// bank conflit. Hoping somebody can tell me.
|
||||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
|
||||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
||||
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,25 +1,9 @@
|
||||
#include "sumrows.cuh"
|
||||
|
||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col == 0) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN)
|
||||
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
|
||||
endif()
|
||||
|
||||
if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
|
||||
add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
|
||||
endif()
|
||||
|
||||
if (NOT GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_NO_FA)
|
||||
endif()
|
||||
|
||||
@@ -48,22 +48,28 @@ static struct ggml_backend_metal_device_context {
|
||||
int mtl_device_ref_count;
|
||||
id<MTLLibrary> mtl_library;
|
||||
|
||||
NSLock * mtl_lock;
|
||||
|
||||
bool has_simdgroup_reduction;
|
||||
bool has_simdgroup_mm;
|
||||
bool has_residency_sets;
|
||||
bool has_bfloat;
|
||||
bool use_bfloat;
|
||||
|
||||
size_t max_size;
|
||||
|
||||
char name[128];
|
||||
} g_ggml_ctx_dev_main = {
|
||||
/*.mtl_device =*/ nil,
|
||||
/*.mtl_device_ref_count =*/ 0,
|
||||
/*.mtl_library =*/ nil,
|
||||
/*.mtl_lock =*/ nil,
|
||||
/*.has_simdgroup_reduction =*/ false,
|
||||
/*.has_simdgroup_mm =*/ false,
|
||||
/*.has_residency_sets =*/ false,
|
||||
/*.has_bfloat =*/ false,
|
||||
/*.use_bfloat =*/ false,
|
||||
/*.max_size =*/ 0,
|
||||
/*.name =*/ "",
|
||||
};
|
||||
|
||||
@@ -71,6 +77,10 @@ static struct ggml_backend_metal_device_context {
|
||||
static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
|
||||
assert(ctx != NULL);
|
||||
|
||||
if (ctx->mtl_lock == nil) {
|
||||
ctx->mtl_lock = [[NSLock alloc] init];
|
||||
}
|
||||
|
||||
if (ctx->mtl_device == nil) {
|
||||
ctx->mtl_device = MTLCreateSystemDefaultDevice();
|
||||
}
|
||||
@@ -94,6 +104,8 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
|
||||
strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
|
||||
}
|
||||
|
||||
@@ -110,6 +122,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
ctx->mtl_device_ref_count--;
|
||||
|
||||
if (ctx->mtl_device_ref_count == 0) {
|
||||
if (ctx->mtl_lock) {
|
||||
[ctx->mtl_lock release];
|
||||
ctx->mtl_lock = nil;
|
||||
}
|
||||
|
||||
if (ctx->mtl_library) {
|
||||
[ctx->mtl_library release];
|
||||
ctx->mtl_library = nil;
|
||||
@@ -498,6 +515,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -976,7 +994,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||
|
||||
@@ -990,9 +1008,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||
|
||||
// load library
|
||||
if (ctx_dev->mtl_library == nil) {
|
||||
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
||||
{
|
||||
[ctx_dev->mtl_lock lock];
|
||||
|
||||
if (ctx_dev->mtl_library == nil) {
|
||||
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
|
||||
}
|
||||
|
||||
[ctx_dev->mtl_lock unlock];
|
||||
}
|
||||
|
||||
id<MTLLibrary> metal_library = ctx_dev->mtl_library;
|
||||
if (metal_library == nil) {
|
||||
GGML_LOG_ERROR("%s: error: metal library is nil\n", __func__);
|
||||
@@ -1454,6 +1479,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
@@ -1653,6 +1679,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
@@ -2400,11 +2427,30 @@ static bool ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -2434,11 +2480,12 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
@@ -5261,7 +5308,6 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
|
||||
}
|
||||
|
||||
ggml_backend_metal_buffer_rset_free(ctx);
|
||||
ggml_backend_metal_device_rel(buffer->buft->device->context);
|
||||
|
||||
if (ctx->owned) {
|
||||
#if TARGET_OS_OSX
|
||||
@@ -5370,7 +5416,10 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||
}
|
||||
|
||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
|
||||
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
ctx->all_data = ggml_metal_host_malloc(size_aligned);
|
||||
ctx->all_size = size_aligned;
|
||||
@@ -5393,14 +5442,12 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||
if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -5411,17 +5458,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 32;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(buft->device->context);
|
||||
const size_t max_size = device.maxBufferLength;
|
||||
ggml_backend_metal_device_rel(buft->device->context);
|
||||
const size_t max_size = ((struct ggml_backend_metal_device_context *)buft->device->context)->max_size;
|
||||
|
||||
return max_size;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
@@ -5494,7 +5538,10 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||
}
|
||||
|
||||
struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
|
||||
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
// the buffer fits into the max buffer size allowed by the device
|
||||
if (size_aligned <= device.maxBufferLength) {
|
||||
@@ -5550,7 +5597,6 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
|
||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -5566,10 +5612,8 @@ static const char * ggml_backend_metal_name(ggml_backend_t backend) {
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_free(ggml_backend_t backend) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
ggml_metal_free(ctx);
|
||||
|
||||
free(backend);
|
||||
@@ -5709,6 +5753,8 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||
|
||||
return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
||||
}
|
||||
|
||||
@@ -5728,10 +5774,7 @@ static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
|
||||
}
|
||||
|
||||
static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
|
||||
// acq/rel just to populate ctx->name in case it hasn't been done yet
|
||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
||||
ggml_backend_metal_device_acq(ctx_dev);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
|
||||
return ctx_dev->name;
|
||||
}
|
||||
@@ -5739,12 +5782,10 @@ static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t
|
||||
static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
*total = device.recommendedMaxWorkingSetSize;
|
||||
*free = *total - device.currentAllocatedSize;
|
||||
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
} else {
|
||||
*free = 1;
|
||||
*total = 1;
|
||||
@@ -5822,7 +5863,10 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
||||
}
|
||||
|
||||
struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
|
||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||
|
||||
GGML_ASSERT(ctx_dev->mtl_device != nil);
|
||||
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
// the buffer fits into the max buffer size allowed by the device
|
||||
if (size_aligned <= device.maxBufferLength) {
|
||||
@@ -5878,7 +5922,6 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
|
||||
if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
|
||||
GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
|
||||
free(ctx);
|
||||
ggml_backend_metal_device_rel(ctx_dev);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@@ -5892,8 +5935,9 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const
|
||||
}
|
||||
|
||||
static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
||||
return
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
|
||||
buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
@@ -5978,8 +6022,19 @@ static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
|
||||
/* .get_proc_address = */ ggml_backend_metal_get_proc_address,
|
||||
};
|
||||
|
||||
// called upon program exit
|
||||
static void ggml_metal_cleanup(void) {
|
||||
ggml_backend_metal_device_rel(&g_ggml_ctx_dev_main);
|
||||
}
|
||||
|
||||
// TODO: make thread-safe
|
||||
ggml_backend_reg_t ggml_backend_metal_reg(void) {
|
||||
// TODO: make this thread-safe somehow?
|
||||
ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
|
||||
|
||||
// register cleanup callback
|
||||
// TODO: not ideal, but not sure if there is a better way to do this in Objective-C
|
||||
atexit(ggml_metal_cleanup);
|
||||
|
||||
{
|
||||
g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
|
||||
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
||||
|
||||
@@ -993,31 +993,61 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
||||
@@ -225,9 +225,9 @@ struct bin_bcast_sycl {
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast_unravel<bin_op>(
|
||||
@@ -246,9 +246,8 @@ struct bin_bcast_sycl {
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
||||
|
||||
@@ -89,33 +89,24 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
switch (dim) {
|
||||
case 0:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
|
||||
});
|
||||
break;
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
|
||||
break;
|
||||
case 1:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
||||
});
|
||||
break;
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
|
||||
break;
|
||||
// dim >=2 will be dispatched to the default path
|
||||
default:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
|
||||
});
|
||||
break;
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,33 +120,29 @@ static void concat_f32_sycl_non_cont(
|
||||
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
|
||||
uint64_t nb3, int32_t dim) {
|
||||
sycl::range<3> gridDim(ne3, ne2, ne1);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
int64_t i3 = item_ct1.get_group(0);
|
||||
int64_t i2 = item_ct1.get_group(1);
|
||||
int64_t i1 = item_ct1.get_group(2);
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||
int64_t i3 = item_ct1.get_group(0);
|
||||
int64_t i2 = item_ct1.get_group(1);
|
||||
int64_t i1 = item_ct1.get_group(2);
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
||||
int64_t o[4] = { 0, 0, 0, 0 };
|
||||
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
||||
|
||||
const float *x;
|
||||
const float * x;
|
||||
|
||||
for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2)) {
|
||||
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
|
||||
(i0)*nb00);
|
||||
x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
|
||||
} else {
|
||||
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
|
||||
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
|
||||
x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
|
||||
(i0 - o[0]) * nb10);
|
||||
}
|
||||
|
||||
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
@@ -59,16 +59,10 @@ static void conv_transpose_1d_f32_f32_sycl(
|
||||
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
||||
const sycl::range<3> block_nums(1, 1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
conv_transpose_1d_kernel(
|
||||
s0, output_size,
|
||||
src0_ne0, src0_ne1, src0_ne2,
|
||||
src1_ne0, dst_ne0,
|
||||
src0, src1, dst, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
|
||||
@@ -33,14 +33,11 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,24 +50,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 64),
|
||||
sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q2_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
|
||||
}
|
||||
#else
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q2_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -85,24 +76,18 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 64),
|
||||
sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q3_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
|
||||
}
|
||||
#else
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q3_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -116,12 +101,9 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,13 +117,12 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
||||
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
||||
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
||||
GGML_ASSERT(k % 2 == 0);
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||
sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||
});
|
||||
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
@@ -153,12 +134,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_1(vx, y, nb32, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,14 +149,13 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -191,13 +168,13 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
||||
});
|
||||
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -210,24 +187,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 64),
|
||||
sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q5_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
|
||||
}
|
||||
#else
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q5_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -242,24 +213,18 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 64),
|
||||
sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q6_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
|
||||
}
|
||||
#else
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q6_K(vx, y, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -271,9 +236,9 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
|
||||
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
||||
}
|
||||
|
||||
template <typename dst_t>
|
||||
@@ -284,15 +249,10 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq1_s(
|
||||
vx, y, item_ct1, iq1s_grid_gpu
|
||||
);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -305,15 +265,10 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq1_m(
|
||||
vx, y, item_ct1, iq1s_grid_gpu
|
||||
);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -326,15 +281,12 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq2_xxs(
|
||||
vx, y, item_ct1, iq2xxs_grid,
|
||||
ksigns_iq2xs, kmask_iq2xs);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -347,15 +299,12 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq2_xs(
|
||||
vx, y, item_ct1, iq2xs_grid,
|
||||
ksigns_iq2xs, kmask_iq2xs);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -368,13 +317,10 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq2_s(vx, y, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -388,15 +334,12 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq3_xxs(
|
||||
vx, y, item_ct1, iq3xxs_grid,
|
||||
ksigns_iq2xs, kmask_iq2xs);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -409,14 +352,10 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq3_s(
|
||||
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -432,14 +371,11 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq4_xs(vx, y, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
|
||||
});
|
||||
}
|
||||
#endif
|
||||
@@ -453,14 +389,11 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq4_nl(vx, y, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(
|
||||
cgh,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -413,7 +413,8 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
@@ -431,7 +432,8 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
@@ -449,7 +451,8 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
@@ -465,11 +468,11 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -477,11 +480,11 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ne;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -490,11 +493,11 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -502,8 +505,9 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ne;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||
item_ct1);
|
||||
@@ -516,11 +520,11 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -528,8 +532,9 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ne;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||
item_ct1);
|
||||
@@ -542,11 +547,11 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -554,8 +559,9 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ne;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||
item_ct1);
|
||||
@@ -568,11 +574,11 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -580,8 +586,9 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ne;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||
item_ct1);
|
||||
@@ -594,11 +601,11 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, const int ne00, const int ne01,
|
||||
@@ -609,7 +616,8 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
stream->parallel_for(
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
@@ -628,7 +636,8 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
|
||||
// dpct::has_capability_or_fail(stream->get_device(),
|
||||
// {sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
@@ -647,7 +656,8 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
|
||||
// dpct::has_capability_or_fail(stream->get_device(),
|
||||
// {sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
@@ -662,11 +672,13 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -675,11 +687,13 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -689,11 +703,13 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -702,10 +718,13 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
|
||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -715,10 +734,13 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
|
||||
const int nb12, const int nb13, queue_ptr stream) {
|
||||
|
||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try {
|
||||
|
||||
@@ -208,12 +208,10 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -877,12 +875,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -900,12 +897,10 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -921,12 +916,10 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -942,12 +935,10 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -963,12 +954,10 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -984,12 +973,10 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||
vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1002,11 +989,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||
@@ -1018,11 +1004,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||
@@ -1034,11 +1019,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||
@@ -1047,11 +1031,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||
@@ -1063,11 +1046,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_dequantize_mul_mat_vec(
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
#ifndef GGML_SYCL_DPCT_HELPER_HPP
|
||||
#define GGML_SYCL_DPCT_HELPER_HPP
|
||||
|
||||
#include <map>
|
||||
#include <sycl/sycl.hpp>
|
||||
#include <sycl/half_type.hpp>
|
||||
#include <syclcompat/math.hpp>
|
||||
#include <map>
|
||||
|
||||
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
||||
#include <oneapi/mkl.hpp>
|
||||
@@ -118,6 +118,36 @@ inline auto get_onemath_backend(sycl::queue& queue)
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
||||
namespace syclex = sycl::ext::oneapi::experimental;
|
||||
#endif
|
||||
|
||||
template <int NR, typename Func>
|
||||
__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
|
||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
||||
syclex::nd_launch(cgh, nd_range, func);
|
||||
#else
|
||||
cgh.parallel_for(nd_range, func);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int NR, typename Func>
|
||||
__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
|
||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
||||
syclex::nd_launch(*q, nd_range, func);
|
||||
#else
|
||||
q->parallel_for(nd_range, func);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
|
||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
||||
syclex::submit(*stream, func);
|
||||
#else
|
||||
stream->submit(func);
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace dpct
|
||||
{
|
||||
typedef sycl::queue *queue_ptr;
|
||||
|
||||
@@ -329,60 +329,51 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
const int offset, queue_ptr stream) {
|
||||
int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
|
||||
item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void gelu_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
gelu(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void silu_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
silu(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
|
||||
// hard code for now
|
||||
const int num_blocks = ceil_div(k, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
|
||||
sgn(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
|
||||
[=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
|
||||
// hard code for now
|
||||
const int num_blocks = ceil_div(k, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
|
||||
abs_op(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
|
||||
[=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
|
||||
@@ -390,23 +381,20 @@ template<typename T>
|
||||
static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
|
||||
// hard code for now
|
||||
const int num_blocks = ceil_div(k, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) {
|
||||
elu_op(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
|
||||
[=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void gelu_quick_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
gelu_quick(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
|
||||
@@ -414,169 +402,133 @@ template<typename T>
|
||||
static void gelu_erf_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
gelu_erf(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void tanh_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
tanh(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void relu_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
relu(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void hardsigmoid_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
hardsigmoid(x, dst, k, item_ct1);
|
||||
});
|
||||
[=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void hardswish_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
hardswish(x, dst, k, item_ct1);
|
||||
});
|
||||
[=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void exp_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
exp(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void log_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
log(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void neg_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
neg(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void step_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
step(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void sigmoid_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
|
||||
sycl_parallel_for(
|
||||
stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sigmoid(x, dst, k, item_ct1);
|
||||
});
|
||||
[=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void sqrt_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sqrt(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void sin_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sin(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void cos_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
cos(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -584,26 +536,20 @@ static void leaky_relu_sycl(const T *x, T *dst, const int k,
|
||||
const float negative_slope,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
leaky_relu(x, dst, k, negative_slope, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void sqr_sycl(const T *x, T *dst, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sqr(x, dst, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -614,9 +560,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
||||
int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
|
||||
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
sycl_parallel_for<1>(
|
||||
stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
||||
});
|
||||
}
|
||||
@@ -627,12 +572,10 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
|
||||
const int ne1, const int ne2, queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
pad(x, dst, ne0, ne00, ne01, ne02, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@@ -640,13 +583,10 @@ static void clamp_sycl(const T *x, T *dst, const float min,
|
||||
const float max, const int k,
|
||||
queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
clamp(x, dst, min, max, k, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream,
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -60,54 +60,6 @@ static void k_get_rows(
|
||||
dst_row[iybs + iqs + y_offset] = v.y();
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t_reorder dequantize_kernel_recorder, typename dst_t>
|
||||
static void k_get_rows_reorder(
|
||||
const void * src0, const void *src0_dq, const int32_t * src1, dst_t * dst,
|
||||
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
|
||||
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
|
||||
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
|
||||
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
|
||||
size_t s10, size_t s11, size_t s12,
|
||||
const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
|
||||
|
||||
const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
|
||||
item_ct1.get_local_id(2)) *
|
||||
2;
|
||||
const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne12;
|
||||
const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne12;
|
||||
|
||||
if (i00 >= ne00) {
|
||||
return;
|
||||
}
|
||||
auto ncols = ne00;
|
||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||
|
||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||
|
||||
const int src0_off = i01 * ncols + i00;
|
||||
const int ib = src0_off / QK4_0; // block index
|
||||
const int iqs = (i00%qk)/qr; // x quant index
|
||||
const int iybs = i00 - i00%qk; // dst block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
dequantize_kernel_recorder((const void *)src0_dq, ib, (const void *)src0, src0_off/2, v);
|
||||
|
||||
dst_row[iybs + iqs + 0] = v.x();
|
||||
dst_row[iybs + iqs + y_offset] = v.y();
|
||||
|
||||
GGML_UNUSED(nb01);
|
||||
GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03);
|
||||
}
|
||||
|
||||
template<typename src0_t, typename dst_t>
|
||||
static void k_get_rows_float(
|
||||
const src0_t * src0, const int32_t * src1, dst_t * dst,
|
||||
@@ -166,58 +118,15 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows<qk, qr, dq>(
|
||||
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
|
||||
item_ct1);
|
||||
});
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t_reorder dq_reorder>
|
||||
static void get_rows_sycl_reorder(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const void *src0_dd,
|
||||
const int32_t *src1_dd, float *dst_dd,
|
||||
queue_ptr stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
|
||||
const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
|
||||
|
||||
// strides in elements
|
||||
//const size_t s0 = nb0 / ggml_element_size(dst);
|
||||
const size_t s1 = nb1 / ggml_element_size(dst);
|
||||
const size_t s2 = nb2 / ggml_element_size(dst);
|
||||
const size_t s3 = nb3 / ggml_element_size(dst);
|
||||
|
||||
const size_t s10 = nb10 / ggml_element_size(src1);
|
||||
const size_t s11 = nb11 / ggml_element_size(src1);
|
||||
const size_t s12 = nb12 / ggml_element_size(src1);
|
||||
//const size_t s13 = nb13 / ggml_element_size(src1);
|
||||
|
||||
GGML_ASSERT(ne00 % 2 == 0);
|
||||
|
||||
const uint8_t* src0_q = (const uint8_t*)src0_dd;
|
||||
const size_t ncols = ne00;
|
||||
const size_t nrows = ne01;
|
||||
const sycl::half* src0_dq = (const sycl::half*)(src0_q + nrows * ncols / 2);
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||
k_get_rows_reorder<qk, qr, dq_reorder>(
|
||||
src0_dd, src0_dq, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
|
||||
template <typename src0_t>
|
||||
static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
@@ -245,9 +154,8 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||
});
|
||||
@@ -277,13 +185,8 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
if (ctx.opt_feature.reorder && dst->op == GGML_OP_MUL_MAT) {
|
||||
get_rows_sycl_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
} else {
|
||||
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
}
|
||||
get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
src1_i32, (float *)dst->data, ctx.stream());
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data,
|
||||
|
||||
@@ -1887,13 +1887,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||
sycl::range<1>(shared_mem), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
@@ -1901,13 +1900,12 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
});
|
||||
});
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||
sycl::range<1>(shared_mem), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
||||
x, dst, ncols, ncols_pad, item_ct1,
|
||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||
@@ -1925,50 +1923,47 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||
const sycl::range<3> block_nums(1, nrows, 1);
|
||||
const size_t shared_mem = 256 * sizeof(float);
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> shared_data(
|
||||
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
||||
sycl::local_accessor<int, 1> shared_indices(
|
||||
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int row = item_ct1.get_global_id(1);
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int row = item_ct1.get_global_id(1);
|
||||
|
||||
float max_val = -INFINITY;
|
||||
int max_idx = -1;
|
||||
float max_val = -INFINITY;
|
||||
int max_idx = -1;
|
||||
|
||||
for (int col = tid; col < ncols; col += 256) {
|
||||
float val = x[row * ncols + col];
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
max_idx = col;
|
||||
for (int col = tid; col < ncols; col += 256) {
|
||||
float val = x[row * ncols + col];
|
||||
if (val > max_val) {
|
||||
max_val = val;
|
||||
max_idx = col;
|
||||
}
|
||||
}
|
||||
|
||||
shared_data[tid] = max_val;
|
||||
shared_indices[tid] = max_idx;
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
for (int stride = 256 / 2; stride > 0; stride >>= 1) {
|
||||
if (tid < stride) {
|
||||
float val1 = shared_data[tid];
|
||||
float val2 = shared_data[tid + stride];
|
||||
if (val2 > val1) {
|
||||
shared_data[tid] = val2;
|
||||
shared_indices[tid] = shared_indices[tid + stride];
|
||||
}
|
||||
}
|
||||
|
||||
shared_data[tid] = max_val;
|
||||
shared_indices[tid] = max_idx;
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
|
||||
for (int stride = 256/2; stride > 0; stride >>= 1) {
|
||||
if (tid < stride) {
|
||||
float val1 = shared_data[tid];
|
||||
float val2 = shared_data[tid + stride];
|
||||
if (val2 > val1) {
|
||||
shared_data[tid] = val2;
|
||||
shared_indices[tid] = shared_indices[tid + stride];
|
||||
}
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
}
|
||||
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = shared_indices[0];
|
||||
}
|
||||
});
|
||||
if (tid == 0) {
|
||||
dst[row] = shared_indices[0];
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
||||
@@ -2952,7 +2947,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||
void ** ptrs_dst_get = ptrs_dst.get();
|
||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||
});
|
||||
@@ -3456,7 +3451,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||
{
|
||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
|
||||
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
||||
|
||||
char *__restrict src1_contiguous_get =
|
||||
@@ -3468,9 +3463,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||
size_t ids_nb_ct6 = ids->nb[1];
|
||||
size_t ids_nb_ct7 = ids->nb[0];
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_copy_src1_to_contiguous(
|
||||
src1_original, src1_contiguous_get,
|
||||
dev_cur_src1_row_get,
|
||||
@@ -3501,15 +3495,14 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||
{
|
||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
|
||||
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
const char *__restrict dst_contiguous_get =
|
||||
dst_contiguous.get();
|
||||
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
||||
dev_row_mapping.get();
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
k_copy_dst_from_contiguous(dst_original,
|
||||
dst_contiguous_get,
|
||||
dev_row_mapping_get,
|
||||
|
||||
@@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
|
||||
const u_int n_seq_tokens = T / B;
|
||||
sycl::range<1> block_dims((C / H));
|
||||
sycl::range<1> grid_dims((B * H));
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
/* local memory accessors*/
|
||||
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
||||
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
||||
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
||||
|
||||
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
|
||||
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
|
||||
u_int tid = item.get_local_id(0);
|
||||
u_int bid = item.get_group(0);
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
|
||||
|
||||
const int64_t CHW = IC * KH * KW;
|
||||
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
|
||||
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
|
||||
p0, p1, d0, d1, item_ct1);
|
||||
});
|
||||
|
||||
@@ -1818,7 +1818,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
||||
@@ -1829,9 +1829,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q4_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -1853,7 +1852,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
||||
@@ -1864,9 +1863,8 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q4_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -1933,7 +1931,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
||||
@@ -1944,9 +1942,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q4_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -1968,7 +1965,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
||||
@@ -1979,9 +1976,8 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q4_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2048,7 +2044,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
||||
@@ -2059,9 +2055,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q5_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2083,7 +2078,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
||||
@@ -2094,9 +2089,8 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q5_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2163,7 +2157,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
||||
@@ -2174,9 +2168,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q5_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2198,7 +2191,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
||||
@@ -2209,9 +2202,8 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q5_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2278,7 +2270,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
||||
@@ -2289,9 +2281,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q8_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2313,7 +2304,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
||||
@@ -2324,9 +2315,8 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q8_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2393,7 +2383,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
||||
@@ -2406,9 +2396,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q2_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2431,7 +2420,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
||||
@@ -2444,9 +2433,8 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q2_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2516,7 +2504,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
||||
@@ -2531,9 +2519,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q3_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2557,7 +2544,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
||||
@@ -2572,9 +2559,8 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q3_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2644,7 +2630,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
||||
@@ -2657,9 +2643,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q4_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2682,7 +2667,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
||||
@@ -2695,9 +2680,8 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q4_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2765,7 +2749,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
||||
@@ -2778,9 +2762,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q5_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2803,7 +2786,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
||||
@@ -2816,9 +2799,8 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q5_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2886,7 +2868,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
||||
@@ -2899,9 +2881,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q6_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
@@ -2924,7 +2905,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
||||
@@ -2937,9 +2918,8 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
mul_mat_q6_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
|
||||
@@ -544,12 +544,12 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -561,12 +561,12 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
|
||||
{
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -580,17 +580,12 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -604,17 +599,12 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -628,17 +618,12 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -652,17 +637,12 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -676,17 +656,12 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -700,17 +675,12 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -724,17 +694,12 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -750,12 +715,12 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
|
||||
nrows, nd_item);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -769,17 +734,12 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -794,12 +754,12 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
|
||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
||||
nd_item);
|
||||
});
|
||||
});
|
||||
}
|
||||
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
@@ -811,17 +771,12 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -836,14 +791,12 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -857,14 +810,12 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
stream->submit([&](sycl::handler & cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -878,15 +829,12 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
|
||||
item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -900,15 +848,12 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -922,15 +867,12 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
|
||||
item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -944,15 +886,12 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
|
||||
item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -966,14 +905,12 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
|
||||
item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -987,15 +924,12 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
|
||||
item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1009,15 +943,12 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
|
||||
nrows, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -254,14 +254,13 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
@@ -272,16 +271,15 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
||||
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,18 +288,14 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
const int ne_elements, queue_ptr stream, int device) {
|
||||
if (group_size < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
const float eps_ct4 = eps;
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(
|
||||
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
|
||||
WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
@@ -313,22 +307,18 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
cgh);
|
||||
|
||||
const float eps_ct4 = eps;
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements,
|
||||
eps_ct4, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -340,14 +330,13 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
@@ -358,16 +347,15 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,16 +366,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
|
||||
});
|
||||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
@@ -398,18 +382,15 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||
cgh);
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||
block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
|
||||
work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -235,20 +235,22 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
|
||||
the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||
*/
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,15 +269,17 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||
theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,12 +302,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||
}
|
||||
// launch kernel
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
@@ -333,12 +337,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||
}
|
||||
// launch kernel
|
||||
if (freq_factors == nullptr) {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
} else {
|
||||
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||
});
|
||||
|
||||
@@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||
const size_t n_local_scratch, queue_ptr stream) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
|
||||
@@ -45,14 +45,9 @@ static void timestep_embedding_f32_sycl(
|
||||
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
||||
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
||||
sycl::range<3> gridDim(1, ne00, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
gridDim * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
timestep_embedding_f32(
|
||||
x, dst, nb1, dim, max_period, item_ct1
|
||||
);
|
||||
});
|
||||
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
@@ -207,12 +207,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
@@ -220,12 +219,11 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
@@ -264,12 +262,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
// Submit kernel
|
||||
if (C / H == WKV_BLOCK_SIZE) {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
@@ -277,12 +274,11 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
});
|
||||
});
|
||||
} else {
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
sycl_parallel_for(
|
||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||
|
||||
@@ -49,15 +49,7 @@ if (Vulkan_FOUND)
|
||||
../../include/ggml-vulkan.h
|
||||
)
|
||||
|
||||
set(VULKAN_SHADER_GEN_CMAKE_ARGS
|
||||
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
|
||||
-DCMAKE_RUNTIME_OUTPUT_DIRECTORY=${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
|
||||
)
|
||||
|
||||
set(VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS "")
|
||||
if (CMAKE_BUILD_TYPE AND CMAKE_BUILD_TYPE MATCHES "Debug|Release|MinSizeRel|RelWithDebInfo")
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS --config=${CMAKE_BUILD_TYPE})
|
||||
endif()
|
||||
set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
|
||||
|
||||
# Test all shader extensions
|
||||
test_shader_extension_support(
|
||||
@@ -136,42 +128,45 @@ if (Vulkan_FOUND)
|
||||
set(HOST_CMAKE_TOOLCHAIN_FILE "")
|
||||
endif()
|
||||
|
||||
# Always use ExternalProject_Add approach
|
||||
include(ExternalProject)
|
||||
|
||||
# Add toolchain file if cross-compiling
|
||||
if (CMAKE_CROSSCOMPILING)
|
||||
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
|
||||
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
|
||||
endif()
|
||||
|
||||
# Native build through ExternalProject_Add
|
||||
ExternalProject_Add(
|
||||
vulkan-shaders-gen
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
|
||||
CMAKE_ARGS ${VULKAN_SHADER_GEN_CMAKE_ARGS}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . ${VULKAN_SHADER_GEN_CMAKE_BUILD_ARGS}
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
|
||||
INSTALL_DIR ${CMAKE_BINARY_DIR}
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
|
||||
-DCMAKE_INSTALL_BINDIR=.
|
||||
-DCMAKE_BUILD_TYPE=$<CONFIG>
|
||||
${VULKAN_SHADER_GEN_CMAKE_ARGS}
|
||||
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
|
||||
|
||||
# NOTE: When DESTDIR is set using Makefile generators and
|
||||
# "make install" triggers the build step, vulkan-shaders-gen
|
||||
# would be installed into the DESTDIR prefix, so it is unset
|
||||
# to ensure that does not happen.
|
||||
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
|
||||
${CMAKE_COMMAND} --install . --config $<CONFIG>
|
||||
)
|
||||
ExternalProject_Add_StepTargets(vulkan-shaders-gen build install)
|
||||
|
||||
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
|
||||
set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix})
|
||||
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
||||
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
|
||||
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
|
||||
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
|
||||
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
|
||||
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
|
||||
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
|
||||
set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
|
||||
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
|
||||
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
|
||||
|
||||
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
||||
set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen)
|
||||
|
||||
# Add build and install dependencies for all builds
|
||||
set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install)
|
||||
file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${_ggml_vk_header}
|
||||
${_ggml_vk_source}
|
||||
${_ggml_vk_source}
|
||||
|
||||
COMMAND ${_ggml_vk_genshaders_cmd}
|
||||
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
||||
@@ -181,7 +176,9 @@ if (Vulkan_FOUND)
|
||||
--target-cpp ${_ggml_vk_source}
|
||||
--no-clean
|
||||
|
||||
DEPENDS ${_ggml_vk_shader_deps}
|
||||
DEPENDS ${_ggml_vk_shader_files}
|
||||
vulkan-shaders-gen
|
||||
|
||||
COMMENT "Generate vulkan shaders"
|
||||
)
|
||||
|
||||
|
||||
@@ -168,6 +168,11 @@ struct vk_command_pool {
|
||||
vk_queue *q;
|
||||
};
|
||||
|
||||
// Prevent simultaneous submissions to the same queue.
|
||||
// This could be per vk_queue if we stopped having two vk_queue structures
|
||||
// sharing the same vk::Queue.
|
||||
static std::mutex queue_mutex;
|
||||
|
||||
struct vk_queue {
|
||||
uint32_t queue_family_index;
|
||||
vk::Queue queue;
|
||||
@@ -1036,6 +1041,14 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
|
||||
struct vk_instance_t {
|
||||
vk::Instance instance;
|
||||
|
||||
bool debug_utils_support = false; // VK_EXT_debug_utils enabled
|
||||
PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
|
||||
PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
|
||||
PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
|
||||
PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
|
||||
PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
|
||||
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
|
||||
|
||||
std::vector<size_t> device_indices;
|
||||
vk_device devices[GGML_VK_MAX_DEVICES];
|
||||
};
|
||||
@@ -1175,6 +1188,14 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||
}
|
||||
pipeline->compiled = true;
|
||||
|
||||
if (vk_instance.debug_utils_support) {
|
||||
vk::DebugUtilsObjectNameInfoEXT duoni;
|
||||
duoni.objectType = vk::ObjectType::ePipeline;
|
||||
duoni.pObjectName = pipeline->name.c_str();
|
||||
duoni.objectHandle = reinterpret_cast<uint64_t>(static_cast<VkPipeline_T*>(pipeline->pipeline));
|
||||
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(device->mutex);
|
||||
device->pipelines.insert({ pipeline->name, pipeline });
|
||||
@@ -1266,6 +1287,7 @@ static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command
|
||||
static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
|
||||
if (ctx->seqs.empty()) {
|
||||
if (fence) {
|
||||
std::lock_guard<std::mutex> guard(queue_mutex);
|
||||
ctx->p->q->queue.submit({}, fence);
|
||||
}
|
||||
return;
|
||||
@@ -1335,6 +1357,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> guard(queue_mutex);
|
||||
ctx->p->q->queue.submit(submit_infos, fence);
|
||||
|
||||
ctx->seqs.clear();
|
||||
@@ -3554,6 +3577,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
||||
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
|
||||
|
||||
static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
|
||||
|
||||
static void ggml_vk_instance_init() {
|
||||
if (vk_instance_initialized) {
|
||||
return;
|
||||
@@ -3574,7 +3599,7 @@ static void ggml_vk_instance_init() {
|
||||
#ifdef __APPLE__
|
||||
const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
|
||||
#endif
|
||||
|
||||
const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
|
||||
std::vector<const char*> layers;
|
||||
|
||||
if (validation_ext) {
|
||||
@@ -3589,6 +3614,9 @@ static void ggml_vk_instance_init() {
|
||||
extensions.push_back("VK_KHR_portability_enumeration");
|
||||
}
|
||||
#endif
|
||||
if (debug_utils_ext) {
|
||||
extensions.push_back("VK_EXT_debug_utils");
|
||||
}
|
||||
vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions);
|
||||
#ifdef __APPLE__
|
||||
if (portability_enumeration_ext) {
|
||||
@@ -3612,6 +3640,18 @@ static void ggml_vk_instance_init() {
|
||||
vk_instance.instance = vk::createInstance(instance_create_info);
|
||||
vk_instance_initialized = true;
|
||||
|
||||
if (debug_utils_ext) {
|
||||
vk_instance.debug_utils_support = true;
|
||||
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
|
||||
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
|
||||
vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
|
||||
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
|
||||
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
|
||||
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
|
||||
|
||||
}
|
||||
|
||||
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
|
||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||
|
||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||
@@ -9488,6 +9528,12 @@ static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer
|
||||
UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
return vk_instance.devices[0]->suballocation_block_size;
|
||||
|
||||
UNUSED(buft);
|
||||
}
|
||||
|
||||
// Should be changed to return device-specific host buffer type
|
||||
// but that probably requires changes in llama.cpp
|
||||
ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
|
||||
@@ -9496,7 +9542,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
|
||||
/* .get_name = */ ggml_backend_vk_host_buffer_type_name,
|
||||
/* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
||||
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
||||
},
|
||||
@@ -9643,6 +9689,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
||||
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
|
||||
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
||||
|
||||
if (vk_instance.debug_utils_support) {
|
||||
vk::DebugUtilsLabelEXT dul = {};
|
||||
dul.pLabelName = "ggml_backend_vk_graph_compute";
|
||||
dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
|
||||
vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
|
||||
}
|
||||
|
||||
uint64_t total_mat_mul_bytes = 0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
|
||||
@@ -10332,6 +10385,22 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
|
||||
UNUSED(instance_extensions);
|
||||
}
|
||||
|
||||
// Extension availability
|
||||
static bool ggml_vk_instance_debug_utils_ext_available(
|
||||
const std::vector<vk::ExtensionProperties> & instance_extensions) {
|
||||
// Check for portability enumeration extension for MoltenVK support
|
||||
for (const auto & properties : instance_extensions) {
|
||||
if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
|
||||
return false;
|
||||
|
||||
UNUSED(instance_extensions);
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
|
||||
switch (props.vendorID) {
|
||||
case VK_VENDOR_ID_INTEL:
|
||||
|
||||
@@ -25,15 +25,3 @@ add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
|
||||
|
||||
# Configure output directories for MSVC builds
|
||||
if(MSVC)
|
||||
# Get the main project's runtime output directory if possible
|
||||
if(DEFINED CMAKE_RUNTIME_OUTPUT_DIRECTORY)
|
||||
foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES})
|
||||
string(TOUPPER ${CONFIG} CONFIG)
|
||||
set_target_properties(${TARGET} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -888,12 +888,6 @@ struct ggml_context {
|
||||
struct ggml_object * objects_end;
|
||||
};
|
||||
|
||||
struct ggml_context_container {
|
||||
bool used;
|
||||
|
||||
struct ggml_context context;
|
||||
};
|
||||
|
||||
//
|
||||
// data types
|
||||
//
|
||||
@@ -961,6 +955,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"UPSCALE",
|
||||
"PAD",
|
||||
"PAD_REFLECT_1D",
|
||||
"ROLL",
|
||||
"ARANGE",
|
||||
"TIMESTEP_EMBEDDING",
|
||||
"ARGSORT",
|
||||
@@ -991,7 +986,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1056,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"upscale(x)",
|
||||
"pad(x)",
|
||||
"pad_reflect_1d(x)",
|
||||
"roll(x)",
|
||||
"arange(start, stop, step)",
|
||||
"timestep_embedding(timesteps, dim, max_period)",
|
||||
"argsort(x)",
|
||||
@@ -1086,7 +1082,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -4347,6 +4343,34 @@ struct ggml_tensor * ggml_pad_reflect_1d(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_roll
|
||||
|
||||
struct ggml_tensor * ggml_roll(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int shift0,
|
||||
int shift1,
|
||||
int shift2,
|
||||
int shift3) {
|
||||
GGML_ASSERT(a->nb[0] == ggml_type_size(a->type));
|
||||
GGML_ASSERT(abs(shift0) < a->ne[0]);
|
||||
GGML_ASSERT(abs(shift1) < a->ne[1]);
|
||||
GGML_ASSERT(abs(shift2) < a->ne[2]);
|
||||
GGML_ASSERT(abs(shift3) < a->ne[3]);
|
||||
|
||||
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, shift0);
|
||||
ggml_set_op_params_i32(result, 1, shift1);
|
||||
ggml_set_op_params_i32(result, 2, shift2);
|
||||
ggml_set_op_params_i32(result, 3, shift3);
|
||||
|
||||
result->op = GGML_OP_ROLL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_arange
|
||||
|
||||
struct ggml_tensor * ggml_arange(
|
||||
|
||||
@@ -198,6 +198,7 @@ class Keys:
|
||||
MASK_ID = "tokenizer.ggml.mask_token_id"
|
||||
ADD_BOS = "tokenizer.ggml.add_bos_token"
|
||||
ADD_EOS = "tokenizer.ggml.add_eos_token"
|
||||
ADD_SEP = "tokenizer.ggml.add_sep_token"
|
||||
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
|
||||
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
|
||||
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
|
||||
@@ -291,6 +292,7 @@ class MODEL_ARCH(IntEnum):
|
||||
BERT = auto()
|
||||
NOMIC_BERT = auto()
|
||||
NOMIC_BERT_MOE = auto()
|
||||
NEO_BERT = auto()
|
||||
JINA_BERT_V2 = auto()
|
||||
BLOOM = auto()
|
||||
STABLELM = auto()
|
||||
@@ -344,6 +346,7 @@ class MODEL_ARCH(IntEnum):
|
||||
PLM = auto()
|
||||
BAILINGMOE = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
@@ -572,6 +575,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
|
||||
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
|
||||
MODEL_ARCH.NEO_BERT: "neo-bert",
|
||||
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
|
||||
MODEL_ARCH.BLOOM: "bloom",
|
||||
MODEL_ARCH.STABLELM: "stablelm",
|
||||
@@ -624,7 +628,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
MODEL_ARCH.PLM: "plm",
|
||||
MODEL_ARCH.BAILINGMOE: "bailingmoe",
|
||||
MODEL_ARCH.DOTS1: "dots1"
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
@@ -1079,6 +1084,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.NEO_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
MODEL_TENSOR.CLS_OUT,
|
||||
],
|
||||
MODEL_ARCH.JINA_BERT_V2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
@@ -2070,6 +2087,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.ARCEE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
||||
@@ -271,7 +271,7 @@ class GGUFWriter:
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}')
|
||||
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
|
||||
|
||||
@@ -891,6 +891,9 @@ class GGUFWriter:
|
||||
def add_add_eos_token(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
||||
|
||||
def add_add_sep_token(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
|
||||
|
||||
def add_add_space_prefix(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class TensorNameMap:
|
||||
"model.embeddings", # rwkv7
|
||||
"model.word_embeddings", # bailingmoe
|
||||
"language_model.model.embed_tokens", # llama4
|
||||
"encoder", # neobert
|
||||
),
|
||||
|
||||
# Token type embeddings
|
||||
@@ -134,6 +135,7 @@ class TensorNameMap:
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.attention_norm", # neobert
|
||||
),
|
||||
|
||||
# Attention norm 2
|
||||
@@ -161,6 +163,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
"transformer_encoder.{bid}.qkv", # neobert
|
||||
),
|
||||
|
||||
# Attention query
|
||||
@@ -236,6 +239,7 @@ class TensorNameMap:
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama4
|
||||
"transformer_encoder.{bid}.wo", # neobert
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
@@ -276,6 +280,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama4
|
||||
"transformer_encoder.{bid}.ffn_norm", # neobert
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
@@ -340,6 +345,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
"model.layers.{bid}.feed_forward.up_proj", # llama4
|
||||
"transformer_encoder.{bid}.ffn.w12", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
@@ -422,6 +428,7 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
"model.layers.{bid}.feed_forward.down_proj", # llama4
|
||||
"transformer_encoder.{bid}.ffn.w3", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
@@ -832,12 +839,14 @@ class TensorNameMap:
|
||||
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: (
|
||||
"encoder.final_layer_norm", # t5
|
||||
"layer_norm", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS: (
|
||||
"classifier", # jina
|
||||
"classifier.dense", # roberta
|
||||
"pre_classifier", # distillbert
|
||||
"dense", # neobert
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS_OUT: (
|
||||
|
||||
@@ -7,7 +7,10 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable
|
||||
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
try:
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
except ImportError:
|
||||
SentencePieceProcessor = None
|
||||
|
||||
import gguf
|
||||
|
||||
@@ -116,6 +119,7 @@ class SpecialVocab:
|
||||
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
|
||||
|
||||
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
||||
tokenizer = None
|
||||
tokenizer_file = path / 'tokenizer.json'
|
||||
if tokenizer_file.is_file():
|
||||
with open(tokenizer_file, encoding = 'utf-8') as f:
|
||||
@@ -149,11 +153,97 @@ class SpecialVocab:
|
||||
added_tokens = tokenizer.get('added_tokens', {})
|
||||
else:
|
||||
added_tokens = {}
|
||||
tokenizer_config = None
|
||||
tokenizer_config_file = path / 'tokenizer_config.json'
|
||||
if not tokenizer_config_file.is_file():
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
if tokenizer:
|
||||
special_bos = (tokenizer_config or {}).get('bos_token')
|
||||
special_cls = (tokenizer_config or {}).get('cls_token')
|
||||
special_eos = (tokenizer_config or {}).get('eos_token')
|
||||
special_sep = (tokenizer_config or {}).get('sep_token')
|
||||
if not special_bos and special_cls and tokenizer_config:
|
||||
tokenizer_config['bos_token'] = special_bos = special_cls
|
||||
if not special_eos and special_sep and tokenizer_config:
|
||||
tokenizer_config['eos_token'] = special_eos = special_sep
|
||||
if post_processor := tokenizer.get('post_processor'):
|
||||
for processor in post_processor.get('processors', [post_processor]):
|
||||
if processor.get('type') == 'RobertaProcessing':
|
||||
self.add_special_token['bos'] = True
|
||||
self.add_special_token['eos'] = True
|
||||
self.add_special_token['sep'] = True
|
||||
if not special_cls and tokenizer_config:
|
||||
special_cls = processor.get('cls', [special_bos])[0]
|
||||
tokenizer_config['cls_token'] = special_cls
|
||||
if not special_sep and tokenizer_config:
|
||||
special_sep = processor.get('sep', [special_eos])[0]
|
||||
tokenizer_config['sep_token'] = special_sep
|
||||
continue
|
||||
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
||||
# Only works with simple templates, **will** get it wrong on unusual sequences
|
||||
if processor.get('type') == 'TemplateProcessing':
|
||||
tmpl_single = processor.get('single', [])
|
||||
tmpl_pair = processor.get('pair', [])
|
||||
special_first = None
|
||||
special_last = None
|
||||
if len(tmpl_single) > 1:
|
||||
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_bos = special_first
|
||||
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
||||
if special_first not in (special_bos, special_cls):
|
||||
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
||||
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_eos = special_last
|
||||
elif special_last != special_eos:
|
||||
if 'eot' not in self.special_token_types:
|
||||
self.special_token_types = tuple(self.special_token_types) + ('eot', )
|
||||
tokenizer_config['eot_token'] = special_eos
|
||||
elif 'eom' not in self.special_token_types:
|
||||
self.special_token_types = tuple(self.special_token_types) + ('eom', )
|
||||
tokenizer_config['eom_token'] = special_eos
|
||||
else:
|
||||
logger.warning(f'Overriding EOS token {special_eos!r} with {special_last!r} without EOT/EOM fallback!')
|
||||
tokenizer_config['eos_token'] = special_eos = special_last
|
||||
self.add_special_token['eos'] = True if special_last == special_eos else False
|
||||
if special_last != special_eos:
|
||||
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
||||
if tmpl_pair:
|
||||
seq_start = 1 if special_first and tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
||||
seq_stop = -1 if special_last and tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
||||
if (special_first and seq_start == 0) or (special_last and seq_stop is None):
|
||||
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
||||
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
||||
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
||||
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
||||
if tmpl_a != 'A' or tmpl_b != 'B':
|
||||
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
||||
# A [sep] [eos] B
|
||||
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
||||
add_sep = False
|
||||
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos) and not special_last:
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
||||
if len(tmpl_pair) == 2:
|
||||
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos):
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
||||
self.add_special_token['sep'] = add_sep
|
||||
if add_sep and not special_sep and tokenizer_config:
|
||||
tokenizer_config['sep_token'] = special_eos
|
||||
continue
|
||||
if not tokenizer_config:
|
||||
return True
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
chat_template_alt = None
|
||||
chat_template_file = path / 'chat_template.json'
|
||||
if chat_template_file.is_file():
|
||||
@@ -302,6 +392,9 @@ class SentencePieceVocab(Vocab):
|
||||
name = "spm"
|
||||
|
||||
def __init__(self, base_path: Path):
|
||||
if SentencePieceProcessor is None:
|
||||
raise RuntimeError("sentencepiece is not installed")
|
||||
|
||||
added_tokens: dict[str, int] = {}
|
||||
if (fname_tokenizer := base_path / 'tokenizer.model').exists():
|
||||
# normal location
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.17.0"
|
||||
version = "0.17.1"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
@@ -22,7 +22,7 @@ python = ">=3.8"
|
||||
numpy = ">=1.17"
|
||||
tqdm = ">=4.27"
|
||||
pyyaml = ">=5.1"
|
||||
sentencepiece = ">=0.1.98,<=0.2.0"
|
||||
sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true }
|
||||
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
@@ -254,7 +254,10 @@ extern "C" {
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
// (if set to NULL:
|
||||
// - if embeddings: all tokens are output
|
||||
// - if not: only the last token is output
|
||||
// )
|
||||
//
|
||||
typedef struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
@@ -262,8 +265,8 @@ extern "C" {
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
|
||||
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
|
||||
@@ -387,6 +390,7 @@ extern "C" {
|
||||
void * imatrix; // pointer to importance matrix data
|
||||
void * kv_overrides; // pointer to vector containing overrides
|
||||
void * tensor_types; // pointer to vector containing tensor types
|
||||
void * prune_layers; // pointer to vector containing layer indices to prune
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
@@ -940,12 +944,14 @@ extern "C" {
|
||||
// Requires the context to have a memory.
|
||||
// For encode-decoder contexts, processes the batch using the decoder.
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// Upon non-zero return values, the memory state is restored to the state before this call
|
||||
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
|
||||
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
||||
// Upon other return values, the memory state is restored to the state before this call
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// 2 - aborted
|
||||
// 2 - aborted (processed ubatches will remain in the context's memory)
|
||||
// -1 - invalid input batch
|
||||
// < -1 - error
|
||||
// < -1 - fatal error (processed ubatches will remain in the context's memory)
|
||||
LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
@@ -961,8 +967,8 @@ extern "C" {
|
||||
// Get the number of threads used for prompt and batch processing (multiple token).
|
||||
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
|
||||
|
||||
// Set whether the model is in embeddings mode or not
|
||||
// If true, embeddings will be returned but logits will not
|
||||
// Set whether the context outputs embeddings or not
|
||||
// TODO: rename to avoid confusion with llama_get_embeddings()
|
||||
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
|
||||
|
||||
// Set whether to use causal attention or not
|
||||
@@ -1041,6 +1047,7 @@ extern "C" {
|
||||
|
||||
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
||||
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
|
||||
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
|
||||
@@ -1084,6 +1091,7 @@ extern "C" {
|
||||
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
||||
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
||||
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
||||
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
|
||||
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
|
||||
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
||||
/// as plaintext. Does not insert a leading space.
|
||||
|
||||
@@ -1 +1 @@
|
||||
6a7d170c04789f6ebcf320ed03c1b16973f93bd7
|
||||
9e4bee1c5afc2d677a5b32ecb90cbdb483e81fff
|
||||
|
||||
@@ -22,8 +22,9 @@ add_library(llama
|
||||
llama-io.cpp
|
||||
llama-kv-cache-unified.cpp
|
||||
llama-kv-cache-unified-iswa.cpp
|
||||
llama-kv-cache-recurrent.cpp
|
||||
llama-memory.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-memory-recurrent.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model-saver.cpp
|
||||
|
||||
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_BERT, "bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
|
||||
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
|
||||
{ LLM_ARCH_NEO_BERT, "neo-bert" },
|
||||
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
|
||||
{ LLM_ARCH_BLOOM, "bloom" },
|
||||
{ LLM_ARCH_STABLELM, "stablelm" },
|
||||
@@ -73,6 +74,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
{ LLM_ARCH_DOTS1, "dots1" },
|
||||
{ LLM_ARCH_ARCEE, "arcee" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@@ -145,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
@@ -195,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
||||
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
|
||||
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
||||
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
||||
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
||||
@@ -244,6 +248,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ARCEE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LLAMA4,
|
||||
{
|
||||
@@ -495,6 +517,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NEO_BERT,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_CLS, "cls" },
|
||||
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
{
|
||||
@@ -1781,3 +1818,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||
return LLM_TENSOR_INFOS.at(tensor);
|
||||
}
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||
// the place to identify them
|
||||
switch (arch) {
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ enum llm_arch {
|
||||
LLM_ARCH_BERT,
|
||||
LLM_ARCH_NOMIC_BERT,
|
||||
LLM_ARCH_NOMIC_BERT_MOE,
|
||||
LLM_ARCH_NEO_BERT,
|
||||
LLM_ARCH_JINA_BERT_V2,
|
||||
LLM_ARCH_BLOOM,
|
||||
LLM_ARCH_STABLELM,
|
||||
@@ -77,6 +78,7 @@ enum llm_arch {
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
LLM_ARCH_DOTS1,
|
||||
LLM_ARCH_ARCEE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@@ -149,6 +151,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_LAYER_INDICES,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
@@ -191,6 +194,7 @@ enum llm_kv {
|
||||
LLM_KV_TOKENIZER_MASK_ID,
|
||||
LLM_KV_TOKENIZER_ADD_BOS,
|
||||
LLM_KV_TOKENIZER_ADD_EOS,
|
||||
LLM_KV_TOKENIZER_ADD_SEP,
|
||||
LLM_KV_TOKENIZER_ADD_PREFIX,
|
||||
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
||||
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
||||
@@ -437,3 +441,6 @@ const char * llm_arch_name(llm_arch arch);
|
||||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,106 +2,102 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "llama-cparams.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <bitset>
|
||||
#include <unordered_map>
|
||||
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
// keep this struct lightweight
|
||||
// it points to data in `llama_batch_allocr`
|
||||
struct llama_ubatch {
|
||||
bool equal_seqs;
|
||||
// TODO: whole_seqs for embeddings?
|
||||
|
||||
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
||||
uint32_t n_seq_tokens; // tokens per sequence
|
||||
uint32_t n_seqs;
|
||||
uint32_t n_seq_tokens; // tokens per sequence set
|
||||
uint32_t n_seqs; // sequence sets in the ubatch
|
||||
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
||||
|
||||
llama_token * token; // [n_tokens]
|
||||
float * embd; // [n_embd, n_tokens]
|
||||
llama_pos * pos; // [n_tokens]
|
||||
int32_t * n_seq_id; // [n_seqs]
|
||||
llama_seq_id ** seq_id; // [n_seqs]
|
||||
int8_t * output; // [n_tokens]
|
||||
// seq_id_unq: unique sequence ids in the ubatch
|
||||
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
||||
// used for extracting sequence pooled embeddings
|
||||
|
||||
// // size | idx | val
|
||||
llama_token * token; // [n_tokens] | i | id, token
|
||||
float * embd; // [n_embd, n_tokens] | i | embd
|
||||
llama_pos * pos; // [n_tokens] | i | pos
|
||||
int32_t * n_seq_id; // [n_tokens] | i | -
|
||||
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
||||
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
||||
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
||||
int8_t * output; // [n_tokens] | i | -
|
||||
};
|
||||
|
||||
struct llama_sbatch_seq {
|
||||
int32_t n_seq_id;
|
||||
|
||||
llama_seq_id * seq_id;
|
||||
|
||||
size_t offset;
|
||||
size_t length;
|
||||
};
|
||||
|
||||
// sequence-length-aware batch splitting
|
||||
struct llama_sbatch {
|
||||
// tokens left in this batch
|
||||
size_t n_tokens;
|
||||
|
||||
size_t n_embd;
|
||||
|
||||
// sorted indices into the batch
|
||||
std::vector<int64_t> ids;
|
||||
// batch indices of the output
|
||||
std::vector<int64_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
const llama_batch * batch = nullptr;
|
||||
|
||||
// buffers for the ubatches
|
||||
// TODO: very hacky, this needs a complete rework
|
||||
struct ubatch_data {
|
||||
std::vector<llama_token> token;
|
||||
std::vector<float> embd;
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<int8_t> output;
|
||||
};
|
||||
|
||||
std::vector<ubatch_data> udatas;
|
||||
|
||||
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
||||
|
||||
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
|
||||
|
||||
// simple split, unknown number of sequences of unequal lengths
|
||||
llama_ubatch split_simple(size_t n_ubatch);
|
||||
|
||||
// make batches of equal-length sequences
|
||||
llama_ubatch split_equal(size_t n_ubatch);
|
||||
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
llama_sbatch() = default;
|
||||
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
|
||||
};
|
||||
|
||||
// a helper for sanitizing and fulfilling a batch
|
||||
// a helper for sanitizing, fulfilling and splitting a batch
|
||||
class llama_batch_allocr {
|
||||
public:
|
||||
llama_batch_allocr();
|
||||
llama_batch_allocr(uint32_t n_pos_per_embd);
|
||||
|
||||
// sanitize and auto-gen missing data in the input batch
|
||||
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
||||
bool init(
|
||||
const llama_batch & batch_inp,
|
||||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory);
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
bool output_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
||||
uint32_t get_n_tokens() const;
|
||||
uint32_t get_n_outputs() const;
|
||||
|
||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||
std::vector<int32_t> & get_out_ids();
|
||||
|
||||
// min/max positions of each sequence in the current ubatch
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
||||
|
||||
// call once before splitting the batch to reset the internal state
|
||||
void split_reset();
|
||||
|
||||
// simple split, unknown number of sequence sets of unequal lengths
|
||||
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||
|
||||
// make ubatches of equal-length sequences sets
|
||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
||||
|
||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||
|
||||
// a helper method for creating a well-defined ubatch of tokens
|
||||
// TODO: support embeddings if needed in the future
|
||||
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
|
||||
|
||||
private:
|
||||
void clear();
|
||||
|
||||
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
|
||||
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
|
||||
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
|
||||
|
||||
// for debugging, start with LLAMA_BATCH_DEBUG=2
|
||||
void ubatch_print(const llama_ubatch & ubatch, int debug);
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
// only for debugging purposes
|
||||
const llama_vocab * vocab;
|
||||
|
||||
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
|
||||
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
||||
const uint32_t n_pos_per_embd;
|
||||
|
||||
uint32_t n_embd;
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
@@ -109,10 +105,43 @@ private:
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
|
||||
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
using pos_set_t = std::set<llama_pos>;
|
||||
using seq_cpl_t = std::vector<bool>;
|
||||
|
||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
|
||||
using idx_vec_t = std::vector<int32_t>;
|
||||
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
||||
|
||||
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
|
||||
|
||||
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
|
||||
|
||||
// batch indices of the output
|
||||
std::vector<int32_t> out_ids;
|
||||
|
||||
// used[i] indicates if token i has already been used in a previous ubatch
|
||||
std::vector<bool> used;
|
||||
|
||||
// llama_ubatch points to this data:
|
||||
struct ubatch {
|
||||
std::vector<llama_token> token;
|
||||
std::vector<float> embd;
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id;
|
||||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
};
|
||||
|
||||
// current splitting state:
|
||||
std::vector<ubatch> ubatches;
|
||||
|
||||
int debug;
|
||||
};
|
||||
|
||||
@@ -333,7 +333,7 @@ int32_t llm_chat_apply_template(
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
||||
system_prompt = trim(message->content);
|
||||
system_prompt += trim(message->content);
|
||||
continue;
|
||||
}
|
||||
// in gemma, "assistant" is "model"
|
||||
@@ -355,7 +355,7 @@ int32_t llm_chat_apply_template(
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
// there is no system message support, we will merge it with user prompt
|
||||
system_prompt = message->content;
|
||||
system_prompt += message->content;
|
||||
continue;
|
||||
} else if (role == "user") {
|
||||
ss << "Human: ";
|
||||
|
||||
@@ -20,7 +20,7 @@ llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
llama_context_params params) :
|
||||
model(model),
|
||||
batch_allocr(std::make_unique<llama_batch_allocr>()) {
|
||||
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||
|
||||
t_start_us = model.t_start_us;
|
||||
@@ -280,8 +280,8 @@ llama_context::llama_context(
|
||||
|
||||
// simulate full KV cache
|
||||
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
const auto mctx = memory->init_full();
|
||||
if (!mctx) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -300,7 +300,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -311,7 +311,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
const auto mstate = memory->init_update(this, optimize);
|
||||
switch (mstate->get_status()) {
|
||||
const auto mctx = memory->init_update(this, optimize);
|
||||
switch (mctx->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
// noop
|
||||
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!mstate->apply()) {
|
||||
if (!mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||
{
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
throw std::runtime_error("failed to initialize memory state");
|
||||
const auto mctx = memory->init_full();
|
||||
if (!mctx) {
|
||||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||
}
|
||||
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
|
||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
||||
if (mstate && !mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
||||
}
|
||||
|
||||
int llama_context::encode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
if (batch_inp.n_tokens == 0) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const llama_batch & batch = batch_allocr->get_batch();
|
||||
const uint32_t n_tokens = balloc->get_n_tokens();
|
||||
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
|
||||
n_queued_tokens += n_tokens;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
|
||||
|
||||
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_tokens) < n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
{
|
||||
// extract sequence embeddings
|
||||
auto & embd_seq_out = embd_seq;
|
||||
embd_seq_out.clear();
|
||||
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
||||
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
||||
|
||||
const auto & batch = balloc->get_batch();
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
cross.seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
cross.seq_ids_enc[i].clear();
|
||||
|
||||
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
const llama_seq_id seq_id = batch.seq_id[i][s];
|
||||
|
||||
cross.seq_ids_enc[i].insert(seq_id);
|
||||
}
|
||||
}
|
||||
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
}
|
||||
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
if (!memory) {
|
||||
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
||||
return encode(batch_inp);
|
||||
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const llama_batch & batch = batch_allocr->get_batch();
|
||||
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int32_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const uint32_t n_tokens_all = batch.n_tokens;
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// this indicates we are doing pooled embedding
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
||||
const uint32_t n_outputs_all = balloc->get_n_outputs();
|
||||
|
||||
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
|
||||
|
||||
if (embd_pooled) {
|
||||
if (output_all) {
|
||||
// require that all tokens are output
|
||||
if (n_outputs_all != n_tokens_all) {
|
||||
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
||||
@@ -942,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update(false);
|
||||
|
||||
llama_memory_state_ptr mstate;
|
||||
llama_memory_context_ptr mctx;
|
||||
|
||||
while (true) {
|
||||
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
|
||||
if (!mstate) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
if (!mctx) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (mstate->get_status()) {
|
||||
switch (mctx->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
||||
|
||||
return -2;
|
||||
}
|
||||
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
did_optimize = true;
|
||||
|
||||
if (kv_self_update(true)) {
|
||||
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
||||
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
|
||||
|
||||
return 1;
|
||||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
||||
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
|
||||
|
||||
return -2;
|
||||
}
|
||||
@@ -996,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
||||
// count the outputs in this ubatch
|
||||
{
|
||||
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
if (n_outputs_all == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
@@ -1019,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
@@ -1028,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
||||
}
|
||||
|
||||
// TODO: fix sequence indexing
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
const auto & seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
@@ -1058,7 +1047,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
@@ -1105,27 +1094,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - a single float per sequence
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
}
|
||||
embd_seq_out[seq_id].resize(1);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
@@ -1136,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
} while (mstate->next());
|
||||
} while (mctx->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
@@ -1145,7 +1134,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
if (n_outputs > 0) {
|
||||
bool sorted_output = true;
|
||||
|
||||
auto & out_ids = mstate->out_ids();
|
||||
auto & out_ids = balloc->get_out_ids();
|
||||
|
||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
||||
|
||||
@@ -1222,9 +1211,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
const auto n_vocab = vocab.n_tokens();
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: use a per-batch flag for logits presence instead
|
||||
bool has_logits = !cparams.embeddings;
|
||||
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
|
||||
bool has_logits = true;
|
||||
bool has_embd = cparams.embeddings;
|
||||
|
||||
// TODO: hacky enc-dec support
|
||||
if (model.arch == LLM_ARCH_T5) {
|
||||
@@ -1303,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
|
||||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
if (n_tokens % n_seqs != 0) {
|
||||
@@ -1319,11 +1307,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
|
||||
this->n_outputs = n_outputs;
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||
|
||||
this->n_outputs = save_n_outputs;
|
||||
|
||||
@@ -1344,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_state_i * mstate) {
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_context_i * mctx) {
|
||||
return model.build_graph(
|
||||
{
|
||||
/*.ctx =*/ ctx,
|
||||
@@ -1360,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mstate =*/ mstate,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
@@ -2040,19 +2028,21 @@ void llama_context::opt_epoch_iter(
|
||||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
const auto n_tokens_all = batch.n_tokens;
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
||||
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
uint32_t n_outputs_all = n_tokens_all;
|
||||
|
||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
|
||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
||||
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
}
|
||||
@@ -2065,17 +2055,17 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
uint32_t pos_batch = 0;
|
||||
do {
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
if (!mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||
if (!mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
@@ -2110,7 +2100,7 @@ void llama_context::opt_epoch_iter(
|
||||
ggml_free(ctx_compute_opt);
|
||||
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (mstate->next());
|
||||
} while (mctx->next());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class llama_io_read_i;
|
||||
class llama_io_write_i;
|
||||
|
||||
struct llama_memory_i;
|
||||
struct llama_memory_state_i;
|
||||
struct llama_memory_context_i;
|
||||
|
||||
struct llama_context {
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
@@ -93,14 +93,14 @@ struct llama_context {
|
||||
int32_t il_end);
|
||||
|
||||
// process a single ubatch with a specific graph type
|
||||
// if memory_state is provided, it will be applied first to the context's memory
|
||||
// if memory_context is provided, it will be applied first to the context's memory
|
||||
// ret contains the status of the graph computation
|
||||
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
||||
llm_graph_result_ptr process_ubatch(
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
llama_memory_state_i * mstate,
|
||||
ggml_status & ret);
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
llama_memory_context_i * mctx,
|
||||
ggml_status & ret);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
@@ -197,15 +197,15 @@ public:
|
||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||
|
||||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate);
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||
|
||||
private:
|
||||
llm_graph_result_ptr graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_state_i * mstate);
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_context_i * mctx);
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
@@ -247,7 +247,7 @@ private:
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
// reuse the batch_allocr to avoid unnecessary memory allocations
|
||||
std::unique_ptr<llama_batch_allocr> batch_allocr;
|
||||
std::unique_ptr<llama_batch_allocr> balloc;
|
||||
|
||||
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@@ -86,41 +87,33 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
||||
if (pos_bucket) {
|
||||
kv_state->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
||||
if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
|
||||
GGML_ASSERT(out_ids);
|
||||
|
||||
if (!out_ids) {
|
||||
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
||||
int32_t * data = (int32_t *) out_ids->data;
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
||||
int32_t * data = (int32_t *) out_ids->data;
|
||||
|
||||
if (n_outputs == n_tokens) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
} else if (ubatch->output) {
|
||||
int32_t n_outputs = 0;
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (ubatch->output[i]) {
|
||||
data[n_outputs++] = i;
|
||||
}
|
||||
}
|
||||
// the graph needs to have been passed the correct number of outputs
|
||||
GGML_ASSERT(n_outputs == n_outputs);
|
||||
} else if (n_outputs == 1) {
|
||||
// only keep last output
|
||||
data[0] = n_tokens - 1;
|
||||
} else {
|
||||
GGML_ASSERT(n_outputs == 0);
|
||||
}
|
||||
if (n_outputs == n_tokens) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
data[i] = i;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ubatch->output);
|
||||
|
||||
int n_outputs = 0;
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (ubatch->output[i]) {
|
||||
data[n_outputs++] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -129,127 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
||||
|
||||
GGML_ASSERT(mean);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
||||
|
||||
float * data = (float *) mean->data;
|
||||
memset(mean->data, 0, n_tokens * n_tokens * ggml_element_size(mean));
|
||||
memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
|
||||
|
||||
std::vector<uint64_t> sum(n_tokens, 0);
|
||||
std::vector<uint64_t> sums(n_seqs_unq, 0);
|
||||
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
|
||||
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
||||
|
||||
sum[seq_id] += ubatch->n_seq_tokens;
|
||||
}
|
||||
|
||||
std::vector<float> div(n_tokens, 0.0f);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const uint64_t s = sum[i];
|
||||
if (s > 0) {
|
||||
div[i] = 1.0f/float(s);
|
||||
sums[seq_idx] += ubatch->n_seq_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
std::vector<float> div(n_seqs_unq, 0.0f);
|
||||
for (int s = 0; s < n_seqs_unq; ++s) {
|
||||
const uint64_t sum = sums[s];
|
||||
if (sum > 0) {
|
||||
div[s] = 1.0f/float(sum);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
|
||||
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
data[seq_idx*n_tokens + i + j] = div[seq_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
||||
if (cparams.embeddings && (
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
||||
|
||||
if (cparams.embeddings && (
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
|
||||
)) {
|
||||
GGML_ASSERT(cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) cls->data;
|
||||
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
||||
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
||||
|
||||
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
||||
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
|
||||
|
||||
if (pos == 0) {
|
||||
data[seq_id] = s*n_seq_tokens + i;
|
||||
}
|
||||
data[seq_idx] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
|
||||
GGML_ASSERT(cls);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
||||
|
||||
uint32_t * data = (uint32_t *) cls->data;
|
||||
memset(cls->data, 0, n_tokens * ggml_element_size(cls));
|
||||
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
||||
|
||||
std::vector<int> last_pos(n_tokens, -1);
|
||||
std::vector<int> last_row(n_tokens, -1);
|
||||
std::vector<int> last_pos(n_seqs_unq, -1);
|
||||
std::vector<int> last_row(n_seqs_unq, -1);
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const llama_pos pos = ubatch->pos[i];
|
||||
|
||||
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
||||
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const llama_pos pos = ubatch->pos[s*n_seq_tokens + i];
|
||||
|
||||
if (pos >= last_pos[seq_id]) {
|
||||
last_pos[seq_id] = pos;
|
||||
last_row[seq_id] = s*n_seq_tokens + i;
|
||||
if (pos >= last_pos[seq_idx]) {
|
||||
last_pos[seq_idx] = pos;
|
||||
last_row[seq_idx] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
if (last_row[i] >= 0) {
|
||||
data[i] = last_row[i];
|
||||
for (int s = 0; s < n_seqs_unq; ++s) {
|
||||
if (last_row[s] >= 0) {
|
||||
data[s] = last_row[s];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
const int64_t n_rs = mctx->get_n_rs();
|
||||
|
||||
if (s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||
int32_t * data = (int32_t *) s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_state->s_copy(i);
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -265,89 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
if (kq_mask) {
|
||||
if (cparams.causal_attn) {
|
||||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
float * data = (float *) kq_mask->data;
|
||||
GGML_ASSERT(kq_mask);
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
||||
float * data = (float *) kq_mask->data;
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
||||
float f = -INFINITY;
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
|
||||
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
||||
}
|
||||
// TODO: reimplement this like in llama_kv_cache_unified
|
||||
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch->n_seqs;
|
||||
const int64_t n_stride = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
||||
|
||||
float * data = (float *) kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
||||
|
||||
for (int j = 0; j < n_seq_tokens; ++j) {
|
||||
const int32_t tj = s1*n_seq_tokens + j;
|
||||
|
||||
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
||||
for (int i = 0; i < n_seq_tokens; ++i) {
|
||||
const int32_t ti = s0*n_seq_tokens + i;
|
||||
float f = -INFINITY;
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
||||
if (ubatch->seq_id[s0][s] == seq_id) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < n_stride; ++i) {
|
||||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,51 +282,71 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (self_kq_mask_swa) {
|
||||
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
if (cross_kq_mask) {
|
||||
const int64_t n_enc = cross_kq_mask->ne[0];
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
GGML_ASSERT(cross_kq_mask);
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
||||
const int64_t n_enc = cross_kq_mask->ne[0];
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_enc; ++i) {
|
||||
float f = -INFINITY;
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int s = 0; s < ubatch->n_seq_id[j]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[j][s];
|
||||
if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
float * data = (float *) cross_kq_mask->data;
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
float f = -INFINITY;
|
||||
|
||||
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
||||
|
||||
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
||||
f = 0.0f;
|
||||
}
|
||||
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (int j = 0; j < n_enc; ++j) {
|
||||
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||
int32_t * data = (int32_t *) s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mctx->get_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,16 +389,12 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||
backend_cpu (params.backend_cpu),
|
||||
cvec (params.cvec),
|
||||
loras (params.loras),
|
||||
mstate (params.mstate),
|
||||
mctx (params.mctx),
|
||||
cross (params.cross),
|
||||
cb_func (params.cb),
|
||||
res (std::make_unique<llm_graph_result>()) {
|
||||
}
|
||||
|
||||
int64_t llm_graph_context::n_pos_per_embd() const {
|
||||
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
||||
}
|
||||
|
||||
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
||||
if (cb_func) {
|
||||
cb_func(ubatch, cur, name, il);
|
||||
@@ -896,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
||||
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
||||
|
||||
auto & cur = inp->pos;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
@@ -923,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
||||
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
||||
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
||||
// features that require constant topology such as pipline parallelism
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
||||
//if (n_outputs < n_tokens) {
|
||||
// return nullptr;
|
||||
//}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
||||
|
||||
auto & cur = inp->out_ids;
|
||||
@@ -940,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
||||
|
||||
auto & cur = inp->mean;
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
@@ -953,24 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
||||
|
||||
auto & cur = inp->cls;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
@@ -1016,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state);
|
||||
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
|
||||
auto & cur = inp->pos_bucket;
|
||||
|
||||
@@ -1047,6 +981,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
||||
return pos_bias;
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
||||
|
||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||
ggml_set_input(inp->s_copy);
|
||||
}
|
||||
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q,
|
||||
@@ -1222,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state);
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
@@ -1259,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
@@ -1291,36 +1252,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_cgraph * gf,
|
||||
@@ -1339,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base();
|
||||
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
@@ -1430,20 +1361,99 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto rs_zero = kv_state->get_rs_z();
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
uint32_t n_kv,
|
||||
uint32_t kv_head,
|
||||
uint32_t kv_size,
|
||||
int32_t rs_zero,
|
||||
bool avoid_copies) const {
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
||||
|
||||
// Clear a single state which will then be copied to the other cleared states.
|
||||
// Note that this is a no-op when the view is zero-sized.
|
||||
@@ -1474,22 +1484,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
||||
return output_states;
|
||||
}
|
||||
|
||||
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
||||
|
||||
const auto n_rs = mctx_cur->get_n_rs();
|
||||
|
||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||
ggml_set_input(inp->s_copy);
|
||||
|
||||
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies) const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies) const {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||
|
||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
||||
|
||||
ggml_tensor * token_shift = build_recurrent_state(
|
||||
gf, token_shift_all, state_copy,
|
||||
hparams.n_embd_k_s(), n_seqs);
|
||||
ggml_tensor * token_shift = build_rs(
|
||||
inp, gf, token_shift_all,
|
||||
hparams.n_embd_r(), n_seqs);
|
||||
|
||||
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
||||
|
||||
@@ -1500,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto kv_head = mctx_cur->get_head();
|
||||
|
||||
return ggml_cpy(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
||||
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -17,11 +17,12 @@ struct ggml_tensor;
|
||||
struct llama_ubatch;
|
||||
struct llama_cparams;
|
||||
|
||||
struct llama_memory_state_i;
|
||||
struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_unified_state;
|
||||
class llama_kv_cache_unified_iswa_state;
|
||||
class llama_kv_cache_recurrent_state;
|
||||
class llama_kv_cache_unified_context;
|
||||
class llama_kv_cache_unified_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@@ -94,14 +95,14 @@ public:
|
||||
|
||||
class llm_graph_input_pos : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
||||
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
||||
virtual ~llm_graph_input_pos() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
||||
|
||||
const int64_t n_pos_per_embd = 1;
|
||||
const uint32_t n_pos_per_embd = 1;
|
||||
};
|
||||
|
||||
// temperature tuning, used by llama4
|
||||
@@ -135,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
||||
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@@ -143,7 +144,8 @@ public:
|
||||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_kv_cache_unified_state * kv_state;
|
||||
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
@@ -188,16 +190,16 @@ public:
|
||||
const llama_cparams & cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
class llm_graph_input_rs : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
||||
virtual ~llm_graph_input_rs() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
const llama_memory_recurrent_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
@@ -237,10 +239,10 @@ public:
|
||||
llm_graph_input_attn_kv_unified(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_state * kv_state) :
|
||||
const llama_kv_cache_unified_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_state(kv_state) {
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
|
||||
@@ -254,7 +256,7 @@ public:
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified_state * kv_state;
|
||||
const llama_kv_cache_unified_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
@@ -262,10 +264,10 @@ public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa_state * kv_state) :
|
||||
const llama_kv_cache_unified_iswa_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_state(kv_state) {
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
|
||||
@@ -282,7 +284,7 @@ public:
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa_state * kv_state;
|
||||
const llama_kv_cache_unified_iswa_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
@@ -300,6 +302,33 @@ public:
|
||||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_memory_hybrid_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
@@ -373,10 +402,10 @@ struct llm_graph_params {
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_state_i * mstate;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
@@ -425,10 +454,10 @@ struct llm_graph_context {
|
||||
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_state_i * mstate;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
@@ -436,8 +465,6 @@ struct llm_graph_context {
|
||||
|
||||
llm_graph_context(const llm_graph_params & params);
|
||||
|
||||
int64_t n_pos_per_embd() const;
|
||||
|
||||
void cb(ggml_tensor * cur, const char * name, int il) const;
|
||||
|
||||
//
|
||||
@@ -508,13 +535,14 @@ struct llm_graph_context {
|
||||
ggml_tensor * build_inp_out_ids() const;
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
ggml_tensor * build_inp_pos_bucket_dec() const;
|
||||
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
||||
|
||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||
|
||||
//
|
||||
// attention
|
||||
//
|
||||
@@ -589,22 +617,62 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
//
|
||||
// recurrent
|
||||
//
|
||||
|
||||
ggml_tensor * build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
// TODO: avoid notion of "kv"
|
||||
// TODO: move this implementation to llama_memory_recurrent.
|
||||
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||
// `llama_memory_recurrent`
|
||||
ggml_tensor * build_rs(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
uint32_t n_kv,
|
||||
uint32_t kv_head,
|
||||
uint32_t kv_size,
|
||||
int32_t rs_zero,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
llm_graph_input_rs * build_rs_inp() const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_store(
|
||||
|
||||
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
||||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
uint32_t llama_hparams::n_embd_r() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return token_shift_count * n_embd;
|
||||
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
||||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_s() const {
|
||||
uint32_t llama_hparams::n_embd_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
||||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_pos_per_embd() const {
|
||||
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return swa_layers[il];
|
||||
|
||||
@@ -115,6 +115,9 @@ struct llama_hparams {
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
@@ -181,10 +184,15 @@ struct llama_hparams {
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_k_s() const;
|
||||
uint32_t n_embd_r() const;
|
||||
|
||||
// dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_v_s() const;
|
||||
uint32_t n_embd_s() const;
|
||||
|
||||
// whether or not the given layer is recurrent (for hybrid models)
|
||||
bool is_recurrent(uint32_t il) const;
|
||||
|
||||
uint32_t n_pos_per_embd() const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
};
|
||||
|
||||
@@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_swa->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
// first try simple split
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_simple(n_ubatch);
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
auto ubatch = sbatch.split_simple(n_ubatch);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
@@ -122,20 +125,23 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// if it fails, try equal split
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
auto ubatch = sbatch.split_equal(n_ubatch);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads_base = kv_base->prepare(ubatches);
|
||||
@@ -150,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
||||
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
// TODO: if we fail again, we should attempt different splitting strategies
|
||||
// but to do that properly, we first have to refactor the batches to be more flexible
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
||||
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
@@ -191,52 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_state
|
||||
// llama_kv_cache_unified_iswa_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_full();
|
||||
state_swa = kv->get_swa ()->init_full();
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv) :
|
||||
ctx_base(kv->get_base()->init_full()),
|
||||
ctx_swa (kv->get_swa ()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
bool optimize) :
|
||||
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
||||
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_iswa_state::next() {
|
||||
bool llama_kv_cache_unified_iswa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
state_base->next();
|
||||
state_swa ->next();
|
||||
ctx_base->next();
|
||||
ctx_swa ->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
@@ -245,41 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa_state::apply() {
|
||||
bool llama_kv_cache_unified_iswa_context::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & state_base->apply();
|
||||
res = res & state_swa ->apply();
|
||||
res = res & ctx_base->apply();
|
||||
res = res & ctx_swa ->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
||||
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
||||
}
|
||||
|
||||
@@ -31,14 +31,14 @@ public:
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
@@ -72,62 +72,57 @@ private:
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
|
||||
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_unified_iswa_state(llama_memory_status status);
|
||||
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv);
|
||||
|
||||
// used to create an update state
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_unified_iswa_state(
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_unified_iswa_context(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_iswa_state();
|
||||
virtual ~llama_kv_cache_unified_iswa_context();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_iswa_state specific API
|
||||
// llama_kv_cache_unified_iswa_context specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_state * get_base() const;
|
||||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
const llama_kv_cache_unified_context * get_base() const;
|
||||
const llama_kv_cache_unified_context * get_swa() const;
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
llama_memory_state_ptr state_base;
|
||||
llama_memory_state_ptr state_swa;
|
||||
const llama_memory_context_ptr ctx_base;
|
||||
const llama_memory_context_ptr ctx_swa;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
|
||||
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
const char * dev_name = "CPU";
|
||||
|
||||
@@ -307,18 +307,24 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||
const llama_batch & batch,
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
do {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (sbatch.n_tokens > 0) {
|
||||
ubatches.push_back(sbatch.split_simple(n_ubatch));
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_simple(n_ubatch);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
auto heads = prepare(ubatches);
|
||||
@@ -326,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
||||
break;
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(
|
||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||
return std::make_unique<llama_kv_cache_unified_context>(
|
||||
this, std::move(heads), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this);
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||
bool do_shift = get_has_shift();
|
||||
|
||||
defrag_info dinfo;
|
||||
@@ -367,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
||||
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
||||
}
|
||||
|
||||
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
@@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
||||
if (debug > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
|
||||
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
|
||||
}
|
||||
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
@@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
||||
seq_pos_max_rm[s] = -1;
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
||||
for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
|
||||
const uint32_t idx = s*ubatch.n_seq_tokens + j;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
if (!cells.is_empty(head_cur + i)) {
|
||||
assert(cells.seq_count(head_cur + i) == 1);
|
||||
|
||||
if (!cells.is_empty(head_cur + idx)) {
|
||||
assert(cells.seq_count(head_cur + idx) == 1);
|
||||
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
||||
const llama_pos pos = cells.pos_get(head_cur + i);
|
||||
|
||||
const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
|
||||
const llama_pos pos = cells.pos_get(head_cur + idx);
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
cells.rm(head_cur + i);
|
||||
}
|
||||
|
||||
cells.rm(head_cur + idx);
|
||||
}
|
||||
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
||||
|
||||
cells.pos_set(head_cur + idx, ubatch.pos[idx]);
|
||||
|
||||
// TODO: fix indexing [UBATCH_IDX]
|
||||
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
|
||||
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
|
||||
}
|
||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
||||
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||
}
|
||||
}
|
||||
|
||||
// move the head at the end of the slot
|
||||
head = head_cur + ubatch.n_tokens;
|
||||
}
|
||||
@@ -792,9 +788,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const uint32_t n_seqs = ubatch->n_seqs;
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
float * data = (float *) dst->data;
|
||||
@@ -814,52 +808,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
||||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_tokens; ++j) {
|
||||
const uint32_t idx = s*n_seq_tokens + j;
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
const llama_pos p1 = ubatch->pos[idx];
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
float f = 0.0f;
|
||||
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
float f = 0.0f;
|
||||
bool masked = false;
|
||||
|
||||
bool masked = false;
|
||||
if (cells.is_empty(j)) {
|
||||
masked = true;
|
||||
} else {
|
||||
const llama_pos p0 = cells.pos_get(j);
|
||||
|
||||
if (cells.is_empty(i)) {
|
||||
masked = true;
|
||||
} else {
|
||||
const llama_pos p0 = cells.pos_get(i);
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells.seq_has(j, seq_id));
|
||||
|
||||
// mask the token if not the same sequence
|
||||
masked = masked || (!cells.seq_has(i, seq_id));
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
|
||||
// mask future tokens
|
||||
masked = masked || (causal_attn && p0 > p1);
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
|
||||
// apply SWA if any
|
||||
masked = masked || (is_masked_swa(p0, p1));
|
||||
|
||||
if (!masked && hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
if (!masked && hparams.use_alibi) {
|
||||
f = -std::abs(p0 - p1);
|
||||
}
|
||||
|
||||
if (masked) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
|
||||
}
|
||||
|
||||
if (masked) {
|
||||
f = -INFINITY;
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
|
||||
}
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
|
||||
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
||||
const int32_t n_kv = dst->ne[0];
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
for (int j = 0; j < n_kv; ++j) {
|
||||
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
||||
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
||||
const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
|
||||
|
||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1430,7 +1420,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)layer.k->type;
|
||||
@@ -1452,7 +1442,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)layer.v->type;
|
||||
@@ -1476,7 +1466,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)layer.v->type;
|
||||
@@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
||||
|
||||
seq_rm(dest_seq_id, -1, -1);
|
||||
|
||||
llama_sbatch sbatch;
|
||||
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
||||
|
||||
ubatch.n_tokens = cell_count;
|
||||
ubatch.n_seq_tokens = cell_count;
|
||||
ubatch.n_seqs = 1;
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@@ -1621,7 +1608,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
@@ -1651,7 +1638,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -1681,7 +1668,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
@@ -1723,37 +1710,36 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_state
|
||||
// llama_kv_cache_unified_context
|
||||
//
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||
n_kv = kv->get_size();
|
||||
head = 0;
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
||||
if (!do_shift && dinfo.empty()) {
|
||||
if (!do_shift && this->dinfo.empty()) {
|
||||
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
llama_kv_cache_unified::ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
||||
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
||||
|
||||
bool llama_kv_cache_unified_state::next() {
|
||||
bool llama_kv_cache_unified_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
@@ -1763,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_state::apply() {
|
||||
bool llama_kv_cache_unified_context::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
// no ubatches -> this is a KV cache update
|
||||
@@ -1781,51 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
||||
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const {
|
||||
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified_state::get_n_kv() const {
|
||||
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
||||
return n_kv;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_k(ctx, il, n_kv);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_v(ctx, il, n_kv);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, il, head);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
||||
return kv->cpy_v(ctx, v_cur, il, head);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const {
|
||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
kv->set_input_k_shift(dst);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
|
||||
@@ -56,14 +56,14 @@ public:
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
@@ -208,49 +208,46 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_state(llama_memory_status status);
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_unified_state(
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv);
|
||||
|
||||
// used to create an update state
|
||||
llama_kv_cache_unified_state(
|
||||
// used to create an update context
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
|
||||
// used to create a decode state from a batch
|
||||
llama_kv_cache_unified_state(
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
llama_kv_cache_unified * kv,
|
||||
llama_sbatch sbatch,
|
||||
ubatch_heads heads,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_unified_state();
|
||||
virtual ~llama_kv_cache_unified_context();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_unified_state specific API
|
||||
// llama_kv_cache_unified_context specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
@@ -275,7 +272,7 @@ private:
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
// update state
|
||||
// update context
|
||||
//
|
||||
|
||||
bool do_shift = false;
|
||||
@@ -283,11 +280,9 @@ private:
|
||||
defrag_info dinfo;
|
||||
|
||||
//
|
||||
// batch processing state
|
||||
// batch processing context
|
||||
//
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
@@ -164,7 +165,7 @@ public:
|
||||
assert(seq_id >= 0);
|
||||
|
||||
seq[i].reset(seq_id);
|
||||
seq_pos[seq_id].erase(pos[i]);
|
||||
seq_pos_dec(seq_id, pos[i]);
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
@@ -187,7 +188,7 @@ public:
|
||||
seq[i].reset();
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos[seq_id].insert(pos[i]);
|
||||
seq_pos_inc(seq_id, pos[i]);
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -232,7 +233,7 @@ public:
|
||||
assert(!seq[i].test(seq_id));
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos[seq_id].insert(pos[i]);
|
||||
seq_pos_inc(seq_id, pos[i]);
|
||||
}
|
||||
|
||||
// return the sequence id of this cell
|
||||
@@ -259,7 +260,9 @@ public:
|
||||
return -1;
|
||||
}
|
||||
|
||||
return *seq_pos[seq_id].begin();
|
||||
assert(seq_pos[seq_id].begin()->second > 0);
|
||||
|
||||
return seq_pos[seq_id].begin()->first;
|
||||
}
|
||||
|
||||
// the maximum position of sequence seq_id currently present in any of the cells
|
||||
@@ -272,7 +275,9 @@ public:
|
||||
return -1;
|
||||
}
|
||||
|
||||
return *seq_pos[seq_id].rbegin();
|
||||
assert(seq_pos[seq_id].rbegin()->second > 0);
|
||||
|
||||
return seq_pos[seq_id].rbegin()->first;
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
@@ -384,22 +389,41 @@ private:
|
||||
//
|
||||
std::vector<llama_pos> shift;
|
||||
|
||||
using bits_t = std::bitset<LLAMA_MAX_SEQ>;
|
||||
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
||||
|
||||
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||
std::vector<bits_t> seq;
|
||||
std::vector<seq_set_t> seq;
|
||||
|
||||
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
||||
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
||||
// if the position p is not present, seq_pos[s][p] is not set
|
||||
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
||||
//
|
||||
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
||||
// - during performing a cache reuse via (rm + add)
|
||||
// - some vision models have input embeddings with repeating positions
|
||||
//
|
||||
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
||||
|
||||
// helper functions for updating `seq_pos`, once cell at a time:
|
||||
|
||||
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
||||
auto it = seq_pos[s].find(p);
|
||||
assert(it != seq_pos[s].end());
|
||||
|
||||
if (--it->second == 0) {
|
||||
seq_pos[s].erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
||||
seq_pos[s][p]++;
|
||||
}
|
||||
|
||||
// remove cell i
|
||||
void seq_pos_rm(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos[s].erase(pos[i]);
|
||||
seq_pos_dec(s, pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -408,7 +432,7 @@ private:
|
||||
void seq_pos_add(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos[s].insert(pos[i]);
|
||||
seq_pos_inc(s, pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
246
src/llama-memory-hybrid.cpp
Normal file
246
src/llama-memory-hybrid.cpp
Normal file
@@ -0,0 +1,246 @@
|
||||
#include "llama-memory-hybrid.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
//
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
llama_memory_hybrid::llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_unified(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||
: filter_recr,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max
|
||||
)) {}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
// follow the recurrent pattern for creating the ubatch splits
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
// prepare the attention cache
|
||||
auto heads_attn = mem_attn->prepare(ubatches);
|
||||
if (heads_attn.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_context>(
|
||||
this, std::move(heads_attn), std::move(ubatches));
|
||||
} while(false);
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
||||
return std::make_unique<llama_memory_hybrid_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid::get_can_shift() const {
|
||||
// Shifting is trivially supported for recurrent
|
||||
return mem_attn->get_can_shift();
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::clear(bool data) {
|
||||
mem_attn->clear(data);
|
||||
mem_recr->clear(data);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// Try removing from the recurrent cache first since it may fail. If it does
|
||||
// fail, the cache will not have been mutated.
|
||||
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||
return false;
|
||||
}
|
||||
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
||||
mem_attn->seq_keep(seq_id);
|
||||
mem_recr->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the min of the total cache is the max of the two caches' min values
|
||||
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
||||
// the max of the total cache is the min of the two caches' max values
|
||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
mem_attn->state_write(io, seq_id);
|
||||
mem_recr->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
mem_attn->state_read(io, seq_id);
|
||||
mem_recr->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
||||
return mem_recr.get();
|
||||
}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
||||
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
||||
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_attn->next();
|
||||
ctx_recr->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_context::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & ctx_attn->apply();
|
||||
res = res & ctx_recr->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
||||
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
||||
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||
}
|
||||
138
src/llama-memory-hybrid.h
Normal file
138
src/llama-memory-hybrid.h
Normal file
@@ -0,0 +1,138 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
||||
// support models where each layer may be either attention-based or recurrent
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn = nullptr,
|
||||
layer_filter_cb && filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_context : public llama_memory_context_i {
|
||||
public:
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_context(llama_memory_status status);
|
||||
|
||||
// init full
|
||||
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
|
||||
|
||||
// init update
|
||||
explicit llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// init success
|
||||
llama_memory_hybrid_context(
|
||||
llama_memory_hybrid * mem,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_context() = default;
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_context
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_context * get_attn() const;
|
||||
const llama_memory_recurrent_context * get_recr() const;
|
||||
|
||||
private:
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_context_ptr ctx_attn;
|
||||
const llama_memory_context_ptr ctx_recr;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
@@ -12,27 +12,28 @@
|
||||
#include <stdexcept>
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
// llama_memory_recurrent
|
||||
//
|
||||
|
||||
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
llama_memory_recurrent::llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
||||
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
||||
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
||||
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
||||
|
||||
head = 0;
|
||||
size = kv_size;
|
||||
size = mem_size;
|
||||
used = 0;
|
||||
|
||||
cells.clear();
|
||||
cells.resize(kv_size);
|
||||
cells.resize(mem_size);
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
@@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
return it->second;
|
||||
};
|
||||
|
||||
k_l.reserve(n_layer);
|
||||
v_l.reserve(n_layer);
|
||||
r_l.resize(n_layer);
|
||||
s_l.resize(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
if (filter && !filter(i)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * dev_name = "CPU";
|
||||
|
||||
@@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
throw std::runtime_error("failed to create ggml context for kv cache");
|
||||
}
|
||||
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
k_l.push_back(k);
|
||||
v_l.push_back(v);
|
||||
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
||||
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
|
||||
ggml_format_name(r, "cache_r_l%d", i);
|
||||
ggml_format_name(s, "cache_s_l%d", i);
|
||||
r_l[i] = r;
|
||||
s_l[i] = s;
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
@@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
}
|
||||
|
||||
{
|
||||
const size_t memory_size_k = size_k_bytes();
|
||||
const size_t memory_size_v = size_v_bytes();
|
||||
const size_t memory_size_r = size_r_bytes();
|
||||
const size_t memory_size_s = size_s_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::clear(bool data) {
|
||||
void llama_memory_recurrent::clear(bool data) {
|
||||
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
@@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
uint32_t new_head = size;
|
||||
|
||||
if (p0 < 0) {
|
||||
@@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
||||
if (0 <= seq_id) {
|
||||
int32_t & tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
const kv_cell & cell = cells[tail_id];
|
||||
const auto & cell = cells[tail_id];
|
||||
// partial intersection is invalid
|
||||
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
||||
return false;
|
||||
@@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
if (seq_id_src == seq_id_dst) {
|
||||
return;
|
||||
}
|
||||
@@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
||||
}
|
||||
|
||||
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
||||
kv_cell & tail_src = cells[seq_id_src];
|
||||
kv_cell & tail_dst = cells[seq_id_dst];
|
||||
auto & tail_src = cells[seq_id_src];
|
||||
auto & tail_dst = cells[seq_id_dst];
|
||||
if (tail_dst.tail >= 0) {
|
||||
// clear destination seq_id if it wasn't empty
|
||||
kv_cell & cell_dst = cells[tail_dst.tail];
|
||||
auto & cell_dst = cells[tail_dst.tail];
|
||||
|
||||
cell_dst.seq_id.erase(seq_id_dst);
|
||||
tail_dst.tail = -1;
|
||||
@@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
||||
}
|
||||
}
|
||||
if (tail_src.tail >= 0) {
|
||||
kv_cell & cell_src = cells[tail_src.tail];
|
||||
auto & cell_src = cells[tail_src.tail];
|
||||
|
||||
cell_src.seq_id.insert(seq_id_dst);
|
||||
tail_dst.tail = tail_src.tail;
|
||||
@@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||
uint32_t new_head = size;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
@@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
||||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||
const int32_t tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
kv_cell & cell = cells[tail_id];
|
||||
auto & cell = cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos += shift;
|
||||
}
|
||||
@@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
if (d == 1) {
|
||||
return;
|
||||
}
|
||||
@@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
||||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||
const int32_t tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
kv_cell & cell = cells[tail_id];
|
||||
auto & cell = cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos /= d;
|
||||
}
|
||||
@@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
||||
}
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
@@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
@@ -359,45 +362,45 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||
GGML_UNUSED(embd_pooled);
|
||||
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
|
||||
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
while (true) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
if (embd_all) {
|
||||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (!prepare(ubatches)) {
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
|
||||
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|
||||
return std::make_unique<llama_memory_recurrent_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
GGML_UNUSED(lctx);
|
||||
GGML_UNUSED(optimize);
|
||||
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
// simply remember the full state because it is very small for this type of cache
|
||||
// TODO: optimize
|
||||
auto org_cells = cells;
|
||||
@@ -421,10 +424,9 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
||||
return success;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
@@ -444,9 +446,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// everything should fit if all seq_ids are smaller than the max
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const uint32_t n_seq_id = ubatch.n_seq_id[s];
|
||||
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
|
||||
const uint32_t n_seq_id = ubatch.n_seq_id[i];
|
||||
|
||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
||||
|
||||
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
||||
// too big seq_id
|
||||
@@ -455,9 +459,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
return false;
|
||||
}
|
||||
if (j > 0) {
|
||||
kv_cell & seq = cells[seq_id];
|
||||
auto & seq = cells[seq_id];
|
||||
if (seq.tail >= 0) {
|
||||
kv_cell & cell = cells[seq.tail];
|
||||
auto & cell = cells[seq.tail];
|
||||
// clear cells from seq_ids that become shared
|
||||
// (should not normally happen, but let's handle it anyway)
|
||||
cell.seq_id.erase(seq_id);
|
||||
@@ -477,7 +481,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
std::vector<int32_t> tails_verif;
|
||||
tails_verif.assign(size, -1);
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
auto & cell = cells[i];
|
||||
for (llama_seq_id seq_id : cell.seq_id) {
|
||||
if (tails_verif[seq_id] != -1) {
|
||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
||||
@@ -498,28 +502,29 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
kv_cell & cell = cells[next_empty_cell];
|
||||
auto & cell = cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
next_empty_cell += 1;
|
||||
}
|
||||
|
||||
// find usable cell range
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
kv_cell & seq_meta = cells[seq_id];
|
||||
const uint32_t i = s*n_seq_tokens;
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
auto & seq_meta = cells[seq_id];
|
||||
bool has_cell = false;
|
||||
if (seq_meta.tail >= 0) {
|
||||
kv_cell & cell = cells[seq_meta.tail];
|
||||
auto & cell = cells[seq_meta.tail];
|
||||
GGML_ASSERT(cell.has_seq_id(seq_id));
|
||||
// does this seq_id "own" the cell?
|
||||
if (cell.seq_id.size() == 1) { has_cell = true; }
|
||||
}
|
||||
if (!has_cell) {
|
||||
kv_cell & empty_cell = cells[next_empty_cell];
|
||||
auto & empty_cell = cells[next_empty_cell];
|
||||
GGML_ASSERT(empty_cell.is_empty());
|
||||
// copy old tail into the empty cell
|
||||
if (seq_meta.tail >= 0) {
|
||||
kv_cell & orig_cell = cells[seq_meta.tail];
|
||||
auto & orig_cell = cells[seq_meta.tail];
|
||||
empty_cell.pos = orig_cell.pos;
|
||||
empty_cell.src = orig_cell.src;
|
||||
orig_cell.seq_id.erase(seq_id);
|
||||
@@ -529,10 +534,10 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
seq_meta.tail = next_empty_cell;
|
||||
// find next empty cell
|
||||
if (s + 1 < n_seqs) {
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
next_empty_cell += 1;
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
kv_cell & cell = cells[next_empty_cell];
|
||||
auto & cell = cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
}
|
||||
}
|
||||
@@ -543,19 +548,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// gather and re-order
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const uint32_t i = s*n_seq_tokens;
|
||||
const int32_t dst_id = s + min;
|
||||
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
|
||||
if (dst_id != src_id) {
|
||||
kv_cell & dst_cell = cells[dst_id];
|
||||
kv_cell & src_cell = cells[src_id];
|
||||
auto & dst_cell = cells[dst_id];
|
||||
auto & src_cell = cells[src_id];
|
||||
|
||||
std::swap(dst_cell.pos, src_cell.pos);
|
||||
std::swap(dst_cell.src, src_cell.src);
|
||||
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
||||
|
||||
// swap tails
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
int32_t & tail = cells[i].tail;
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
int32_t & tail = cells[j].tail;
|
||||
if (tail == src_id) {
|
||||
tail = dst_id;
|
||||
} else if (tail == dst_id) {
|
||||
@@ -567,20 +573,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
|
||||
// update the pos of the used seqs
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
const uint32_t i = s*n_seq_tokens;
|
||||
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
|
||||
const int32_t cell_id = s + min;
|
||||
kv_cell & cell = cells[cell_id];
|
||||
auto & cell = cells[cell_id];
|
||||
|
||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
||||
__func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
|
||||
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
|
||||
}
|
||||
cell.pos = last_pos;
|
||||
cell.seq_id.clear();
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][j];
|
||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
||||
cell.seq_id.insert(seq_id);
|
||||
cells[seq_id].tail = cell_id;
|
||||
}
|
||||
@@ -622,18 +629,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
head = min;
|
||||
n = max - min + 1;
|
||||
used = std::count_if(cells.begin(), cells.end(),
|
||||
[](const kv_cell & cell){ return !cell.is_empty(); });
|
||||
[](const mem_cell & cell){ return !cell.is_empty(); });
|
||||
|
||||
// sanity check
|
||||
return n >= n_seqs;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::get_can_shift() const {
|
||||
bool llama_memory_recurrent::get_can_shift() const {
|
||||
// shifting the pos is trivial for recurrent models
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::total_size() const {
|
||||
size_t llama_memory_recurrent::total_size() const {
|
||||
size_t size = 0;
|
||||
for (const auto & buf : bufs) {
|
||||
size += ggml_backend_buffer_get_size(buf.get());
|
||||
@@ -642,27 +649,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::size_k_bytes() const {
|
||||
size_t size_k_bytes = 0;
|
||||
size_t llama_memory_recurrent::size_r_bytes() const {
|
||||
size_t size_r_bytes = 0;
|
||||
|
||||
for (const auto & k : k_l) {
|
||||
size_k_bytes += ggml_nbytes(k);
|
||||
for (const auto & r : r_l) {
|
||||
if (r != nullptr) {
|
||||
size_r_bytes += ggml_nbytes(r);
|
||||
}
|
||||
}
|
||||
|
||||
return size_k_bytes;
|
||||
return size_r_bytes;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::size_v_bytes() const {
|
||||
size_t size_v_bytes = 0;
|
||||
size_t llama_memory_recurrent::size_s_bytes() const {
|
||||
size_t size_s_bytes = 0;
|
||||
|
||||
for (const auto & v : v_l) {
|
||||
size_v_bytes += ggml_nbytes(v);
|
||||
for (const auto & s : s_l) {
|
||||
if (s != nullptr) {
|
||||
size_s_bytes += ggml_nbytes(s);
|
||||
}
|
||||
}
|
||||
|
||||
return size_v_bytes;
|
||||
return size_s_bytes;
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
|
||||
@@ -700,7 +711,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
|
||||
state_write_data(io, cell_ranges);
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
|
||||
@@ -719,7 +730,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
@@ -738,98 +749,93 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||
const uint32_t v_trans = 0;
|
||||
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||
const uint32_t s_trans = 0;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
io.write(&v_trans, sizeof(v_trans));
|
||||
io.write(&n_layer, sizeof(n_layer));
|
||||
io.write(&s_trans, sizeof(s_trans));
|
||||
io.write(&n_layer, sizeof(n_layer));
|
||||
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
|
||||
// Iterate and write all the keys first, each row is a cell
|
||||
// Get whole range at a time
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
||||
io.write(&k_type_i, sizeof(k_type_i));
|
||||
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
||||
io.write(&r_type_i, sizeof(r_type_i));
|
||||
|
||||
// Write row size of key
|
||||
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
||||
io.write(&k_size_row, sizeof(k_size_row));
|
||||
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
io.write(&r_size_row, sizeof(r_size_row));
|
||||
|
||||
// Read each range of cells of k_size length each into tmp_buf and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * k_size_row;
|
||||
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
||||
const size_t buf_size = range_size * r_size_row;
|
||||
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (!v_trans) {
|
||||
if (!s_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
io.write(&v_type_i, sizeof(v_type_i));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
io.write(&s_type_i, sizeof(s_type_i));
|
||||
|
||||
// Write row size of value
|
||||
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
||||
io.write(&v_size_row, sizeof(v_size_row));
|
||||
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
io.write(&s_size_row, sizeof(s_size_row));
|
||||
|
||||
// Read each range of cells of v_size length each into tmp_buf and write out
|
||||
// Read each range of cells of s_size length each into tmp_buf and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * v_size_row;
|
||||
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
||||
const size_t buf_size = range_size * s_size_row;
|
||||
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = size;
|
||||
const uint32_t mem_size = size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_s = hparams.n_embd_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
io.write(&v_type_i, sizeof(v_type_i));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
io.write(&s_type_i, sizeof(s_type_i));
|
||||
|
||||
// Write element size
|
||||
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
|
||||
io.write(&v_size_el, sizeof(v_size_el));
|
||||
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
io.write(&s_size_el, sizeof(s_size_el));
|
||||
|
||||
// Write GQA embedding size
|
||||
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
||||
io.write(&n_embd_s, sizeof(n_embd_s));
|
||||
|
||||
// For each row, we get the element values of each cell
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
||||
const size_t buf_size = range_size * v_size_el;
|
||||
io.write_tensor(v_l[il], src_offset, buf_size);
|
||||
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
||||
const size_t buf_size = range_size * s_size_el;
|
||||
io.write_tensor(s_l[il], src_offset, buf_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
|
||||
seq_rm(dest_seq_id, -1, -1);
|
||||
|
||||
llama_sbatch sbatch;
|
||||
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
||||
|
||||
batch.n_tokens = cell_count;
|
||||
batch.n_seq_tokens = cell_count;
|
||||
batch.n_seqs = 1;
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
llama_pos pos;
|
||||
@@ -843,12 +849,12 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||
return false;
|
||||
}
|
||||
|
||||
batch.pos[i] = pos;
|
||||
ubatch.pos[i] = pos;
|
||||
}
|
||||
batch.n_seq_id[0] = 1;
|
||||
batch.seq_id[0] = &dest_seq_id;
|
||||
ubatch.n_seq_id[0] = 1;
|
||||
ubatch.seq_id[0] = &dest_seq_id;
|
||||
|
||||
if (!find_slot(batch)) {
|
||||
if (!find_slot(ubatch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
@@ -856,8 +862,8 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||
// Assume that this is one contiguous block of cells
|
||||
GGML_ASSERT(head + cell_count <= size);
|
||||
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
||||
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
|
||||
GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
|
||||
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
||||
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
||||
} else {
|
||||
@@ -871,7 +877,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||
clear(true);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
auto & cell = cells[i];
|
||||
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
@@ -885,7 +891,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
|
||||
// TODO: llama_kv_cache_recurrent should have a notion of max sequences
|
||||
// TODO: llama_memory_recurrent should have a notion of max sequences
|
||||
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||
if (seq_id < 0) {
|
||||
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||
@@ -917,10 +923,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t v_trans;
|
||||
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t s_trans;
|
||||
uint32_t n_layer;
|
||||
io.read_to(&v_trans, sizeof(v_trans));
|
||||
io.read_to(&s_trans, sizeof(s_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != hparams.n_layer) {
|
||||
@@ -931,102 +937,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
||||
return false;
|
||||
}
|
||||
if (false != (bool) v_trans) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
||||
if (false != (bool) s_trans) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
const int32_t k_type_i = (int32_t) k_l[il]->type;
|
||||
if (k_type_i != k_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||
int32_t r_type_i_ref;
|
||||
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||
if (r_type_i != r_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read row size of key
|
||||
uint64_t k_size_row_ref;
|
||||
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
||||
if (k_size_row != k_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
||||
uint64_t r_size_row_ref;
|
||||
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
||||
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
||||
}
|
||||
}
|
||||
|
||||
if (!v_trans) {
|
||||
if (!s_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read row size of value
|
||||
uint64_t v_size_row_ref;
|
||||
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
||||
if (v_size_row != v_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
||||
uint64_t s_size_row_ref;
|
||||
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_s = hparams.n_embd_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read element size of value
|
||||
uint32_t v_size_el_ref;
|
||||
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
||||
if (v_size_el != v_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
||||
uint32_t s_size_el_ref;
|
||||
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
if (s_size_el != s_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read GQA embedding size
|
||||
uint32_t n_embd_v_gqa_ref;
|
||||
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
||||
// Read state embedding size
|
||||
uint32_t n_embd_s_ref;
|
||||
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
if (n_embd_s != n_embd_s_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * v_size_el;
|
||||
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * s_size_el;
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1036,25 +1040,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state
|
||||
// llama_memory_recurrent_context
|
||||
//
|
||||
|
||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
|
||||
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
||||
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
||||
}
|
||||
|
||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
||||
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||
llama_memory_recurrent * mem,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
||||
|
||||
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
|
||||
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
||||
|
||||
bool llama_kv_cache_recurrent_state::next() {
|
||||
bool llama_memory_recurrent_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
@@ -1064,54 +1065,48 @@ bool llama_kv_cache_recurrent_state::next() {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent_state::apply() {
|
||||
bool llama_memory_recurrent_context::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
kv->find_slot(ubatches[i_next]);
|
||||
mem->find_slot(ubatches[i_next]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
|
||||
llama_memory_status llama_memory_recurrent_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
|
||||
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
|
||||
return is_full ? kv->size : kv->n;
|
||||
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
||||
return is_full ? mem->size : mem->n;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_head() const {
|
||||
return is_full ? 0 : kv->head;
|
||||
uint32_t llama_memory_recurrent_context::get_head() const {
|
||||
return is_full ? 0 : mem->head;
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
|
||||
return is_full ? 0 : kv->rs_z;
|
||||
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
||||
return is_full ? 0 : mem->rs_z;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
||||
return kv->size;
|
||||
uint32_t llama_memory_recurrent_context::get_size() const {
|
||||
return mem->size;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
|
||||
return kv->k_l[il];
|
||||
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
||||
return mem->r_l[il];
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
|
||||
return kv->v_l[il];
|
||||
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||
return mem->s_l[il];
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
||||
return kv->cells[i + kv->head].src0;
|
||||
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
||||
return mem->cells[i + mem->head].src0;
|
||||
}
|
||||
@@ -8,35 +8,40 @@
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
// llama_memory_recurrent
|
||||
//
|
||||
|
||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
||||
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
|
||||
class llama_memory_recurrent : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
~llama_memory_recurrent() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
@@ -51,7 +56,7 @@ public:
|
||||
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||
// find a contiguous slot of memory cells and emplace the ubatch there
|
||||
bool find_slot(const llama_ubatch & ubatch);
|
||||
|
||||
bool get_can_shift() const override;
|
||||
@@ -72,7 +77,7 @@ public:
|
||||
int32_t rs_z = -1;
|
||||
|
||||
// TODO: optimize for recurrent state needs
|
||||
struct kv_cell {
|
||||
struct mem_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to know where states should be copied from
|
||||
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
||||
@@ -88,15 +93,16 @@ public:
|
||||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
bool is_same_seq(const mem_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
std::vector<mem_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
// per layer
|
||||
std::vector<ggml_tensor *> r_l;
|
||||
std::vector<ggml_tensor *> s_l;
|
||||
|
||||
private:
|
||||
//const llama_model & model;
|
||||
@@ -109,8 +115,8 @@ private:
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
size_t size_r_bytes() const;
|
||||
size_t size_s_bytes() const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
@@ -119,57 +125,50 @@ private:
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
||||
class llama_memory_recurrent_context : public llama_memory_context_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_recurrent_state(llama_memory_status status);
|
||||
llama_memory_recurrent_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv);
|
||||
// used to create a full-cache or update context
|
||||
llama_memory_recurrent_context(
|
||||
llama_memory_recurrent * mem);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_sbatch sbatch,
|
||||
// used to create a batch processing context from a batch
|
||||
llama_memory_recurrent_context(
|
||||
llama_memory_recurrent * mem,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_recurrent_state();
|
||||
virtual ~llama_memory_recurrent_context();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state specific API
|
||||
// llama_memory_recurrent_context specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_n_rs() const;
|
||||
uint32_t get_head() const;
|
||||
int32_t get_rs_z() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
ggml_tensor * get_k_l(int32_t il) const;
|
||||
ggml_tensor * get_v_l(int32_t il) const;
|
||||
ggml_tensor * get_r_l(int32_t il) const;
|
||||
ggml_tensor * get_s_l(int32_t il) const;
|
||||
|
||||
int32_t s_copy(int i) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_recurrent * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
llama_memory_recurrent * mem;
|
||||
|
||||
size_t i_next = 0;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user