Compare commits

...

18 Commits
b4514 ... b4532

Author SHA1 Message Date
Jeff Bolz
5245729e33 vulkan: fix diag_mask_inf (#11323)
With robustbufferaccess disabled, this shader was showing OOB stores. There
is a bounds check in the code, but the workgrouop dimensions were reversed vs
CUDA and it was running the wrong number of threads. So fix the workgroup
dimensions and disable robustness for this pipeline.
2025-01-23 08:01:17 +01:00
Diego Devesa
6152129d05 main : update README documentation for batch size (#11353)
* main : update README documentation for batch size

* fix formatting

* minor
2025-01-22 19:22:20 +01:00
Georgi Gerganov
16d3df7ab0 readme : add plugin links (#11355) 2025-01-22 19:44:26 +02:00
Diego Devesa
12c2bdf2de server : fix draft context not being released (#11354) 2025-01-22 17:44:40 +01:00
Olivier Chafik
c64d2becb1 minja: sync at 0f5f7f2b37 (#11352) 2025-01-22 16:16:27 +00:00
Jiří Podivín
96f4053934 Adding logprobs to /v1/completions (#11344)
Signed-off-by: Jiri Podivin <jpodivin@redhat.com>
2025-01-22 12:51:32 +01:00
Olivier Chafik
a94f3b2727 common: utils to split / join / repeat strings (from json converter) (#11342)
* Factor string_join, string_split, string_repeat into common

* json: refactor to surface a versatile builder

* Update common.cpp
2025-01-22 09:51:44 +00:00
tc-mb
3e3357fd77 llava : support Minicpm-omni (#11289)
* init

* add readme

* update readme

* no use make

* update readme

* update fix code

* fix editorconfig-checker

* no change convert py

* use clip_image_u8_free
2025-01-22 09:35:48 +02:00
Olivier Chafik
6171c9d258 Add Jinja template support (#11016)
* Copy minja from 58f0ca6dd7

* Add --jinja and --chat-template-file flags

* Add missing <optional> include

* Avoid print in get_hf_chat_template.py

* No designated initializers yet

* Try and work around msvc++ non-macro max resolution quirk

* Update test_chat_completion.py

* Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template

* Refactor test-chat-template

* Test templates w/ minja

* Fix deprecation

* Add --jinja to llama-run

* Update common_chat_format_example to use minja template wrapper

* Test chat_template in e2e test

* Update utils.py

* Update test_chat_completion.py

* Update run.cpp

* Update arg.cpp

* Refactor common_chat_* functions to accept minja template + use_jinja option

* Attempt to fix linkage of LLAMA_CHATML_TEMPLATE

* Revert LLAMA_CHATML_TEMPLATE refactor

* Normalize newlines in test-chat-templates for windows tests

* Forward decl minja::chat_template to avoid eager json dep

* Flush stdout in chat template before potential crash

* Fix copy elision warning

* Rm unused optional include

* Add missing optional include to server.cpp

* Disable jinja test that has a cryptic windows failure

* minja: fix vigogne (https://github.com/google/minja/pull/22)

* Apply suggestions from code review

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Finish suggested renamings

* Move chat_templates inside server_context + remove mutex

* Update --chat-template-file w/ recent change to --chat-template

* Refactor chat template validation

* Guard against missing eos/bos tokens (null token otherwise throws in llama_vocab::impl::token_get_attr)

* Warn against missing eos / bos tokens when jinja template references them

* rename: common_chat_template[s]

* reinstate assert on chat_templates.template_default

* Update minja to b8437df626

* Update minja to https://github.com/google/minja/pull/25

* Update minja from https://github.com/google/minja/pull/27

* rm unused optional header

---------

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-01-21 13:18:51 +00:00
Xuan Son Nguyen
e28245f35f export-lora : fix tok_embd tensor (#11330) 2025-01-21 14:07:12 +01:00
Radoslav Gerganov
6da5bec81c rpc : better caching of the base buffer pointer (#11331)
There is no need to use map, just store the base pointer in the buffer
context.
2025-01-21 15:06:41 +02:00
Eric Curtin
2e2f8f093c linenoise.cpp refactoring (#11301)
More RAII mainly

Signed-off-by: Eric Curtin <ecurtin@redhat.com>
2025-01-21 09:32:35 +00:00
Georgi Gerganov
2139667ec4 metal : fix out-of-bounds write (#11314)
ggml-ci
2025-01-21 08:48:13 +02:00
Georgi Gerganov
80d0d6b4b7 common : add -hfd option for the draft model (#11318)
* common : add -hfd option for the draft model

* cont : fix env var

* cont : more fixes
2025-01-20 22:29:43 +02:00
Jeff Bolz
aea8ddd516 vulkan: fix coopmat2 validation failures (#11284)
mul mat and flash attention shaders were loading f32 types directly into
A/B matrices, which happens to work but is technically invalid usage.
For FA, we can load it as an Accumulator matrix and convert and this
is not in the inner loop and is cheap enough. For mul mat, it's more
efficient to do this conversion in a separate pass and have the input(s)
be f16.

coopmat2 requires SPIR-V 1.6 (related using to LocalSizeId). LocalSizeId
requires maintenance4 be enabled, and SPIR-V 1.6 requires Vulkan 1.3.
2025-01-20 10:38:32 -06:00
Georgi Gerganov
9f7add1cde examples : fix add_special conditions (#11311) 2025-01-20 16:36:08 +02:00
Christopher Nielsen
90d987b105 mmap: add include for cerrno (#11296)
ggml-ci

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2025-01-20 16:02:43 +02:00
Michael Podvitskiy
a4251edd6f cmake: fix shell command quoting in build-info script (#11309) 2025-01-20 16:02:15 +02:00
44 changed files with 4090 additions and 451 deletions

View File

@@ -1361,7 +1361,9 @@ llama-server: \
examples/server/httplib.h \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)

View File

@@ -16,7 +16,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
## Hot topics
- **Introducing GGUF-my-LoRA** https://github.com/ggerganov/llama.cpp/discussions/10123
- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669
- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor)

View File

@@ -44,7 +44,7 @@ if(MSVC)
set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME})
else()
execute_process(
COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER}
COMMAND sh -c "\"$@\" --version | head -1" _ ${CMAKE_C_COMPILER}
OUTPUT_VARIABLE OUT
OUTPUT_STRIP_TRAILING_WHITESPACE
)

View File

@@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.hpp
common.cpp
common.h
console.cpp
@@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
json.hpp
log.cpp
log.h
minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp

View File

@@ -133,7 +133,8 @@ static void common_params_handle_model_default(
const std::string & model_url,
std::string & hf_repo,
std::string & hf_file,
const std::string & hf_token) {
const std::string & hf_token,
const std::string & model_default) {
if (!hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (hf_file.empty()) {
@@ -163,7 +164,7 @@ static void common_params_handle_model_default(
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
} else if (model.empty()) {
model = DEFAULT_MODEL_PATH;
model = model_default;
}
}
@@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
// TODO: refactor model params in a common struct
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH);
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, "");
if (params.escape) {
string_process_escapes(params.prompt);
@@ -323,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
params.chat_template.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
));
}
return true;
}
@@ -1629,6 +1639,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO"));
add_opt(common_arg(
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
"Same as --hf-repo, but for the draft model (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.hf_repo = value;
}
).set_env("LLAMA_ARG_HFD_REPO"));
add_opt(common_arg(
{"-hff", "--hf-file"}, "FILE",
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
@@ -1938,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
if (!common_chat_verify_template(value)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s\n"
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
value.c_str()
));
}
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(
"set custom jinja chat template file (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.chat_template));
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),

268
common/chat-template.hpp Normal file
View File

@@ -0,0 +1,268 @@
/*
Copyright 2024 Google LLC
Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once
#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
namespace minja {
class chat_template {
public:
private:
bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool requires_typed_content_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;
std::string try_raw_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
// fprintf(stderr, "Error: %s\n", e.what());
return "";
}
}
public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
{
template_root_ = minja::Parser::parse(source_, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
supports_tools_ = source.find("tools") != std::string::npos;
auto renders_string_arguments =
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_raw_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", {
{"code", "print('Hello, World!')"},
}},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
requires_object_arguments_ = renders_object_arguments;
}
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;
supports_system_role_ = try_raw_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;
requires_typed_content_ = try_raw_render({{{"role", "user"}, {"content", "Hey"}}}, {}, false).find("Hey") == std::string::npos
&& try_raw_render({{{"role", "user"}, {"content", {{{"type", "text"}, {"text", "Hey"}}}}}}, {}, false).find("Hey") != std::string::npos;
}
const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
bool supports_tools() const { return supports_tools_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }
std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
bool adjust_inputs = true) const
{
json actual_messages;
// First, "fix" messages so they have a chance to be rendered correctly by the template
if (adjust_inputs && (requires_object_arguments_ || !supports_system_role_ || !supports_tools_ || requires_typed_content_)) {
actual_messages = json::array();
auto add_message = [&](const json & msg) {
if (requires_typed_content_ && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
actual_messages.push_back({
{"role", msg.at("role")},
{"content", {{
{"type", "text"},
{"text", msg.at("content")},
}}},
});
} else {
actual_messages.push_back(msg);
}
};
std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
add_message({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto & message_ : messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");
if (message.contains("tool_calls")) {
if (requires_object_arguments_ || !supports_tools_) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
if (!supports_tools_) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
continue;
}
const auto & function = tool_call.at("function");
auto tc = json {
{"name", function.at("name")},
{"arguments", function.at("arguments")},
};
if (tool_call.contains("id")) {
tc["id"] = tool_call["id"];
}
tool_calls.push_back(tc);
}
auto obj = json {
{"tool_calls", tool_calls},
};
if (!content.is_null() && content != "") {
obj["content"] = content;
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
}
}
if (!supports_tools_ && role == "tool") {
message["role"] = "user";
auto obj = json {
{"tool_response", {
{"tool", message.at("name")},
{"content", message.at("content")},
}},
};
if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
}
message["content"] = obj.dump(2);
message.erase("name");
}
if (!message["content"].is_null() && !supports_system_role_) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
add_message(message);
}
flush_sys();
} else {
actual_messages = messages;
}
auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", bos_token_},
{"eos_token", eos_token_},
}));
if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}
return template_root_->render(context);
}
};
} // namespace minja

View File

@@ -12,6 +12,7 @@
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "chat-template.hpp"
#include <algorithm>
#include <cinttypes>
@@ -483,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
result << separator;
}
result << values[i];
}
return result.str();
}
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> parts;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
parts.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
parts.push_back(str.substr(start));
return parts;
}
std::string string_repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
std::string string_from(bool value) {
return value ? "true" : "false";
}
@@ -1728,67 +1771,75 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
// Chat template utils
//
std::string common_get_builtin_chat_template(const struct llama_model * model) {
const char * ptr_tmpl = llama_model_chat_template(model);
return ptr_tmpl == nullptr ? "" : ptr_tmpl;
}
bool common_chat_verify_template(const std::string & tmpl) {
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = minja::chat_template(tmpl, "<s>", "</s>");
chat_template.apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass) {
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
return tmpl.apply(messages, /* tools= */ json(), add_ass);
}
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass) {
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
@@ -1796,21 +1847,74 @@ std::string common_chat_format_single(const struct llama_model * model,
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl) {
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant"},
{"user", "Hello"},
{"assistant", "Hi there"},
{"user", "How are you?"},
};
return common_chat_apply_template(model, tmpl, msgs, true);
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
std::string default_template_src = chat_template_override;
std::string template_tool_use_src = chat_template_override;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
}
}
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
} else {
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
};
}
//

View File

@@ -175,7 +175,11 @@ struct common_params_speculative {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
std::string model = ""; // draft model for speculative decoding // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // draft model for speculative decoding // NOLINT
std::string model_url = ""; // model url to download // NOLINT
};
struct common_params_vocoder {
@@ -330,6 +334,7 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
std::vector<std::string> api_keys;
@@ -424,6 +429,10 @@ std::string string_format(const char * fmt, ...);
std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
std::string string_repeat(const std::string & str, size_t n);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
template<class T>
@@ -508,12 +517,14 @@ struct llama_model * common_load_model_from_url(
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
@@ -597,30 +608,43 @@ struct common_chat_msg {
std::string content;
};
// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass);
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass);
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(const struct llama_model * model,
const std::string & tmpl);
std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
//
// KV cache utils

View File

@@ -1,4 +1,6 @@
#include "json-schema-to-grammar.h"
#include "common.h"
#include <algorithm>
#include <fstream>
#include <map>
@@ -11,11 +13,6 @@
using json = nlohmann::ordered_json;
template <typename Iterator>
static std::string join(Iterator begin, Iterator end, const std::string & separator);
static std::string repeat(const std::string & str, size_t n);
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max();
@@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (sub_len > 0) {
auto from_sub = from.substr(i + 1);
auto to_sub = to.substr(i + 1);
auto sub_zeros = repeat("0", sub_len);
auto sub_nines = repeat("9", sub_len);
auto sub_zeros = string_repeat("0", sub_len);
auto sub_nines = string_repeat("9", sub_len);
auto to_reached = false;
out << "(";
@@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
auto max_digits = max_s.length();
for (auto digits = min_digits; digits < max_digits; digits++) {
uniform_range(min_s, repeat("9", digits));
min_s = "1" + repeat("0", digits);
uniform_range(min_s, string_repeat("9", digits));
min_s = "1" + string_repeat("0", digits);
out << " | ";
}
uniform_range(min_s, max_s);
@@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}
static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);
while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}
tokens.push_back(str.substr(start));
return tokens;
}
static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}
std::string result;
result.reserve(str.length() * n);
for (size_t i = 0; i < n; ++i) {
result += str;
}
return result;
}
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match;
std::string result;
@@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {
class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
@@ -418,7 +373,7 @@ private:
for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
}
return join(rules.begin(), rules.end(), " | ");
return string_join(rules, " | ");
}
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
@@ -481,7 +436,7 @@ private:
for (const auto & item : ret) {
results.push_back(to_rule(item));
}
return std::make_pair(join(results.begin(), results.end(), " "), false);
return std::make_pair(string_join(results, " "), false);
};
while (i < length) {
@@ -539,7 +494,7 @@ private:
}
curly_brackets += '}';
i++;
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0;
int max_times = std::numeric_limits<int>::max();
try {
@@ -854,7 +809,7 @@ public:
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/");
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
@@ -905,7 +860,7 @@ public:
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -1019,10 +974,10 @@ public:
void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
}
}
@@ -1036,10 +991,27 @@ public:
};
std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
return build_grammar([&](const llama_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
}
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
llama_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
},
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
return converter.visit(schema, name == "root" ? "" : name);
},
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
converter.resolve_refs(schema, "");
}
};
cb(builder);
converter.check_errors();
return converter.format_grammar();
}

