Compare commits

...

17 Commits
b7775 ... b7792

Author SHA1 Message Date
Jeff Bolz
33f890e579 vulkan: support flash attention GQA/split_k with small batches (#18938) 2026-01-21 17:43:43 +01:00
Masato Nakasaka
067b8d7af3 Revert "vulkan: force full subgroups for flash attention to fix intel subgroup crash (#17356)" (#18831)
This reverts commit 980b7cd17e.
2026-01-21 17:13:43 +01:00
Jeff Bolz
50b7f076a5 vulkan: Use mul_mat_vec_id for small values of n (#18918)
Change ggml_vk_mul_mat_vec_id_q_f16 to loop over the batch dimension and
update the indexing calculations in get_offsets.

Mat-vec is faster than mat-mat for small values of n. We don't get the same
reuse of the weights as in the non-ID path, but with this the cost is linear
in n rather than n>1 being far slower than n==1.
2026-01-21 16:22:02 +01:00
Tarek Dakhran
ad8d85bd94 memory : add llama_memory_hybrid_iswa (#18601)
* memory : add llama_memory_hybrid_iswa

* Update src/llama-memory-hybrid-iswa.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-01-21 14:30:23 +02:00
Piotr Wilkin (ilintar)
12a4a47e6a Fix GLM 4.7 Lite MoE gating func (#18980)
* Fix GLM 4.7 MoE gating func

* Update src/models/deepseek2.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-01-21 12:35:20 +01:00
Matthieu Coudron
37c35f0e1c gguf: display strerrno when cant load a model (#18884)
I've had issues loading models with llama-server:
[44039] E gguf_init_from_file: failed to open GGUF file 'mistral-7b-v0.1.Q8_0.gguf'

and I was sure it could access the file. Seems like --models-dir and
--models-presets dont interact like I thought they would but I salvaged
this snippet that helps troubleshooting
[44039] E gguf_init_from_file: failed to open GGUF file 'mistral-7b-v0.1.Q8_0.gguf' (errno No such file or directory)
2026-01-21 08:52:46 +02:00
Oliver Simons
5bd341c9a1 CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator (#18964)
* CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator

Strided iterator was added in [CCCL
3.1](https://github.com/NVIDIA/cccl/releases/tag/v3.1.0), which is packaged into
[CTK
13.1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id5)

* Unindent as per code review request
2026-01-21 02:34:29 +01:00
Adrien Gallouët
1c7cf94b22 common, server : use the same User-Agent by default (#18957)
This commit also ensures that if a custom User-Agent is used, it will be
the only one sent.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-01-20 18:28:43 +01:00
Xuan-Son Nguyen
2c1f199653 cli : fix reasoning responses in CLI (#18961)
* cli : fix reasoning responses in CLI

* fix build

* fix build (2)
2026-01-20 18:23:25 +01:00
Oliver Simons
d1e3556481 CUDA: Replace init_offsets kernel with iterators in cub-based argsort (#18930)
* CUDA: Replace `init_offsets` with iterators in argsort

This is a QOL improvement, saving us the cost of materializing the
iterator

* Remove unnecessary include from top-k.cu
2026-01-20 20:11:01 +08:00
Adrien Gallouët
08f3f4a8a3 ggml : cleanup path_str() (#18928)
- Remove pragmas as `std::codecvt_utf8` is not used.
- Avoid implicit `strlen()`.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-01-20 11:42:49 +01:00
Georgi Gerganov
271191906c metal : enable FA for MLA heads (#18950) 2026-01-20 12:21:28 +02:00
Daniel Bevenius
7dee9ff59a convert : use n_groups instead of hardcoded values in reshape (#18929)
* convert : use n_groups instead of hardcoded values in reshape

This commit modifies the conversion script for NemotronHModel to use
the 'n_groups' hyperparameter, and allow Python to calculate the the
last dimension, using -1, when reshaping the 'mixer.norm.weight' tensor.

* use self.n_group instead of self.hparams["n_groups"]
2026-01-20 06:55:24 +01:00
Xuan-Son Nguyen
6df686bee6 server : refactor oai_parser_opt, move it to server_chat_params (#18937)
* server_chat_params

* move chat format into CLI

* use meta whenever possible

* clean up, no more chatml fallback
2026-01-19 23:28:01 +01:00
ddh0
1706a6d7c6 convert : support Glm4MoeLite (#18936)
* initial commit for branch

* add glm-4.7-flash, move tokenizer hash

* use `glm4` pretok

* silence flake8 E302 (CI)

* apply review feedback

* add <|user|> as eog

* also add EOG `<|observation|>`

* revert llama-vocab

* inherit vocab from glm4

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-01-19 23:09:20 +01:00
Sigbjørn Skjæret
959ecf7f23 jinja : fix undefined keys and attributes and int/float as bool (#18924)
* fix undefined keys and attributes

* add falsy tests

* as_bool for integers and floats

* more falsy/truthy tests

* --typo
2026-01-19 20:29:43 +01:00
Sigbjørn Skjæret
4037093c66 ci : run test-jinja -py on high perf [no ci] (#18916) 2026-01-19 20:29:15 +01:00
45 changed files with 1425 additions and 658 deletions

View File

@@ -254,7 +254,7 @@ function gg_run_ctest_release {
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
if [ -z ${GG_BUILD_LOW_PERF} ]; then
(time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log
(time ctest --output-on-failure -L 'main|python' ) 2>&1 | tee -a $OUT/${ci}-ctest.log
else
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
fi

View File

@@ -129,7 +129,7 @@ static void parse_json_tool_calls(
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
result_.role = "assistant";
@@ -1611,7 +1611,7 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
@@ -1635,7 +1635,7 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
return msg;
}
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
if (parser.empty()) {
throw std::runtime_error("Failed to parse due to missing parser definition.");
}

View File

@@ -5,7 +5,7 @@
#include "json-partial.h"
#include "regex-partial.h"
#include <nlohmann/json.hpp>
#include <nlohmann/json_fwd.hpp>
#include <optional>
#include <string>
@@ -19,20 +19,20 @@ class common_chat_msg_partial_exception : public std::runtime_error {
class common_chat_msg_parser {
std::string input_;
bool is_partial_;
common_chat_syntax syntax_;
common_chat_parser_params syntax_; // TODO: rename to params
std::string healing_marker_;
size_t pos_ = 0;
common_chat_msg result_;
public:
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
const common_chat_syntax & syntax() const { return syntax_; }
const common_chat_parser_params & syntax() const { return syntax_; }
void move_to(size_t pos) {
if (pos > input_.size()) {

View File

@@ -601,18 +601,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
if (!variant.empty()) {
if (variant == "tool_use") {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
return tmpls->template_tool_use->source();
}
return nullptr;
return "";
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
}
}
return tmpls->template_default->source().c_str();
return tmpls->template_default->source();
}
common_chat_templates_ptr common_chat_templates_init(

View File

@@ -145,7 +145,7 @@ struct common_chat_templates_inputs {
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
bool enable_thinking = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
std::map<std::string, std::string> chat_template_kwargs;
@@ -165,14 +165,21 @@ struct common_chat_params {
std::string parser;
};
struct common_chat_syntax {
// per-message parsing syntax
// should be derived from common_chat_params
struct common_chat_parser_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
bool parse_tool_calls = true;
common_peg_arena parser = {};
common_chat_parser_params() = default;
common_chat_parser_params(const common_chat_params & chat_params) {
format = chat_params.format;
thinking_forced_open = chat_params.thinking_forced_open;
}
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
@@ -191,7 +198,7 @@ common_chat_templates_ptr common_chat_templates_init(
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
struct common_chat_params common_chat_templates_apply(
@@ -213,10 +220,12 @@ std::string common_chat_format_example(
const std::map<std::string, std::string> & chat_template_kwargs);
const char* common_chat_format_name(common_chat_format format);
const char* common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
// used by arg and server
const char * common_reasoning_format_name(common_reasoning_format format);
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);

View File

@@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT;
extern const char * LLAMA_COMPILER;
extern const char * LLAMA_BUILD_TARGET;
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
struct common_control_vector_load_info;
//
@@ -284,6 +286,7 @@ struct common_params_diffusion {
};
// reasoning API response format (not to be confused as chat template's reasoning format)
// only used by server
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`

View File

@@ -314,23 +314,26 @@ static bool common_pull_file(httplib::Client & cli,
// download one single file from remote URL to local path
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
auto [cli, parts] = common_http_client(url);
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
if (!bearer_token.empty()) {
default_headers.insert({"Authorization", "Bearer " + bearer_token});
}
httplib::Headers headers;
for (const auto & h : custom_headers) {
default_headers.emplace(h.first, h.second);
headers.emplace(h.first, h.second);
}
cli.set_default_headers(default_headers);
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (!bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + bearer_token);
}
cli.set_default_headers(headers);
const bool file_exists = std::filesystem::exists(path);
@@ -437,10 +440,12 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
const common_remote_params & params) {
auto [cli, parts] = common_http_client(url);
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
for (const auto & header : params.headers) {
headers.emplace(header.first, header.second);
httplib::Headers headers;
for (const auto & h : params.headers) {
headers.emplace(h.first, h.second);
}
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (params.timeout > 0) {

View File

@@ -805,7 +805,7 @@ value member_expression::execute_impl(context & ctx) {
} else if (is_val<value_string>(property)) {
auto key = property->as_string().str();
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
val = try_builtin_func(ctx, key, object);
val = try_builtin_func(ctx, key, object, true);
} else {
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
}
@@ -814,7 +814,7 @@ value member_expression::execute_impl(context & ctx) {
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
}
auto key = property->as_string().str();
val = try_builtin_func(ctx, key, object);
val = try_builtin_func(ctx, key, object, true);
}
if (ctx.is_get_stats && val && object && property) {

View File

@@ -203,6 +203,9 @@ struct value_int_t : public value_t {
virtual int64_t as_int() const override { return val_int; }
virtual double as_float() const override { return static_cast<double>(val_int); }
virtual string as_string() const override { return std::to_string(val_int); }
virtual bool as_bool() const override {
return val_int != 0;
}
virtual const func_builtins & get_builtins() const override;
};
using value_int = std::shared_ptr<value_int_t>;
@@ -219,6 +222,9 @@ struct value_float_t : public value_t {
if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
return out;
}
virtual bool as_bool() const override {
return val_flt != 0.0;
}
virtual const func_builtins & get_builtins() const override;
};
using value_float = std::shared_ptr<value_float_t>;

View File

@@ -1,5 +1,6 @@
#pragma once
// TODO: use json_fwd.hpp when possible
#include <nlohmann/json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).

View File

@@ -1078,6 +1078,9 @@ class TextModel(ModelBase):
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267":
# ref: https://huggingface.co/zai-org/GLM-4.7-Flash
res = "glm4"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
@@ -7458,7 +7461,7 @@ class DeepseekModel(TextModel):
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration"
"YoutuVLForConditionalGeneration",
)
class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@@ -8446,6 +8449,32 @@ class Glm4MoeModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("Glm4MoeLiteForCausalLM")
class Glm4MoeLiteModel(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
# copied from Glm4MoeModel
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
# Special tokens
# Note: Using <|endoftext|> (151329) for eot causes endless generation
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
special_vocab.add_to_gguf(self.gguf_writer)
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(TextModel):
model_arch = gguf.MODEL_ARCH.CHATGLM
@@ -9183,7 +9212,7 @@ class NemotronHModel(GraniteHybridModel):
return [(mapped_name, reshaped_data)]
if name.endswith("mixer.norm.weight"):
reshaped_data = data_torch.reshape(8, 512)
reshaped_data = data_torch.reshape(self.n_group, -1)
mapped_name = self.map_tensor_name(name)
return [(mapped_name, reshaped_data)]

View File

@@ -170,6 +170,7 @@ pre_computed_hashes = [
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
]

View File

@@ -77,39 +77,23 @@
#include "ggml-zendnn.h"
#endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
namespace fs = std::filesystem;
static std::string path_str(const fs::path & path) {
std::string u8path;
try {
#if defined(__cpp_lib_char8_t)
// C++20 and later: u8string() returns std::u8string
std::u8string u8str = path.u8string();
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
const std::u8string u8str = path.u8string();
return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());
#else
// C++17: u8string() returns std::string
u8path = path.u8string();
return path.u8string();
#endif
} catch (...) {
return std::string();
}
return u8path;
}
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;

View File

@@ -2,6 +2,9 @@
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
# define STRIDED_ITERATOR_AVAILABLE
# endif
using namespace cub;
#endif // GGML_CUDA_USE_CUB
@@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
}
}
#ifndef STRIDED_ITERATOR_AVAILABLE
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx <= nrows) {
offsets[idx] = idx * ncols;
}
}
#endif // STRIDED_ITERATOR_AVAILABLE
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
@@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * temp_indices = temp_indices_alloc.get();
float * temp_keys = temp_keys_alloc.get();
int * d_offsets = offsets_alloc.get();
static const int block_size = 256;
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
#ifdef STRIDED_ITERATOR_AVAILABLE
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
#else
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
int * offset_iterator = offsets_alloc.get();
const dim3 offset_grid((nrows + block_size - 1) / block_size);
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
#endif
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
size_t temp_storage_bytes = 0;
@@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
offset_iterator, offset_iterator + 1, stream);
}
} else {
if (nrows == 1) {
@@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
stream);
}
}
@@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
}
} else {
if (nrows == 1) {
@@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
offset_iterator + 1, stream);
}
}
}

View File

@@ -4,7 +4,6 @@
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
# include <cuda/iterator>
# define CUB_TOP_K_AVAILABLE
using namespace cub;
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2

View File

@@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek sizes
// TODO: disabled for now, until optmized
op->src[0]->ne[0] != 256 &&
op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {

View File

@@ -2520,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
int32_t nsg = 4;
int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);

View File

@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
constexpr short NC = (C/8)/NSG;
// note: do not unroll for large heads
#pragma unroll (DK <= 64 ? NC : 1)
for (short cc = 0; cc < NC; ++cc) {
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
if (DK % 16 != 0) {
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
k8x8_t mk[2];
q8x8_t mq[2];
FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
// note: too much unroll can tank the performance for large heads
#pragma unroll (MIN(DK8/2, 4*NSG))
for (short i = 0; i < DK8/2; ++i) {
simdgroup_barrier(mem_flags::mem_none);
simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
pv += 8*NS20;
}
} else {
FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
constexpr short NC = (C/8)/2;
FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
s8x8_t vs[2];
simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS

View File

@@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants {
uint32_t fusion_flags;
uint32_t nei0;
uint32_t ne11;
uint32_t expert_i1;
uint32_t nbi1;
};
struct vk_flash_attn_push_constants {
@@ -1516,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants {
uint32_t num_blocks;
};
struct vk_op_flash_attn_split_k_reduce_push_constants {
uint32_t D;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t k_num;
uint32_t sinks;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -3178,15 +3189,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \
} \
} \
@@ -3980,7 +3991,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
if (device->subgroup_clustered && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
@@ -8083,8 +8094,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1];
GGML_ASSERT(nei1 == 1);
const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
@@ -8168,7 +8178,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -8226,7 +8236,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
uint32_t stride_batch_y = ne10*ne11;
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
}
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
@@ -8262,23 +8272,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
}
// compute
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
fusion_flags,
(uint32_t)nei0, (uint32_t)ne11,
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
d_ids,
},
pc, { groups_x, (uint32_t)nei0, groups_z });
// Loop over the batch dimension
for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
const vk_mat_vec_id_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
(uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
fusion_flags,
(uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
};
ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
{
d_X,
d_Y,
d_D,
d_F0,
d_F1,
d_ids,
},
pc, { groups_x, (uint32_t)nei0, groups_z });
}
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
@@ -8292,7 +8304,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
ggml_tensor * dst = cgraph->nodes[node_idx];
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src2 = dst->src[2];
return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
}
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -8454,14 +8466,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_ASSERT(0);
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
workgroups_y /= N;
workgroups_y /= gqa_ratio;
}
bool small_rows = N <= get_fa_num_small_rows(path);
@@ -8523,6 +8535,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
assert(pipeline);
// Compile early to initialize wg_denoms.
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
uint32_t split_kv = KV;
uint32_t split_k = 1;
@@ -8530,22 +8544,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0) {
// Try to use split_k when KV is large enough to be worth the overhead.
// Must either be a single batch or be using gqa, we can't mix the two.
if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
// Try to run two workgroups per SM.
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
workgroups_x = split_k;
}
}
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
// For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
GGML_ABORT("Requested preallocation size is too large");
}
@@ -8556,7 +8572,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
{
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
}
@@ -8605,7 +8620,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
workgroups_x *= pipeline->wg_denoms[0];
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
@@ -8613,15 +8628,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
ggml_vk_sync_buffers(ctx, subctx);
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{split_k_buf, sinks_buf, dst_buf},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
ctx->prealloc_split_k_need_sync = true;
} else {
if (gqa_ratio > 1) {
// When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
workgroups_x *= pipeline->wg_denoms[0];
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
pc, { workgroups_x, workgroups_y, workgroups_z });

View File

@@ -53,7 +53,7 @@ void main() {
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
@@ -101,9 +101,9 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@@ -320,7 +320,8 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
@@ -332,7 +333,7 @@ void main() {
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -378,7 +379,7 @@ void main() {
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {

View File

@@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
@@ -173,12 +173,19 @@ void init_indices()
N = p.N;
KV = p.KV;
i = gl_WorkGroupID.x;
split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
// batch and split_k share gl_WorkGroupID.x
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
split_k_index = gl_WorkGroupID.x % p.k_num;
} else if (p.gqa_ratio > 1) {
i = 0;
gqa_iq1 = gl_WorkGroupID.x;
split_k_index = 0;
} else {
i = gl_WorkGroupID.x;
gqa_iq1 = 0;
split_k_index = 0;
}
Tr = CEIL_DIV(N, Br);

View File

@@ -90,7 +90,7 @@ void main() {
barrier();
}
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
@@ -141,9 +141,9 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
@@ -370,7 +370,8 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
@@ -382,7 +383,7 @@ void main() {
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
@@ -428,7 +429,7 @@ void main() {
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {

View File

@@ -111,7 +111,7 @@ void main() {
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
@@ -138,9 +138,9 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
uint32_t m_offset = 0;
uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
@@ -272,10 +272,11 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
// note: O and Q have swapped coord 1,2.
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
@@ -325,7 +326,7 @@ void main() {
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {

View File

@@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne1;
uint ne2;
uint ne3;
uint k_num;
uint sinks;
@@ -24,15 +25,15 @@ void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.z % p.ne2;
const uint i3 = gl_WorkGroupID.z / p.ne2;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
uint lm_stride = p.ne1 * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
@@ -99,7 +100,7 @@ void main() {
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
@@ -115,6 +116,6 @@ void main() {
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
data_d[iq3 * D * N + D * n + d] = O;
data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
}
}

View File

@@ -29,6 +29,8 @@ layout (push_constant) uniform parameter
#ifdef MUL_MAT_ID
uint nei0;
uint ne11;
uint expert_i1;
uint nbi1;
#else
uint ne02;
uint ne12;
@@ -43,7 +45,7 @@ uint expert_id;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.y;
const uint expert_i0 = gl_GlobalInvocationID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
#endif
@@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
batch_idx_a = i03 * p.ne02 + i02;
}
#else
expert_id = data_ids[expert_idx];
expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
#endif
a_offset =
@@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#endif
b_offset =
#ifdef MUL_MAT_ID
(expert_idx % p.ne11) * p.stride_b;
(expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
#else
batch_idx * p.batch_stride_b;
#endif
d_offset =
#ifdef MUL_MAT_ID
expert_idx * p.stride_d;
expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
#else
batch_idx * p.batch_stride_d;
#endif
@@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
}
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
const uint expert_idx = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
const uint expert_i0 = gl_GlobalInvocationID.y;
tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
}
#else
if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {

View File

@@ -734,7 +734,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
FILE * file = ggml_fopen(fname, "rb");
if (!file) {
GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname);
GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno));
return nullptr;
}

View File

@@ -24,6 +24,7 @@ add_library(llama
llama-kv-cache-iswa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp
llama-memory-recurrent.cpp
llama-mmap.cpp
llama-model-loader.cpp

View File

@@ -7,6 +7,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
#include <cassert>
@@ -510,6 +511,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
return res;
}
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
const auto * attn_ctx = mctx->get_attn();
// base tensors may not be allocated if there are no non-SWA attention layers
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
}
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
int32_t * data = (int32_t *) inp_rs->s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mctx->get_recr()->s_copy(i);
}
}
}
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
const auto * attn_ctx = mctx->get_attn();
// base tensors may not be allocated if there are no non-SWA attention layers
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
}
// swa tensors may not be allocated if there are no SWA attention layers
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
}
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
res &= inp_rs->head == mctx->get_recr()->get_head();
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
return res;
}
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
// set the inputs only for the active samplers in the current ubatch
std::unordered_set<llama_seq_id> active_samplers;
@@ -2056,6 +2127,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
// build iswa attention input
const auto * attn_ctx = mctx_cur->get_attn();
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = attn_ctx->get_base()->get_n_kv();
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
{
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
}
void llm_graph_context::build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const {

View File

@@ -24,6 +24,7 @@ class llama_kv_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
class llama_memory_hybrid_iswa_context;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -397,6 +398,34 @@ public:
const llama_memory_hybrid_context * mctx;
};
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid_iswa(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_iswa_context * mctx) :
inp_attn(std::move(inp_attn)),
inp_rs(std::move(inp_rs)),
cparams(cparams),
mctx(mctx) { }
virtual ~llm_graph_input_mem_hybrid_iswa() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
const llama_cparams cparams;
const llama_memory_hybrid_iswa_context * mctx;
};
class llm_graph_input_sampling : public llm_graph_input_i {
public:
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
@@ -881,6 +910,8 @@ struct llm_graph_context {
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
//
// pooling
//

View File

@@ -0,0 +1,275 @@
#include "llama-memory-hybrid-iswa.h"
#include "llama-impl.h"
#include "llama-model.h"
#include "llama-context.h"
//
// llama_memory_hybrid_iswa
//
llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool swa_full,
uint32_t kv_size,
uint32_t n_ubatch,
uint32_t n_pad,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache_iswa(
model,
type_k,
type_v,
v_trans,
offload,
swa_full,
unified,
kv_size,
n_seq_max,
n_ubatch,
n_pad,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_recr(new llama_memory_recurrent(
model,
type_r,
type_s,
offload,
rs_size,
n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {}
llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do {
balloc.split_reset();
// follow the recurrent pattern for creating the ubatch splits
std::vector<llama_ubatch> ubatches;
while (true) {
llama_ubatch ubatch;
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
// prepare the attention cache (iswa version returns both base and swa slot infos)
auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
if (sinfos_base.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
if (sinfos_swa.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_hybrid_iswa_context>(
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
} while(false);
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
return std::make_unique<llama_memory_hybrid_iswa_context>(this);
}
llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
}
bool llama_memory_hybrid_iswa::get_can_shift() const {
// Shifting is trivially supported for recurrent
return mem_attn->get_can_shift();
}
void llama_memory_hybrid_iswa::clear(bool data) {
mem_attn->clear(data);
mem_recr->clear(data);
}
bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
// Try removing from the recurrent cache first since it may fail. If it does
// fail, the cache will not have been mutated.
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
return false;
}
return mem_attn->seq_rm(seq_id, p0, p1);
}
void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
mem_attn->seq_keep(seq_id);
mem_recr->seq_keep(seq_id);
}
void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
mem_attn->seq_add(seq_id, p0, p1, shift);
mem_recr->seq_add(seq_id, p0, p1, shift);
}
void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
mem_attn->seq_div(seq_id, p0, p1, d);
mem_recr->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
// the min of the total cache is the max of the two caches' min values
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
}
llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
// the max of the total cache is the min of the two caches' max values
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
}
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
for (const auto & buft_size : mem_recr->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
mem_attn->state_write(io, seq_id, flags);
mem_recr->state_write(io, seq_id, flags);
}
void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
mem_attn->state_read(io, seq_id, flags);
mem_recr->state_read(io, seq_id, flags);
}
llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
return mem_attn.get();
}
llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
return mem_recr.get();
}
//
// llama_memory_hybrid_iswa_context
//
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
ctx_attn(mem->get_mem_attn()->init_full()),
ctx_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
llama_context * lctx,
bool optimize) :
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
}
bool llama_memory_hybrid_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_attn->next();
ctx_recr->next();
if (++i_next >= ubatches.size()) {
return false;
}
return true;
}
bool llama_memory_hybrid_iswa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
res = res & ctx_attn->apply();
res = res & ctx_recr->apply();
return res;
}
llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
return status;
}
const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
}
const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
}

View File

@@ -0,0 +1,140 @@
#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory.h"
#include "llama-memory-recurrent.h"
#include <memory>
#include <vector>
//
// llama_memory_hybrid_iswa
//
// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
// support models where each layer may be either attention-based (with SWA support) or recurrent
class llama_memory_hybrid_iswa : public llama_memory_i {
public:
llama_memory_hybrid_iswa(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool swa_full,
uint32_t kv_size,
uint32_t n_ubatch,
uint32_t n_pad,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid_iswa() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_memory_hybrid_iswa specific API
//
llama_kv_cache_iswa * get_mem_attn() const;
llama_memory_recurrent * get_mem_recr() const;
private:
const llama_hparams & hparams;
const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
const std::unique_ptr<llama_memory_recurrent> mem_recr;
};
class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// init failure
explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
// init full
explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
// init update
explicit llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
llama_context * lctx,
bool optimize);
// init success
llama_memory_hybrid_iswa_context(
llama_memory_hybrid_iswa * mem,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_swa,
std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_iswa_context() = default;
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_memory_hybrid_iswa_context
//
const llama_kv_cache_iswa_context * get_attn() const;
const llama_memory_recurrent_context * get_recr() const;
private:
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_attn;
const llama_memory_context_ptr ctx_recr;
const llama_memory_status status;
};

View File

@@ -8,6 +8,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
#include "ggml-cpp.h"
@@ -1713,7 +1714,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
// for compatibility with existing DeepSeek V2 and V2.5 GGUFs
// that have no expert_gating_func model parameter set
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) {
// GLM 4.7 Lite
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
} else {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
}
}
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
@@ -7523,23 +7529,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
};
}
res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
// Use hybrid-iswa for hybrid models with SWA
res = new llama_memory_hybrid_iswa(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_swa_full */ params.swa_full,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_ubatch */ cparams.n_ubatch,
/* attn_n_pad */ 1,
/* recurrent_type_r */ GGML_TYPE_F32,
/* recurrent_type_s */ GGML_TYPE_F32,
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
}
} else {
llama_memory_i::layer_reuse_cb reuse = nullptr;

View File

@@ -187,6 +187,7 @@ llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp)
llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-jinja.cpp)
llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(

View File

@@ -8460,6 +8460,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {

View File

@@ -54,113 +54,109 @@ static void assert_throws(const std::function<void()> & fn, const std::string &
static void test_reasoning() {
//common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG);
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
});
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ true,
/* .thinking_forced_open = */ true,
});
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = true;
params.thinking_forced_open = true;
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, params);
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<think>Cogito</think>", builder.result().content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
const std::string variant("content_only_inline_think");
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
/* .parse_tool_calls = */ false,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
params.parse_tool_calls = false;
const std::string input = "<think>Pense</think>Bonjour";
auto msg = common_chat_parse(input, false, syntax);
auto msg = common_chat_parse(input, false, params);
assert_equals(variant, std::string("Pense"), msg.reasoning_content);
assert_equals(variant, std::string("Bonjour"), msg.content);
}
{
const std::string variant("llama_3_inline_think");
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
/* .parse_tool_calls = */ false,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_LLAMA_3_X;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
params.parse_tool_calls = false;
const std::string input = "<think>Plan</think>Réponse";
auto msg = common_chat_parse(input, false, syntax);
auto msg = common_chat_parse(input, false, params);
assert_equals(variant, std::string("Plan"), msg.reasoning_content);
assert_equals(variant, std::string("Réponse"), msg.content);
}
// Test DeepSeek V3.1 parsing - reasoning content followed by "</think>" and then regular content
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("deepseek_v3_1_reasoning_format_deepseek");
common_chat_msg_parser builder("REASONING</think>ok", /* is_partial= */ false, syntax);
common_chat_msg_parser builder("REASONING</think>ok", /* is_partial= */ false, params);
assert_equals(variant, true, builder.try_parse_reasoning("<think>", "</think>"));
assert_equals(variant, std::string("REASONING"), builder.result().reasoning_content);
assert_equals(variant, std::string("ok"), builder.consume_rest());
}
// Test DeepSeek V3.1 parsing - reasoning_format none - reasoning content followed by "</think>" and then regular content
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_NONE;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("deepseek_v3_1_reasoning_format_none");
const std::string input = "REASONING</think>ok";
auto msg = common_chat_parse(input, false, syntax);
auto msg = common_chat_parse(input, false, params);
assert_equals(variant, std::string("REASONING</think>ok"), msg.content);
assert_equals(variant, std::string(""), msg.reasoning_content);
}
@@ -256,15 +252,14 @@ static void test_deepseek_v3_1_tool_calls() {
//common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG);
// variant: happy path for when it works as the model card says it should
const std::string variant("simple");
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
params.parse_tool_calls = true;
const std::string input = "<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Tokyo\"}<tool▁call▁end><tool▁calls▁end>";
auto msg = common_chat_parse(input, false, syntax);
auto msg = common_chat_parse(input, false, params);
assert_equals<std::size_t>(variant, 1, msg.tool_calls.size());
assert_equals(variant, std::string("get_time"), msg.tool_calls[0].name);
// JSON arguments are dumped without spaces
@@ -274,16 +269,15 @@ static void test_deepseek_v3_1_tool_calls() {
// variant: simple + thinking open
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("simple_thinking");
const std::string in = "REASONING</think><tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Tokyo\"}<tool▁call▁end><tool▁calls▁end>";
auto m = common_chat_parse(in, false, syntax);
auto m = common_chat_parse(in, false, params);
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
@@ -292,16 +286,15 @@ static void test_deepseek_v3_1_tool_calls() {
}
// variant: simple + multiple tool calls
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
params.parse_tool_calls = true;
const std::string variant("simple_multiple_tool_calls");
const std::string in = "CONTENT<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Paris\"}<tool▁call▁end><tool▁call▁begin>get_weather<tool▁sep>{\"city\": \"Paris\"}<tool▁call▁end><tool▁calls▁end>";
auto m = common_chat_parse(in, false, syntax);
auto m = common_chat_parse(in, false, params);
assert_equals<std::size_t>(variant, 2, m.tool_calls.size());
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[0].arguments);
@@ -314,16 +307,15 @@ static void test_deepseek_v3_1_tool_calls() {
// variant: thinking forced open + tool call in reasoning content
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("thinking_forced_open_tool_call_in_reasoning");
const std::string in = "REASONING<tool▁calls▁begin><tool▁call▁begin>get_time2<tool▁sep>{\"city\": \"Tokyo2\"}<tool▁call▁end><tool▁calls▁end>REASONING</think><tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Tokyo\"}<tool▁call▁end><tool▁calls▁end>";
auto m = common_chat_parse(in, false, syntax);
auto m = common_chat_parse(in, false, params);
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
assert_equals(variant, std::string("get_time"), m.tool_calls[0].name);
assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments);
@@ -336,16 +328,15 @@ static void test_deepseek_v3_1_tool_calls() {
// to make tool calls in reasoning content according to the model card, but it does sometimes, so
// add the reasoning content as regular content and parse the tool calls.
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_not_partial");
const std::string in = "REASONING<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Tokyo\"}<tool▁call▁end><tool▁calls▁end>";
auto m = common_chat_parse(in, false, syntax);
auto m = common_chat_parse(in, false, params);
assert_equals(variant, std::string("REASONING"), m.content);
assert_equals(variant, std::string(""), m.reasoning_content);
assert_equals<std::size_t>(variant, 1, m.tool_calls.size());
@@ -355,16 +346,15 @@ static void test_deepseek_v3_1_tool_calls() {
// variant: thinking forced open + tool call in reasoning content + no closing think + partial
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_partial");
const std::string in = "REASONING<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Tokyo\"}<tool▁call▁end><tool▁calls▁end>";
auto m = common_chat_parse(in, /* is_partial= */ true, syntax);
auto m = common_chat_parse(in, /* is_partial= */ true, params);
assert_equals(variant, std::string("REASONING<tool▁calls▁begin><tool▁call▁begin>get_time<tool▁sep>{\"city\": \"Tokyo\"}<tool▁call▁end><tool▁calls▁end>"), m.reasoning_content);
assert_equals(variant, std::string(""), m.content);
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
@@ -372,32 +362,30 @@ static void test_deepseek_v3_1_tool_calls() {
// variant: thinking not forced open + reasoning + regular content + no tool calls
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = true;
params.parse_tool_calls = true;
const std::string variant("thinking_forced_open_reasoning_regular_content_no_tool_calls");
const std::string in = "REASONING</think>CONTENT";
auto m = common_chat_parse(in, false, syntax);
auto m = common_chat_parse(in, false, params);
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
assert_equals(variant, std::string("CONTENT"), m.content);
assert_equals(variant, std::string("REASONING"), m.reasoning_content);
}
// variant: thinking not forced open + missing reasoning + no tool calls
{
common_chat_syntax syntax = {
/* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
/* .parse_tool_calls = */ true,
};
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
params.parse_tool_calls = true;
const std::string variant("thinking_not_forced_open_missing_reasoning_no_tool_calls");
const std::string in = "CONTENT";
auto m = common_chat_parse(in, false, syntax);
auto m = common_chat_parse(in, false, params);
assert_equals<std::size_t>(variant, 0, m.tool_calls.size());
assert_equals(variant, std::string("CONTENT"), m.content);
assert_equals(variant, std::string(""), m.reasoning_content);

View File

@@ -616,15 +616,15 @@ void test_command7_parser_compare(testing & t) {
auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) {
// Original common_chat_combinator_parser taken from chat.cpp
common_chat_parser_params params;
params.format = COMMON_CHAT_FORMAT_GENERIC;
params.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
params.reasoning_in_content = false;
params.thinking_forced_open = false;
common_chat_msg_parser builder(
input,
/* .is_partial = */ need_more_input,
{
/* .format = */ COMMON_CHAT_FORMAT_GENERIC,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
}
params
);
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");

File diff suppressed because it is too large Load Diff

View File

@@ -191,6 +191,84 @@ static void test_conditionals(testing & t) {
json::object(),
"yes"
);
test_template(t, "is undefined falsy",
"{{ 'yes' if not y else 'no' }}",
json::object(),
"yes"
);
test_template(t, "is undefined attribute falsy",
"{{ 'yes' if not y.x else 'no' }}",
{{"y", true}},
"yes"
);
test_template(t, "is undefined key falsy",
"{{ 'yes' if not y['x'] else 'no' }}",
{{"y", {{}}}},
"yes"
);
test_template(t, "is empty array falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", json::array()}},
"yes"
);
test_template(t, "is empty object falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", json::object()}},
"yes"
);
test_template(t, "is empty string falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", ""}},
"yes"
);
test_template(t, "is 0 falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", 0}},
"yes"
);
test_template(t, "is 0.0 falsy",
"{{ 'yes' if not y else 'no' }}",
{{"y", 0.0}},
"yes"
);
test_template(t, "is non-empty array truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", json::array({""})}},
"yes"
);
test_template(t, "is non-empty object truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", {"x", false}}},
"yes"
);
test_template(t, "is non-empty string truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", "0"}},
"yes"
);
test_template(t, "is 1 truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", 1}},
"yes"
);
test_template(t, "is 1.0 truthy",
"{{ 'yes' if y else 'no' }}",
{{"y", 1.0}},
"yes"
);
}
static void test_loops(testing & t) {

View File

@@ -66,19 +66,25 @@ struct cli_context {
defaults.stream = true; // make sure we always use streaming mode
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
// defaults.return_progress = true; // TODO: show progress
defaults.oaicompat_chat_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
}
std::string generate_completion(result_timings & out_timings) {
server_response_reader rd = ctx_server.get_response_reader();
auto chat_params = format_chat();
{
// TODO: reduce some copies here in the future
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_input = messages; // copy
task.cli_files = input_files; // copy
task.id = rd.get_new_id();
task.index = 0;
task.params = defaults; // copy
task.cli_prompt = chat_params.prompt; // copy
task.cli_files = input_files; // copy
task.cli = true;
// chat template settings
task.params.chat_parser_params = common_chat_parser_params(chat_params);
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
rd.post_task({std::move(task)});
}
@@ -156,6 +162,25 @@ struct cli_context {
return content;
}
}
common_chat_params format_chat() {
auto meta = ctx_server.get_meta();
auto & chat_params = meta.chat_params;
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = chat_params.use_jinja;
inputs.parallel_tool_calls = false;
inputs.add_generation_prompt = true;
inputs.enable_thinking = chat_params.enable_thinking;
// Apply chat template to the list of messages
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
}
};
int main(int argc, char ** argv) {

View File

@@ -779,7 +779,6 @@ static void handle_media(
// download remote image
// TODO @ngxson : maybe make these params configurable
common_remote_params params;
params.headers.push_back({"User-Agent", "llama.cpp/" + build_info});
params.max_size = 1024 * 1024 * 10; // 10MB
params.timeout = 10; // seconds
SRV_INF("downloading image from '%s'\n", url.c_str());
@@ -831,7 +830,7 @@ static void handle_media(
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
const oaicompat_parser_options & opt,
const server_chat_params & opt,
std::vector<raw_buffer> & out_files)
{
json llama_params;
@@ -1012,7 +1011,7 @@ json oaicompat_chat_params_parse(
}
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
auto chat_params = common_chat_templates_apply(opt.tmpls.get(), inputs);
/* Append assistant prefilled message */
if (prefill_assistant_message) {

View File

@@ -13,8 +13,6 @@
#include <vector>
#include <cinttypes>
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
using json = nlohmann::ordered_json;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__)
@@ -274,25 +272,26 @@ std::vector<server_tokens> tokenize_input_prompts(
// OAI utils
//
// used by /completions endpoint
json oaicompat_completion_params_parse(const json & body);
struct oaicompat_parser_options {
// global server parameters for chat formatting / parsing
struct server_chat_params {
bool use_jinja;
bool prefill_assistant;
common_reasoning_format reasoning_format;
std::map<std::string,std::string> chat_template_kwargs;
common_chat_templates * tmpls;
std::map<std::string, std::string> chat_template_kwargs; // mapping key --> json value
common_chat_templates_ptr tmpls;
bool allow_image;
bool allow_audio;
bool enable_thinking = true;
std::string media_path;
};
// used by /completions endpoint
json oaicompat_completion_params_parse(const json & body);
// used by /chat/completions endpoint
json oaicompat_chat_params_parse(
json & body, /* openai api json semantics */
const oaicompat_parser_options & opt,
const server_chat_params & opt,
std::vector<raw_buffer> & out_files);
// convert Anthropic Messages API format to OpenAI Chat Completions API format

View File

@@ -534,8 +534,8 @@ public:
server_queue queue_tasks;
server_response queue_results;
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
// note: chat_params must not be refreshed upon existing sleeping state
server_chat_params chat_params;
~server_context_impl() {
if (!sleeping) {
@@ -688,15 +688,6 @@ private:
llama_init_dft->free_context();
}
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs);
} catch (const std::exception & e) {
SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what());
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_init(model, "chatml");
}
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
if (!is_resume) {
@@ -845,30 +836,6 @@ private:
model_name = model_path.filename().string();
}
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
SRV_INF("thinking = %d\n", enable_thinking);
oai_parser_opt = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* common_chat_templates */ chat_templates.get(),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* media_path */ params_base.media_path,
};
// print sample chat example to make it clear which template is used
// @ngxson modern templates are too long, spam the logs; printing the example is enough
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
// common_chat_templates_source(chat_templates.get()),
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
if (!is_resume) {
return init();
}
@@ -907,6 +874,42 @@ private:
}
}
// populate chat template params
{
common_chat_templates_ptr chat_templates;
try {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
LOG_INF("%s: chat template, example_format: '%s'\n", __func__,
common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str());
} catch (const std::exception & e) {
SRV_ERR("%s: chat template parsing error: %s\n", __func__, e.what());
SRV_ERR("%s: please consider disabling jinja via --no-jinja, or use a custom chat template via --chat-template\n", __func__);
SRV_ERR("%s: for example: --no-jinja --chat-template chatml\n", __func__);
return false;
}
// thinking is enabled if:
// 1. It's not explicitly disabled (reasoning_budget == 0)
// 2. The chat template supports it
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
chat_params = {
/* use_jinja */ params_base.use_jinja,
/* prefill_assistant */ params_base.prefill_assistant,
/* reasoning_format */ params_base.reasoning_format,
/* chat_template_kwargs */ params_base.default_template_kwargs,
/* tmpls */ std::move(chat_templates),
/* allow_image */ mctx ? mtmd_support_vision(mctx) : false,
/* allow_audio */ mctx ? mtmd_support_audio (mctx) : false,
/* enable_thinking */ enable_thinking,
/* media_path */ params_base.media_path,
};
}
return true;
}
@@ -1588,32 +1591,14 @@ private:
// tokenize the input if it's set by CLI, return false on error
bool tokenize_cli_input(server_task & task) {
GGML_ASSERT(task.cli_input != nullptr);
try {
auto & opt = oai_parser_opt;
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(task.cli_input);
inputs.tools = {}; // TODO
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
inputs.json_schema = ""; // TODO
inputs.grammar = ""; // TODO
inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = false;
inputs.add_generation_prompt = true;
inputs.reasoning_format = opt.reasoning_format;
inputs.enable_thinking = opt.enable_thinking;
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
// tokenize the resulting prompt
auto & prompt = chat_params.prompt;
auto & prompt = task.cli_prompt;
if (mctx != nullptr) {
task.tokens = process_mtmd_prompt(mctx, prompt, task.cli_files);
} else {
task.tokens = std::move(tokenize_input_prompts(vocab, mctx, prompt, true, true)[0]);
}
task.cli_input.clear();
task.cli_prompt.clear();
task.cli_files.clear();
} catch (const std::exception & e) {
send_error(task, std::string("Failed to format input: ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
@@ -1689,7 +1674,7 @@ private:
{
// special case: if input is provided via CLI, tokenize it first
// otherwise, no need to tokenize as it's already done inside the HTTP thread
if (task.cli_input != nullptr) {
if (task.cli) {
if (!tokenize_cli_input(task)) {
break;
}
@@ -2901,8 +2886,6 @@ server_response_reader server_context::get_response_reader() {
}
server_context_meta server_context::get_meta() const {
auto tool_use_src = common_chat_templates_source(impl->chat_templates.get(), "tool_use");
auto bos_id = llama_vocab_bos(impl->vocab);
auto eos_id = llama_vocab_eos(impl->vocab);
auto bos_token_str = bos_id != LLAMA_TOKEN_NULL ? common_token_to_piece(impl->ctx, bos_id, true) : "";
@@ -2913,14 +2896,13 @@ server_context_meta server_context::get_meta() const {
/* model_name */ impl->model_name,
/* model_path */ impl->params_base.model.path,
/* has_mtmd */ impl->mctx != nullptr,
/* has_inp_image */ impl->oai_parser_opt.allow_image,
/* has_inp_audio */ impl->oai_parser_opt.allow_audio,
/* has_inp_image */ impl->chat_params.allow_image,
/* has_inp_audio */ impl->chat_params.allow_audio,
/* json_webui_settings */ impl->json_webui_settings,
/* slot_n_ctx */ impl->get_slot_n_ctx(),
/* pooling_type */ llama_pooling_type(impl->ctx),
/* chat_template */ common_chat_templates_source(impl->chat_templates.get()),
/* chat_template_tool_use */ tool_use_src ? tool_use_src : "",
/* chat_params */ impl->chat_params,
/* bos_token_str */ bos_token_str,
/* eos_token_str */ eos_token_str,
@@ -3202,8 +3184,8 @@ void server_routes::init_routes() {
// this endpoint can be accessed during sleeping
// the next LOC is to avoid someone accidentally use ctx_server
bool server_ctx; // do NOT delete this line
GGML_UNUSED(server_ctx);
bool ctx_server; // do NOT delete this line
GGML_UNUSED(ctx_server);
res->ok({{"status", "ok"}});
return res;
@@ -3393,8 +3375,8 @@ void server_routes::init_routes() {
// this endpoint can be accessed during sleeping
// the next LOC is to avoid someone accidentally use ctx_server
bool server_ctx; // do NOT delete this line
GGML_UNUSED(server_ctx);
bool ctx_server; // do NOT delete this line
GGML_UNUSED(ctx_server);
task_params tparams;
tparams.sampling = params.sampling;
@@ -3403,6 +3385,9 @@ void server_routes::init_routes() {
{ "n_ctx", meta->slot_n_ctx },
};
std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), "");
std::string tmpl_tools = common_chat_templates_source(meta->chat_params.tmpls.get(), "tool_use");
json props = {
{ "default_generation_settings", default_generation_settings_for_props },
{ "total_slots", params.n_parallel },
@@ -3417,15 +3402,15 @@ void server_routes::init_routes() {
{ "endpoint_metrics", params.endpoint_metrics },
{ "webui", params.webui },
{ "webui_settings", meta->json_webui_settings },
{ "chat_template", meta->chat_template },
{ "chat_template", tmpl_default },
{ "bos_token", meta->bos_token_str },
{ "eos_token", meta->eos_token_str },
{ "build_info", meta->build_info },
{ "is_sleeping", queue_tasks.is_sleeping() },
};
if (params.use_jinja) {
if (!meta->chat_template_tool_use.empty()) {
props["chat_template_tool_use"] = meta->chat_template_tool_use;
if (!tmpl_tools.empty()) {
props["chat_template_tool_use"] = tmpl_tools;
}
}
res->ok(props);
@@ -3446,6 +3431,7 @@ void server_routes::init_routes() {
this->get_api_show = [this](const server_http_req &) {
auto res = create_response();
std::string tmpl_default = common_chat_templates_source(meta->chat_params.tmpls.get(), "");
json data = {
{
"model_info", {
@@ -3454,7 +3440,7 @@ void server_routes::init_routes() {
},
{"modelfile", ""},
{"parameters", ""},
{"template", meta->chat_template},
{"template", tmpl_default},
{"details", {
{"parent_model", ""},
{"format", "gguf"},
@@ -3579,7 +3565,7 @@ void server_routes::init_routes() {
json body = json::parse(req.body);
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
return handle_completions_impl(
req,
@@ -3595,7 +3581,7 @@ void server_routes::init_routes() {
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
return handle_completions_impl(
req,
@@ -3611,7 +3597,7 @@ void server_routes::init_routes() {
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
json prompt = body_parsed.at("prompt");
@@ -3627,7 +3613,7 @@ void server_routes::init_routes() {
json body = json::parse(req.body);
json data = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
meta->chat_params,
files);
res->ok({{ "prompt", std::move(data.at("prompt")) }});
return res;
@@ -3638,8 +3624,8 @@ void server_routes::init_routes() {
// this endpoint can be accessed during sleeping
// the next LOC is to avoid someone accidentally use ctx_server
bool server_ctx; // do NOT delete this line
GGML_UNUSED(server_ctx);
bool ctx_server; // do NOT delete this line
GGML_UNUSED(ctx_server);
json models = {
{"models", {

View File

@@ -20,9 +20,8 @@ struct server_context_meta {
int slot_n_ctx;
enum llama_pooling_type pooling_type;
// chat template
std::string chat_template;
std::string chat_template_tool_use;
// chat params
server_chat_params & chat_params;
// tokens
std::string bos_token_str;

View File

@@ -68,10 +68,10 @@ json task_params::to_json(bool only_metrics) const {
{"stream", stream},
{"n_probs", sampling.n_probs},
{"min_keep", sampling.min_keep},
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
{"chat_format", common_chat_format_name(chat_parser_params.format)},
{"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)},
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"thinking_forced_open", chat_parser_params.thinking_forced_open},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@@ -127,10 +127,10 @@ json task_params::to_json(bool only_metrics) const {
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
{"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
{"chat_format", common_chat_format_name(chat_parser_params.format)},
{"reasoning_format", common_reasoning_format_name(chat_parser_params.reasoning_format)},
{"reasoning_in_content", chat_parser_params.reasoning_in_content},
{"thinking_forced_open", chat_parser_params.thinking_forced_open},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@@ -291,21 +291,21 @@ task_params server_task::params_from_json_cmpl(
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format));
params.chat_parser_params.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.chat_parser_params.format));
} else {
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
params.chat_parser_params.format = defaults.chat_parser_params.format;
}
common_reasoning_format reasoning_format = params_base.reasoning_format;
if (data.contains("reasoning_format")) {
reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get<std::string>());
}
params.oaicompat_chat_syntax.reasoning_format = reasoning_format;
params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false);
params.chat_parser_params.reasoning_format = reasoning_format;
params.chat_parser_params.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY);
params.chat_parser_params.thinking_forced_open = json_value(data, "thinking_forced_open", false);
params.chat_parser_params.parse_tool_calls = json_value(data, "parse_tool_calls", false);
if (data.contains("chat_parser")) {
params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get<std::string>());
params.chat_parser_params.parser.load(data.at("chat_parser").get<std::string>());
}
}
@@ -722,7 +722,7 @@ common_chat_msg task_result_state::update_chat_msg(
auto new_msg = common_chat_parse(
generated_text,
is_partial,
oaicompat_chat_syntax);
chat_parser_params);
if (!new_msg.empty()) {
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;

View File

@@ -78,7 +78,9 @@ struct task_params {
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// per-request parameters for chat parsing
common_chat_parser_params chat_parser_params;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
@@ -91,7 +93,7 @@ struct task_params {
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_parser_params chat_parser_params;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
@@ -100,8 +102,8 @@ struct task_result_state {
bool anthropic_thinking_block_started = false;
bool anthropic_text_block_started = false;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
task_result_state(const common_chat_parser_params & chat_parser_params)
: chat_parser_params(chat_parser_params) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
@@ -130,8 +132,10 @@ struct server_task {
task_params params;
server_tokens tokens;
// only used by CLI, this delegates the tokenization to the server
json cli_input = nullptr;
// only used by CLI, this allow tokenizing CLI inputs on server side
// we need this because mtmd_context and vocab are not accessible outside of server_context
bool cli = false;
std::string cli_prompt;
std::vector<raw_buffer> cli_files;
server_task_type type;
@@ -228,7 +232,7 @@ struct server_task {
// the task will be moved into queue, then onto slots
// however, the state must be kept by caller (e.g., HTTP thread)
task_result_state create_state() const {
return task_result_state(params.oaicompat_chat_syntax);
return task_result_state(params.chat_parser_params);
}
bool is_parent() const {