View File

@@ -5,4 +5,12 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
struct llama_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
std::function<void(nlohmann::ordered_json &)> resolve_refs;
};
std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);

2812
common/minja.hpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -345,8 +345,18 @@ struct lora_merge_ctx {
gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = inp_base;
for (size_t i = 0; i < adapters.size(); ++i) {
struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)));
struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
struct ggml_tensor * delta;
bool is_tok_embd = string_starts_with(name_base, "token_embd");
if (is_tok_embd) {
printf("%s : detected token embeddings tensor\n", __func__);
delta = ggml_mul_mat(ctx0,
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32),
ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32));
} else {
delta = ggml_mul_mat(ctx0,
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))),
ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32));
}
// scale
const float alpha = adapters[i]->alpha;
const float rank = (float) inp_b[i]->ne[0];

View File

@@ -0,0 +1,46 @@
## MiniCPM-o 2.6
Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible.
### Prepare models and code
Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder.
Clone llama.cpp:
```bash
git clone git@github.com:OpenBMB/llama.cpp.git
cd llama.cpp
git checkout minicpm-omni
```
### Usage of MiniCPM-o 2.6
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us)
```bash
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4
python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
# quantize int4 version
./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
```
Build llama.cpp using `CMake`:
https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md
```bash
cmake -B build
cmake --build build --config Release
```
Inference on Linux or Mac
```
# run f16 version
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# run quantized int4 version
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
# or run in interactive mode
./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
```

View File

@@ -718,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
else if (ctx->minicpmv_version == 3) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
else if (ctx->minicpmv_version == 4) {
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
}
ggml_set_name(pos_embed, "pos_embed");
ggml_set_input(pos_embed);
}
@@ -1053,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
n_head = hidden_size/d_head;
num_query = 64;
}
else if (ctx->minicpmv_version == 4) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
@@ -2041,6 +2049,7 @@ static std::vector<std::vector<clip_image_u8 *>> uhd_slice_image(const clip_imag
images[images.size()-1].push_back(patch);
}
}
clip_image_u8_free(refine_image);
}
return images;
}
@@ -2079,6 +2088,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
clip_image_f32_free(res);
}
}
for (size_t i = 0; i < imgs.size(); ++i) {
for (size_t j = 0; j < imgs[i].size(); ++j) {
if (imgs[i][j] != nullptr) {
clip_image_u8_free(imgs[i][j]);
}
}
}
return true;
}
else if (ctx->has_qwen2vl_merger) {
@@ -2335,6 +2351,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
else if (ctx->minicpmv_version == 3) {
n_patches = 64;
}
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
@@ -2514,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
int bucket_coords_h[70];
int bucket_coords_w[70];
int bucket_coords_h[1024];
int bucket_coords_w[1024];
for (int i = 0; i < pos_h; i++){
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
}
@@ -2543,6 +2562,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->minicpmv_version == 3) {
embed_dim = 3584;
}
else if (ctx->minicpmv_version == 4) {
embed_dim = 3584;
}
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
@@ -2786,6 +2808,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
else if (ctx->minicpmv_version == 3) {
return 3584;
}
else if (ctx->minicpmv_version == 4) {
return 3584;
}
}
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];

View File

@@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
return true;
}
static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) {
static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) {
int width = image->nx;
int height = image->ny;
int num_patches = (height / patch_size) * (width / patch_size);
@@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
else {
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
if (has_minicpmv_projector == 2) {
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
else if (has_minicpmv_projector == 3) {
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
}
encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
}
if (!encoded) {
@@ -313,6 +307,9 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
load_image_size->height = img->ny;
clip_add_load_image_size(ctx_clip, load_image_size);
LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height);
delete[] img_res_v.data;
img_res_v.size = 0;
img_res_v.data = nullptr;
}
else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) {
// flat / default llava-1.5 type embedding

View File

@@ -140,6 +140,9 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
else if (has_minicpmv_projector == 3) {
system_prompt = "<|im_start|>user\n";
}
else if (has_minicpmv_projector == 4) {
system_prompt = "<|im_start|>user\n";
}
LOG_INF("%s: image token past: %d\n", __func__, n_past);
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
@@ -227,6 +230,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
else if (has_minicpmv_projector == 3) {
user_prompt = "<|im_start|>user\n" + prompt;
}
else if (has_minicpmv_projector == 4) {
user_prompt = "<|im_start|>user\n" + prompt;
}
}
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
@@ -236,6 +242,9 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
else if (has_minicpmv_projector == 3) {
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
}
else if (has_minicpmv_projector == 4) {
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
}
// generate the response
@@ -308,7 +317,6 @@ int main(int argc, char ** argv) {
const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
printf("%s", tmp);// mistral llava-1.6
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);

View File

@@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2)
# with proper
args = ap.parse_args()
@@ -545,12 +545,19 @@ if args.use_f32:
minicpmv_version = args.minicpmv_version
emb_dim = 4096
block_count = 26
if minicpmv_version == 1:
emb_dim = 2304
block_count = 26
elif minicpmv_version == 2:
emb_dim = 4096
block_count = 27
elif minicpmv_version == 3:
emb_dim = 3584
block_count = 27
elif minicpmv_version == 4:
emb_dim = 3584
block_count = 27
default_vision_config = {
"hidden_size": 1152,
@@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config)
if minicpmv_version == 3:
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
elif minicpmv_version == 4:
vision_config = SiglipVisionConfig(**default_vision_config)
model = SiglipVisionTransformer(vision_config)
processor = None
# if model.attn_pool is not None:
@@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None:
fname_middle = "mmproj-"
has_text_encoder = False
has_minicpmv_projector = True
minicpmv_version = 3
minicpmv_version = 4
elif args.vision_only:
fname_middle = "vision-"
has_text_encoder = False
@@ -625,7 +635,6 @@ if has_vision_encoder:
fout.add_uint32("clip.vision.projection_dim", 0)
fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16)
fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
block_count = 26
fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count)
if processor is not None:

View File

@@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
args = ap.parse_args()
# find the model part that includes the the multimodal projector weights
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True)
model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16)
checkpoint = model.state_dict()
# get a list of mm tensor names

View File

@@ -310,9 +310,9 @@ These options help improve the performance and memory usage of the LLaMA models.
### Batch Size
- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`.
- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`.
- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`.
### Prompt Caching

View File

@@ -4,6 +4,7 @@
#include "log.h"
#include "sampling.h"
#include "llama.h"
#include "chat-template.hpp"
#include <cstdio>
#include <cstring>
@@ -84,14 +85,6 @@ static void sigint_handler(int signo) {
}
#endif
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
}
int main(int argc, char ** argv) {
common_params params;
g_params = &params;
@@ -165,6 +158,7 @@ int main(int argc, char ** argv) {
}
const llama_vocab * vocab = llama_model_get_vocab(model);
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@@ -207,7 +201,7 @@ int main(int argc, char ** argv) {
}
// auto enable conversation mode if chat template is available
const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty();
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@@ -225,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode
if (params.conversation_mode) {
if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
} else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
}
@@ -269,10 +263,18 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back({role, content});
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
{
auto prompt = (params.conversation_mode && params.enable_chat_template)
// format the system prompt in conversation mode (fallback to default if empty)
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
// otherwise use the prompt as is
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@@ -779,7 +781,7 @@ int main(int argc, char ** argv) {
}
if (params.enable_chat_template) {
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
chat_add_and_format("assistant", assistant_ss.str());
}
is_interacting = true;
LOG("\n");
@@ -844,7 +846,7 @@ int main(int argc, char ** argv) {
bool format_chat = params.conversation_mode && params.enable_chat_template;
std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
? chat_add_and_format("user", std::move(buffer))
: std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);

View File

@@ -103,24 +103,26 @@
*
*/
#include <termios.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <stdlib.h>
#include <ctype.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/ioctl.h>
#include <unistd.h>
#include <vector>
#include "linenoise.h"
# include "linenoise.h"
#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100
#define LINENOISE_MAX_LINE 4096
static std::vector<const char*> unsupported_term = {"dumb","cons25","emacs",nullptr};
# include <ctype.h>
# include <errno.h>
# include <stdio.h>
# include <string.h>
# include <sys/file.h>
# include <sys/ioctl.h>
# include <sys/stat.h>
# include <sys/types.h>
# include <termios.h>
# include <unistd.h>
# include <memory>
# include <string>
# include <vector>
# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100
# define LINENOISE_MAX_LINE 4096
static std::vector<const char *> unsupported_term = { "dumb", "cons25", "emacs" };
static linenoiseCompletionCallback *completionCallback = NULL;
static linenoiseHintsCallback *hintsCallback = NULL;
static linenoiseFreeHintsCallback *freeHintsCallback = NULL;
@@ -166,21 +168,58 @@ int linenoiseHistoryAdd(const char *line);
#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both.
static void refreshLine(struct linenoiseState *l);
class File {
public:
FILE * file = nullptr;
FILE * open(const std::string & filename, const char * mode) {
file = fopen(filename.c_str(), mode);
return file;
}
int lock() {
if (file) {
fd = fileno(file);
if (flock(fd, LOCK_EX | LOCK_NB) != 0) {
fd = -1;
return 1;
}
}
return 0;
}
~File() {
if (fd >= 0) {
flock(fd, LOCK_UN);
}
if (file) {
fclose(file);
}
}
private:
int fd = -1;
};
__attribute__((format(printf, 1, 2)))
/* Debugging function. */
#if 0
static void lndebug(const char *fmt, ...) {
static FILE *lndebug_fp = NULL;
if (lndebug_fp == NULL) {
lndebug_fp = fopen("/tmp/lndebug.txt", "a");
static File file;
if (file.file == nullptr) {
file.open("/tmp/lndebug.txt", "a");
}
if (lndebug_fp != NULL) {
if (file.file != nullptr) {
va_list args;
va_start(args, fmt);
vfprintf(lndebug_fp, fmt, args);
vfprintf(file.file, fmt, args);
va_end(args);
fflush(lndebug_fp);
fflush(file.file);
}
}
#else
@@ -213,8 +252,11 @@ void linenoiseSetMultiLine(int ml) {
static int isUnsupportedTerm(void) {
char *term = getenv("TERM");
if (term == NULL) return 0;
for (int j = 0; unsupported_term[j]; ++j)
if (!strcasecmp(term, unsupported_term[j])) return 1;
for (size_t j = 0; j < unsupported_term.size(); ++j) {
if (!strcasecmp(term, unsupported_term[j])) {
return 1;
}
}
return 0;
}
@@ -334,17 +376,6 @@ static void linenoiseBeep(void) {
fflush(stderr);
}
/* ============================== Completion ================================ */
/* Free a list of completion option populated by linenoiseAddCompletion(). */
static void freeCompletions(linenoiseCompletions *lc) {
size_t i;
for (i = 0; i < lc->len; i++)
free(lc->cvec[i]);
if (lc->cvec != NULL)
free(lc->cvec);
}
/* Called by completeLine() and linenoiseShow() to render the current
* edited line with the proposed completion. If the current completion table
* is already available, it is passed as second argument, otherwise the
@@ -353,9 +384,9 @@ static void freeCompletions(linenoiseCompletions *lc) {
* Flags are the same as refreshLine*(), that is REFRESH_* macros. */
static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) {
/* Obtain the table of completions if the caller didn't provide one. */
linenoiseCompletions ctable = { 0, NULL };
linenoiseCompletions ctable;
if (lc == NULL) {
completionCallback(ls->buf,&ctable);
completionCallback(ls->buf, &ctable);
lc = &ctable;
}
@@ -364,16 +395,17 @@ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseComple
struct linenoiseState saved = *ls;
ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]);
ls->buf = lc->cvec[ls->completion_idx];
refreshLineWithFlags(ls,flags);
refreshLineWithFlags(ls, flags);
ls->len = saved.len;
ls->pos = saved.pos;
ls->buf = saved.buf;
} else {
refreshLineWithFlags(ls,flags);
refreshLineWithFlags(ls, flags);
}
/* Free the completions table if needed. */
if (lc != &ctable) freeCompletions(&ctable);
if (lc == &ctable) {
ctable.to_free = false;
}
}
/* This is an helper function for linenoiseEdit*() and is called when the
@@ -391,11 +423,11 @@ static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseComple
* possible completions, and the caller should read for the next characters
* from stdin. */
static int completeLine(struct linenoiseState *ls, int keypressed) {
linenoiseCompletions lc = { 0, NULL };
linenoiseCompletions lc;
int nwritten;
char c = keypressed;
completionCallback(ls->buf,&lc);
completionCallback(ls->buf, &lc);
if (lc.len == 0) {
linenoiseBeep();
ls->in_completion = 0;
@@ -406,7 +438,7 @@ static int completeLine(struct linenoiseState *ls, int keypressed) {
ls->in_completion = 1;
ls->completion_idx = 0;
} else {
ls->completion_idx = (ls->completion_idx+1) % (lc.len+1);
ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1);
if (ls->completion_idx == lc.len) linenoiseBeep();
}
c = 0;
@@ -420,8 +452,7 @@ static int completeLine(struct linenoiseState *ls, int keypressed) {
default:
/* Update buffer and return */
if (ls->completion_idx < lc.len) {
nwritten = snprintf(ls->buf,ls->buflen,"%s",
lc.cvec[ls->completion_idx]);
nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]);
ls->len = ls->pos = nwritten;
}
ls->in_completion = 0;
@@ -430,13 +461,12 @@ static int completeLine(struct linenoiseState *ls, int keypressed) {
/* Show completion or original buffer */
if (ls->in_completion && ls->completion_idx < lc.len) {
refreshLineWithCompletion(ls,&lc,REFRESH_ALL);
refreshLineWithCompletion(ls, &lc, REFRESH_ALL);
} else {
refreshLine(ls);
}
}
freeCompletions(&lc);
return c; /* Return last read character */
}
@@ -462,53 +492,25 @@ void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) {
* user typed <tab>. See the example.c source code for a very easy to
* understand example. */
void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) {
size_t len = strlen(str);
char *copy, **cvec;
copy = (char*) malloc(len + 1);
if (copy == NULL) return;
memcpy(copy,str,len+1);
cvec = (char**) realloc(lc->cvec,sizeof(char*)*(lc->len+1));
if (cvec == NULL) {
free(copy);
const size_t len = strlen(str);
auto copy = std::make_unique<char[]>(len + 1);
if (!copy) {
return;
}
memcpy(copy.get(), str, len + 1);
char ** cvec = static_cast<char **>(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1)));
if (cvec == nullptr) {
return;
}
lc->cvec = cvec;
lc->cvec[lc->len++] = copy;
}
/* =========================== Line editing ================================= */
/* We define a very simple "append buffer" structure, that is an heap
* allocated string where we can append to. This is useful in order to
* write all the escape sequences in a buffer and flush them to the standard
* output in a single call, to avoid flickering effects. */
struct abuf {
char *b;
int len;
};
static void abInit(struct abuf *ab) {
ab->b = NULL;
ab->len = 0;
}
static void abAppend(struct abuf *ab, const char *s, int len) {
char *new_ptr = (char*) realloc(ab->b,ab->len+len);
if (new_ptr == NULL) return;
memcpy(new_ptr+ab->len,s,len);
ab->b = new_ptr;
ab->len += len;
}
static void abFree(struct abuf *ab) {
free(ab->b);
lc->cvec[lc->len++] = copy.release();
}
/* Helper of refreshSingleLine() and refreshMultiLine() to show hints
* to the right of the prompt. */
static void refreshShowHints(struct abuf * ab, struct linenoiseState * l, int plen) {
static void refreshShowHints(std::string & ab, struct linenoiseState * l, int plen) {
char seq[64];
if (hintsCallback && plen+l->len < l->cols) {
int color = -1, bold = 0;
@@ -522,10 +524,11 @@ static void refreshShowHints(struct abuf * ab, struct linenoiseState * l, int pl
snprintf(seq,64,"\033[%d;%d;49m",bold,color);
else
seq[0] = '\0';
abAppend(ab,seq,strlen(seq));
abAppend(ab,hint,hintlen);
ab.append(seq);
ab.append(hint, hintlen);
if (color != -1 || bold != 0)
abAppend(ab,"\033[0m",4);
ab.append("\033[0m");
/* Call the function to free the hint returned. */
if (freeHintsCallback) freeHintsCallback(hint);
}
@@ -546,8 +549,7 @@ static void refreshSingleLine(struct linenoiseState *l, int flags) {
char *buf = l->buf;
size_t len = l->len;
size_t pos = l->pos;
struct abuf ab;
std::string ab;
while((plen+pos) >= l->cols) {
buf++;
len--;
@@ -557,35 +559,34 @@ static void refreshSingleLine(struct linenoiseState *l, int flags) {
len--;
}
abInit(&ab);
/* Cursor to left edge */
snprintf(seq,sizeof(seq),"\r");
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
if (flags & REFRESH_WRITE) {
/* Write the prompt and the current buffer content */
abAppend(&ab,l->prompt,strlen(l->prompt));
ab.append(l->prompt);
if (maskmode == 1) {
while (len--) abAppend(&ab,"*",1);
while (len--) {
ab.append("*");
}
} else {
abAppend(&ab,buf,len);
ab.append(buf, len);
}
/* Show hits if any. */
refreshShowHints(&ab,l,plen);
refreshShowHints(ab, l, plen);
}
/* Erase to right */
snprintf(seq,sizeof(seq),"\x1b[0K");
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
if (flags & REFRESH_WRITE) {
/* Move cursor to original position. */
snprintf(seq,sizeof(seq),"\r\x1b[%dC", (int)(pos+plen));
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
}
if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */
abFree(&ab);
(void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */
}
/* Multi line low level line refresh.
@@ -604,26 +605,23 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
int col; /* colum position, zero-based. */
int old_rows = l->oldrows;
int fd = l->ofd, j;
struct abuf ab;
std::string ab;
l->oldrows = rows;
/* First step: clear all the lines used before. To do so start by
* going to the last row. */
abInit(&ab);
if (flags & REFRESH_CLEAN) {
if (old_rows-rpos > 0) {
lndebug("go down %d", old_rows-rpos);
snprintf(seq,64,"\x1b[%dB", old_rows-rpos);
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
}
/* Now for every row clear it, go up. */
for (j = 0; j < old_rows-1; j++) {
lndebug("clear+up");
snprintf(seq,64,"\r\x1b[0K\x1b[1A");
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
}
}
@@ -631,21 +629,22 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
/* Clean the top line. */
lndebug("clear");
snprintf(seq,64,"\r\x1b[0K");
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
}
if (flags & REFRESH_WRITE) {
/* Write the prompt and the current buffer content */
abAppend(&ab,l->prompt,strlen(l->prompt));
ab.append(l->prompt);
if (maskmode == 1) {
unsigned int i;
for (i = 0; i < l->len; i++) abAppend(&ab,"*",1);
for (unsigned int i = 0; i < l->len; ++i) {
ab.append("*");
}
} else {
abAppend(&ab,l->buf,l->len);
ab.append(l->buf, l->len);
}
/* Show hits if any. */
refreshShowHints(&ab,l,plen);
refreshShowHints(ab, l, plen);
/* If we are at the very end of the screen with our prompt, we need to
* emit a newline and move the prompt to the first column. */
@@ -654,9 +653,9 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
(l->pos+plen) % l->cols == 0)
{
lndebug("<newline>");
abAppend(&ab,"\n",1);
ab.append("\n");
snprintf(seq,64,"\r");
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
rows++;
if (rows > (int)l->oldrows) l->oldrows = rows;
}
@@ -669,7 +668,7 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
if (rows-rpos2 > 0) {
lndebug("go-up %d", rows-rpos2);
snprintf(seq,64,"\x1b[%dA", rows-rpos2);
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
}
/* Set column. */
@@ -679,14 +678,12 @@ static void refreshMultiLine(struct linenoiseState *l, int flags) {
snprintf(seq,64,"\r\x1b[%dC", col);
else
snprintf(seq,64,"\r");
abAppend(&ab,seq,strlen(seq));
ab.append(seq);
}
lndebug("\n");
l->oldpos = l->pos;
if (write(fd,ab.b,ab.len) == -1) {} /* Can't recover from write error. */
abFree(&ab);
(void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */
}
/* Calls the two low level functions refreshSingleLine() or
@@ -1313,16 +1310,17 @@ int linenoiseHistorySetMaxLen(int len) {
* otherwise -1 is returned. */
int linenoiseHistorySave(const char *filename) {
mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO);
FILE *fp;
int j;
fp = fopen(filename,"w");
File file;
file.open(filename, "w");
umask(old_umask);
if (fp == NULL) return -1;
if (file.file == NULL) {
return -1;
}
chmod(filename,S_IRUSR|S_IWUSR);
for (j = 0; j < history_len; j++)
fprintf(fp,"%s\n",history[j]);
fclose(fp);
for (int j = 0; j < history_len; ++j) {
fprintf(file.file, "%s\n", history[j]);
}
return 0;
}
@@ -1332,12 +1330,14 @@ int linenoiseHistorySave(const char *filename) {
* If the file exists and the operation succeeded 0 is returned, otherwise
* on error -1 is returned. */
int linenoiseHistoryLoad(const char *filename) {
FILE *fp = fopen(filename,"r");
File file;
file.open(filename, "r");
char buf[LINENOISE_MAX_LINE];
if (file.file == NULL) {
return -1;
}
if (fp == NULL) return -1;
while (fgets(buf,LINENOISE_MAX_LINE,fp) != NULL) {
while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) {
char *p;
p = strchr(buf,'\r');
@@ -1345,7 +1345,6 @@ int linenoiseHistoryLoad(const char *filename) {
if (p) *p = '\0';
linenoiseHistoryAdd(buf);
}
fclose(fp);
return 0;
}
#endif

View File

@@ -45,6 +45,7 @@ extern "C" {
#endif
#include <stddef.h> /* For size_t. */
#include <stdlib.h>
extern const char *linenoiseEditMore;
@@ -69,10 +70,23 @@ struct linenoiseState {
int history_index; /* The history index we are currently editing. */
};
typedef struct linenoiseCompletions {
size_t len;
char **cvec;
} linenoiseCompletions;
struct linenoiseCompletions {
size_t len = 0;
char ** cvec = nullptr;
bool to_free = true;
~linenoiseCompletions() {
if (!to_free) {
return;
}
for (size_t i = 0; i < len; ++i) {
free(cvec[i]);
}
free(cvec);
}
};
/* Non blocking API. */
int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt);

View File

@@ -28,6 +28,7 @@
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
#include "llama-cpp.h"
#include "chat-template.hpp"
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
[[noreturn]] static void sigint_handler(int) {
@@ -105,6 +106,7 @@ class Opt {
llama_model_params model_params;
std::string model_;
std::string user;
bool use_jinja = false;
int context_size = -1, ngl = -1;
float temperature = -1;
bool verbose = false;
@@ -156,6 +158,8 @@ class Opt {
} else if (options_parsing &&
(parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) {
verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true;
return 0;
@@ -713,13 +717,31 @@ static void add_message(const char * role, const std::string & text, LlamaData &
}
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) {
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
try {
auto result = tmpl.apply(messages, /* tools= */ json(), append);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
}
@@ -729,10 +751,12 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
// Function to tokenize the prompt
static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> & prompt_tokens, const LlamaData & llama_data) {
const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0;
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first,
true) < 0) {
printe("failed to tokenize the prompt\n");
return -1;
@@ -778,7 +802,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens;
if (tokenize_prompt(vocab, prompt, tokens) < 0) {
if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) {
return 1;
}
@@ -869,8 +893,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
}
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) {
const int new_len = apply_chat_template(llama_data, append);
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
if (new_len < 0) {
printe("failed to apply the chat template\n");
return -1;
@@ -929,9 +953,11 @@ static int get_user_input(std::string & user_input, const std::string & user) {
}
// Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user) {
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default);
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
@@ -942,7 +968,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) {
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
return 1;
}
@@ -957,7 +983,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) {
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) {
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}
@@ -1017,7 +1043,7 @@ int main(int argc, const char ** argv) {
return 1;
}
if (chat_loop(llama_data, opt.user)) {
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
return 1;
}

View File

@@ -126,7 +126,7 @@ The project is under active development, and we are [looking for feedback and co
| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') |
| `--grammar-file FNAME` | file to read grammar from |
| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object<br/>For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead |
| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) |
**Example-specific params**

View File

@@ -267,6 +267,11 @@ struct server_task {
params.speculative.n_min = std::max(params.speculative.n_min, 2);
params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided
if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){
params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);
}
if (data.contains("lora")) {
if (data.at("lora").is_array()) {
params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora"));
@@ -1688,6 +1693,8 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates;
~server_context() {
// Clear any sampling context
for (server_slot & slot : slots) {
@@ -1728,13 +1735,16 @@ struct server_context {
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.empty()) {
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices;
params_dft.hf_file = params_base.speculative.hf_file;
params_dft.hf_repo = params_base.speculative.hf_repo;
params_dft.model = params_base.speculative.model;
params_dft.model_url = params_base.speculative.model_url;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
@@ -1762,16 +1772,44 @@ struct server_context {
// force F16 KV cache for the draft model for extra performance
cparams_dft.type_k = GGML_TYPE_F16;
cparams_dft.type_v = GGML_TYPE_F16;
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true;
}
bool validate_builtin_chat_template() const {
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
const char * tmpl = llama_model_chat_template(model);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
GGML_ASSERT(templates.template_default);
try {
templates.template_default->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
if (templates.template_tool_use) {
templates.template_tool_use->apply({{
{"role", "user"},
{"content", "test"},
}}, json(), true);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() {
@@ -3656,9 +3694,12 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
{ "chat_template", ctx_server.chat_templates.template_default->source() },
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
}
res_ok(res, data);
};
@@ -3886,7 +3927,10 @@ int main(int argc, char ** argv) {
return;
}
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
auto body = json::parse(req.body);
const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default;
json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja);
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
@@ -4296,7 +4340,7 @@ int main(int argc, char ** argv) {
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_builtin_chat_template()) {
if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) {
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
params.chat_template = "chatml";
}
@@ -4304,8 +4348,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
ctx_server.chat_templates.template_default->source().c_str(),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1));

View File

@@ -4,22 +4,26 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
global server
server.jinja = jinja
server.chat_template = chat_template
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,

View File

@@ -72,13 +72,14 @@ class ServerProcess:
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
chat_template: str | None = None
chat_template_file: str | None = None
# session variables
process: subprocess.Popen | None = None
@@ -169,8 +170,12 @@ class ServerProcess:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
if self.jinja:
server_args.append("--jinja")
if self.chat_template:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")

View File

@@ -16,6 +16,8 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "minja.hpp"
#include "chat-template.hpp"
#include <random>
#include <sstream>
@@ -349,7 +351,7 @@ static llama_tokens format_infill(
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
@@ -377,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
chat.push_back({role, content});
}
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
@@ -576,14 +578,23 @@ static json oaicompat_completion_params_parse(const json & body) {
return llama_params;
}
static json oaicompat_chat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
const common_chat_template & tmpl,
bool use_jinja)
{
json llama_params;
// Apply chat template to the list of messages
llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
if (has_tools) {
if (use_jinja) {
LOG_WRN("tools param is not fully supported yet\n");
} else {
throw std::runtime_error("tools param requires --jinja flag");
}
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
@@ -606,6 +617,13 @@ static json oaicompat_chat_completion_params_parse(
}
}
// Apply chat template to the list of messages
if (use_jinja) {
llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true);
} else {
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
@@ -621,7 +639,7 @@ static json oaicompat_chat_completion_params_parse(
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
static const std::vector<std::string> unsupported_params { "tool_choice" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);

View File

@@ -95,13 +95,15 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
// helper function to evaluate a prompt and generate a response
auto generate = [&](const std::string & prompt, bool is_first) {
auto generate = [&](const std::string & prompt) {
std::string response;
const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0;
// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n");
}
@@ -161,7 +163,7 @@ int main(int argc, char ** argv) {
break;
}
const char * tmpl = llama_model_chat_template(model);
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
// add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())});
@@ -180,7 +182,7 @@ int main(int argc, char ** argv) {
// generate a response
printf("\033[33m");
std::string response = generate(prompt, prev_len == 0);
std::string response = generate(prompt);
printf("\n\033[0m");
// add the response to the messages

View File

@@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl(
device const half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
@@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
@@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
if (tiisg == 0) {
for (int row = 0; row < 2; ++row) {
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
dst_f32[first_row + row] = sumf1[row];
}
}
@@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl(
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
@@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < 2; ++row) {
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = tot;
@@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl(
const int row = 2*r0 + sgitg;
if (row >= args.ne0) {
return;
}
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
@@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.5f;
@@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
@@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
@@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < N_DST; ++row) {
for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
@@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;
@@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < 2; ++row) {
for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = all_sum;

View File

@@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
struct ggml_backend_rpc_buffer_context {
std::shared_ptr<socket_t> sock;
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
void * base_ptr;
uint64_t remote_ptr;
};
@@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
return ctx->base_cache[buffer];
if (ctx->base_ptr != nullptr) {
return ctx->base_ptr;
}
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
rpc_msg_buffer_get_base_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
GGML_ASSERT(status);
void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
ctx->base_cache[buffer] = base_ptr;
return base_ptr;
ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
return ctx->base_ptr;
}
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
if (response.remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
response.remote_size);
return buffer;
} else {

View File

@@ -29,8 +29,6 @@
#include "ggml-vulkan-shaders.hpp"
#define VK_API_VERSION VK_API_VERSION_1_2
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
#define VK_VENDOR_ID_AMD 0x1002
@@ -1614,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1631,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
@@ -2021,7 +2012,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
@@ -2287,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
#endif
VkPhysicalDeviceMaintenance4Features maint4_features {};
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
if (maintenance4_support) {
last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
last_struct = (VkBaseOutStructure *)&maint4_features;
device_extensions.push_back("VK_KHR_maintenance4");
}
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2662,7 +2661,14 @@ void ggml_vk_instance_init() {
vk_instance_initialized = true;
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
uint32_t api_version = vk::enumerateInstanceVersion();
if (api_version < VK_API_VERSION_1_2) {
std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
GGML_ABORT("fatal error");
}
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
@@ -2972,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
GGML_ASSERT(src1_type == GGML_TYPE_F32);
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
case GGML_TYPE_Q4_0:
@@ -3812,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
src1_uma = d_Qy != nullptr;
}
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
// Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
!ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
!ggml_vk_dim01_contiguous(src1);
@@ -4393,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
ids_uma = d_ids != nullptr;
}
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
!ggml_vk_dim01_contiguous(src0);
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
!ggml_vk_dim01_contiguous(src1);
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
@@ -4404,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
if (qx_needs_dequant) {
GGML_ABORT("fatal error");
// Fall back to dequant + f16 mulmat
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
}
// Not implemented

View File

@@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
#include "types.comp"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

View File

@@ -166,7 +166,7 @@ void main() {
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;

View File

@@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#if QUANT_K > 1
#define DECODEFUNCA , dequantFuncA
#define MAT_A_TYPE float16_t
#include "dequant_funcs_cm2.comp"
#else
#define DECODEFUNCA
#define MAT_A_TYPE A_TYPE
#endif
#define MAT_B_TYPE B_TYPE
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
@@ -236,16 +232,13 @@ void main() {
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
} else
#endif // !defined(MUL_MAT_ID)
@@ -261,10 +254,8 @@ void main() {
[[dont_unroll]]
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
// Clamping is expensive, so detect different code paths for each combination
// of A and B needing clamping.
@@ -281,16 +272,12 @@ void main() {
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
#endif
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (unclampedA && !unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (!unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
@@ -298,16 +285,12 @@ void main() {
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
#endif
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (!unclampedA && !unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
}

View File

@@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);

View File

@@ -510,7 +510,8 @@ extern "C" {
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Get the default chat template. Returns nullptr if not available
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
// If name is NULL, returns the default chat template
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name);
// Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);

77
scripts/get_hf_chat_template.py Executable file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python
'''
Fetches the Jinja chat template of a HuggingFace model.
If a model has multiple chat templates, you can specify the variant name.
Syntax:
./scripts/get_hf_chat_template.py model_id [variant]
Examples:
./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct
./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use
./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct
'''
import json
import re
import sys
def get_hf_chat_template(model_id, variant=None):
try:
# Use huggingface_hub library if available.
# Allows access to gated models if the user has access and ran `huggingface-cli login`.
from huggingface_hub import hf_hub_download
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f:
config_str = f.read()
except ImportError:
import requests
assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}"
response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json")
if response.status_code == 401:
raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`')
response.raise_for_status()
config_str = response.text
try:
config = json.loads(config_str)
except json.JSONDecodeError:
# Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
# (Remove extra '}' near the end of the file)
config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str))
chat_template = config['chat_template']
if isinstance(chat_template, str):
return chat_template
else:
variants = {
ct['name']: ct['template']
for ct in chat_template
}
def format_variants():
return ', '.join(f'"{v}"' for v in variants.keys())
if variant is None:
if 'default' not in variants:
raise Exception(f'Please specify a chat template variant (one of {format_variants()})')
variant = 'default'
sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n')
elif variant not in variants:
raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})")
return variants[variant]
def main(args):
if len(args) < 1:
raise ValueError("Please provide a model ID and an optional variant name")
model_id = args[0]
variant = None if len(args) < 2 else args[1]
template = get_hf_chat_template(model_id, variant)
sys.stdout.write(template)
if __name__ == '__main__':
main(sys.argv[1:])

View File

@@ -29,7 +29,7 @@ add_library(llama
unicode-data.cpp
)
target_include_directories(llama PUBLIC . ../include)
target_include_directories(llama PUBLIC . ../include ../common)
target_compile_features (llama PUBLIC cxx_std_17) # don't bump
target_link_libraries(llama PUBLIC ggml)

View File

@@ -179,6 +179,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },
@@ -1443,10 +1444,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {}
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
std::string LLM_KV::operator()(llm_kv kv) const {
return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
: ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
}
std::string LLM_TN_IMPL::str() const {

View File

@@ -177,6 +177,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID,
@@ -335,9 +336,10 @@ enum llm_tensor_layer {
};
struct LLM_KV {
LLM_KV(llm_arch arch);
LLM_KV(llm_arch arch, const char * suffix = nullptr);
llm_arch arch;
const char * suffix;
std::string operator()(llm_kv kv) const;
};

View File

@@ -7,6 +7,7 @@
#include <cstring>
#include <climits>
#include <stdexcept>
#include <cerrno>
#ifdef __has_include
#if __has_include(<unistd.h>)

View File

@@ -3955,8 +3955,10 @@ uint64_t llama_model_size(const struct llama_model * model) {
return model->size();
}
const char * llama_model_chat_template(const struct llama_model * model) {
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
: LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
const auto & it = model->gguf_kv.find(key);
if (it == model->gguf_kv.end()) {
return nullptr;
}

View File

@@ -7,6 +7,16 @@
#include "llama.h"
#include "common.h"
#include "chat-template.hpp"
static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32
static const std::regex nl_regex("\r\n");
return std::regex_replace(s, nl_regex, "\n");
#else
return s;
#endif
}
int main(void) {
std::vector<llama_chat_message> conversation {
@@ -21,156 +31,228 @@ int main(void) {
std::string name;
std::string template_str;
std::string expected_output;
std::string expected_output_jinja;
std::string bos_token = "";
std::string eos_token = "";
bool supported_with_jinja = true;
};
std::vector<TestCase> test_cases {
{
/* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
/* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
/* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ",
/* .template_str= */ "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <<SYS>>\\\\n' + messages[idx]['content'] + '\\\\n<</SYS>>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
/* .template_str= */ "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}",
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
/* .bos_token= */ "<s>",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "bofenghuang/vigogne-2-70b-chat",
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\\\n' + system_message + '\\\\n<</SYS>>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\\\n' + content.strip() + '\\\\n<</SYS>>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST]Hi there</s>[INST] Who are you [/INST]I am an assistant</s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "mlabonne/AlphaMonarch-7B",
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .bos_token= */ "<s>",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "google/gemma-7b-it",
/* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\\n' + message['content'] | trim + '<end_of_turn>\\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\\n'}}{% endif %}",
/* .expected_output= */ "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
/* .expected_output= */ "<start_of_turn>user\nYou are a helpful assistant\n\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
/* .expected_output_jinja= */ "<start_of_turn>user\nYou are a helpful assistant\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
},
{
/* .name= */ "OrionStarAI/Orion-14B-Chat",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "openchat/openchat-3.5-0106",
// The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
// So we match against the included template but implement the suggested version.
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
/* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
/* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
/* .expected_output_jinja= */ "GPT4 Correct System: You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
},
{
/* .name= */ "deepseek-ai/deepseek-coder-33b-instruct",
/* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
/* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
/* .expected_output_jinja= */ "",
},
{
/* .name= */ "eachadea/vicuna-13b-1.1",
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
/* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "Orca-Vicuna",
// No template included in tokenizer_config.json, so this template likely needs to be manually set.
/* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
/* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there</s>\nUSER: Who are you\nASSISTANT: I am an assistant </s>\nUSER: Another question\nASSISTANT:",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "CohereForAI/c4ai-command-r-plus",
/* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
/* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
/* .expected_output_jinja= */ "",
},
{
/* .name= */ "Llama-3",
/* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
/* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
/* .expected_output_jinja= */ "",
},
{
/* .name= */ "Phi-3-mini",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
/* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
},
{
/* .name= */ "Phi-3-small",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
/* .expected_output_jinja= */ "",
},
{
/* .name= */ "Phi-3-medium",
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
/* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
},
{
/* .name= */ "Phi-3-vision",
/* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
/* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "ChatGLM3",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
/* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
/* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
},
{
/* .name= */ "ChatGLM4",
/* .template_str= */ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
/* .expected_output= */ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
/* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
/* .expected_output= */ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "DeepSeek-V2",
/* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
/* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<end▁of▁sentence>User: Who are you\n\nAssistant: I am an assistant <end▁of▁sentence>User: Another question\n\nAssistant:",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "<end▁of▁sentence>",
},
{
/* .name= */ "ibm-granite/granite-3.0-8b-instruct",
/* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}",
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
/* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
/* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
},
{
/* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
/* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
/* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] Another question[/INST]",
/* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant</s>[INST] You are a helpful assistant\n\nAnother question[/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
/* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n",
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
/* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]Another question[/INST]",
/* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there</s>[INST]Who are you[/INST] I am an assistant </s>[INST]You are a helpful assistant\n\nAnother question[/INST]",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)",
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there</s>[INST] Who are you[/INST] I am an assistant </s>[INST] Another question[/INST]",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "</s>",
},
{
/* .name= */ "ai-sage/GigaChat-20B-A3B-instruct",
/* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
/* .expected_output= */ "<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
/* .supported_with_jinja= */ false, // Requires additional_special_tokens as extra context
},
{
/* .name= */ "Infinigence/Megrez-3B-Instruct",
/* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
/* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "phi-4",
/* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
/* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
/* .expected_output_jinja= */ "",
/* .bos_token= */ "",
/* .eos_token= */ "",
},
};
std::vector<char> formatted_chat(1024);
@@ -190,6 +272,7 @@ int main(void) {
// test invalid chat template
res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
assert(res < 0);
const auto add_generation_prompt = true;
for (const auto & test_case : test_cases) {
printf("\n\n=== %s ===\n\n", test_case.name.c_str());
@@ -198,26 +281,59 @@ int main(void) {
test_case.template_str.c_str(),
conversation.data(),
conversation.size(),
true,
add_generation_prompt,
formatted_chat.data(),
formatted_chat.size()
);
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
printf("%s\n", output.c_str());
printf("-------------------------\n");
assert(output == test_case.expected_output);
if (output != test_case.expected_output) {
printf("Expected:\n%s\n", test_case.expected_output.c_str());
printf("-------------------------\n");
printf("Actual:\n%s\n", output.c_str());
fflush(stdout);
assert(output == test_case.expected_output);
}
}
json messages = json::array();
for (const auto & msg : conversation) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
for (const auto & test_case : test_cases) {
if (!test_case.supported_with_jinja) {
continue;
}
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
try {
minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
if (output != expected_output) {
printf("Expected:\n%s\n", expected_output.c_str());
printf("-------------------------\n");
printf("Actual:\n%s\n", output.c_str());
fflush(stdout);
assert(output == expected_output);
}
} catch (const std::exception & e) {
printf("ERROR: %s\n", e.what());
assert(false);
}
}
// test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::vector<common_chat_msg> chat2;
common_chat_msg sys_msg{"system", "You are a helpful assistant"};
auto fmt_sys = [&](std::string tmpl) {
auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
auto fmt_sys = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", "");
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;
};
@@ -241,9 +357,10 @@ int main(void) {
chat2.push_back({"assistant", "I am assistant"});
common_chat_msg new_msg{"user", "How are you"};
auto fmt_single = [&](std::string tmpl) {
auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
auto fmt_single = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", "");
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n");
return output;
};
@@ -258,7 +375,5 @@ int main(void) {
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
printf("Test chat templates: OK\n");
return 0;
}