mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-04 16:14:06 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
408ff524b4 | ||
|
|
5143fa895e | ||
|
|
3a550b5ca4 | ||
|
|
a81283820a | ||
|
|
c610b6c11b | ||
|
|
5d6688de08 | ||
|
|
4fd1242bef | ||
|
|
b2426e469e | ||
|
|
9e2b1e83c6 | ||
|
|
fb15d649ed | ||
|
|
856ed0947f |
@@ -1263,6 +1263,18 @@ static std::string list_builtin_chat_templates() {
|
||||
return msg.str();
|
||||
}
|
||||
|
||||
static bool is_truthy(const std::string & value) {
|
||||
return value == "on" || value == "enabled" || value == "1";
|
||||
}
|
||||
|
||||
static bool is_falsey(const std::string & value) {
|
||||
return value == "off" || value == "disabled" || value == "0";
|
||||
}
|
||||
|
||||
static bool is_autoy(const std::string & value) {
|
||||
return value == "auto" || value == "-1";
|
||||
}
|
||||
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
|
||||
// load dynamic backends
|
||||
ggml_backend_load_all();
|
||||
@@ -1544,21 +1556,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.n_chunks = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
|
||||
add_opt(common_arg(
|
||||
{"-fa", "--flash-attn"}, "FA",
|
||||
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "on" || value == "enabled" || value == "1") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
} else if (value == "off" || value == "disabled" || value == "0") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
} else if (value == "auto" || value == "-1") {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
} else {
|
||||
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||
add_opt(common_arg({ "-fa", "--flash-attn" }, "[on|off|auto]",
|
||||
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')",
|
||||
llama_flash_attn_type_name(params.flash_attn_type)),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
|
||||
} else if (is_falsey(value)) {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
||||
} else if (is_autoy(value)) {
|
||||
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_ARG_FLASH_ATTN"));
|
||||
add_opt(common_arg(
|
||||
{"-p", "--prompt"}, "PROMPT",
|
||||
"prompt to start generation with; for system message, use -sys",
|
||||
@@ -3134,13 +3146,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"},
|
||||
"Enable colored logging",
|
||||
[](common_params &) {
|
||||
common_log_set_colors(common_log_main(), true);
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"-v", "--verbose", "--log-verbose"},
|
||||
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
|
||||
|
||||
@@ -623,6 +623,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
@@ -1184,6 +1185,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
|
||||
});
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Generate the prompt using the apply() function with the template
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;
|
||||
|
||||
// Handle thinking tags appropriately based on inputs.enable_thinking
|
||||
if (string_ends_with(data.prompt, "<think>\n")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</think>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
// When tools are present, build grammar for the <TOOLCALL> format, similar to CommandR, but without tool call ID
|
||||
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = true;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
auto schemas = json::array();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
schemas.push_back({
|
||||
{ "type", "object" },
|
||||
{ "properties",
|
||||
{
|
||||
{ "name",
|
||||
{
|
||||
{ "type", "string" },
|
||||
{ "const", function.at("name") },
|
||||
} },
|
||||
{ "arguments", function.at("parameters") },
|
||||
} },
|
||||
{ "required", json::array({ "name", "arguments" }) },
|
||||
});
|
||||
});
|
||||
auto schema = json{
|
||||
{ "type", "array" },
|
||||
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
|
||||
{ "minItems", 1 },
|
||||
};
|
||||
if (!inputs.parallel_tool_calls) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root",
|
||||
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
|
||||
"\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) +
|
||||
" \"</TOOLCALL>\"");
|
||||
});
|
||||
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||
// If thinking_forced_open, then we capture the </think> tag in the grammar,
|
||||
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
||||
std::string(data.thinking_forced_open ?
|
||||
"[\\s\\S]*?(</think>\\s*)" :
|
||||
"(?:<think>[\\s\\S]*?</think>\\s*)?") +
|
||||
"(<TOOLCALL>)[\\s\\S]*" });
|
||||
}
|
||||
return data;
|
||||
}
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
@@ -1830,7 +1892,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
|
||||
// If thinking_forced_open, then we capture the </think> tag in the grammar,
|
||||
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
|
||||
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
|
||||
"(\\s*"
|
||||
"\\s*("
|
||||
"(?:<tool_call>"
|
||||
"|<function"
|
||||
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
|
||||
@@ -2060,6 +2122,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
}
|
||||
}
|
||||
|
||||
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
if (!builder.try_consume_literal("</TOOLCALL>")) {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
builder.add_tool_calls(tool_calls_data.json);
|
||||
} else {
|
||||
throw common_chat_msg_partial_exception("Incomplete tool call");
|
||||
}
|
||||
}
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags first - this handles the main reasoning content
|
||||
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
|
||||
@@ -2293,6 +2382,11 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_seed_oss(tmpl, params, inputs);
|
||||
}
|
||||
|
||||
// Nemotron v2
|
||||
if (src.find("<SPECIAL_10>") != std::string::npos) {
|
||||
return common_chat_params_init_nemotron_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
@@ -2454,6 +2548,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_SEED_OSS:
|
||||
common_chat_parse_seed_oss(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
|
||||
common_chat_parse_nemotron_v2(builder);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
|
||||
}
|
||||
|
||||
@@ -112,6 +112,7 @@ enum common_chat_format {
|
||||
COMMON_CHAT_FORMAT_GRANITE,
|
||||
COMMON_CHAT_FORMAT_GPT_OSS,
|
||||
COMMON_CHAT_FORMAT_SEED_OSS,
|
||||
COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
||||
@@ -4,17 +4,52 @@
|
||||
#include <condition_variable>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_WIN32)
|
||||
# include <io.h>
|
||||
# include <windows.h>
|
||||
# define isatty _isatty
|
||||
# define fileno _fileno
|
||||
#else
|
||||
# include <unistd.h>
|
||||
#endif // defined(_WIN32)
|
||||
|
||||
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
|
||||
|
||||
void common_log_set_verbosity_thold(int verbosity) {
|
||||
common_log_verbosity_thold = verbosity;
|
||||
}
|
||||
|
||||
// Auto-detect if colors should be enabled based on terminal and environment
|
||||
static bool common_log_should_use_colors_auto() {
|
||||
// Check NO_COLOR environment variable (https://no-color.org/)
|
||||
if (const char * no_color = std::getenv("NO_COLOR")) {
|
||||
if (no_color[0] != '\0') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TERM environment variable
|
||||
if (const char * term = std::getenv("TERM")) {
|
||||
if (std::strcmp(term, "dumb") == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if stdout and stderr are connected to a terminal
|
||||
// We check both because log messages can go to either
|
||||
bool stdout_is_tty = isatty(fileno(stdout));
|
||||
bool stderr_is_tty = isatty(fileno(stderr));
|
||||
|
||||
return stdout_is_tty || stderr_is_tty;
|
||||
}
|
||||
|
||||
static int64_t t_us() {
|
||||
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
}
|
||||
@@ -353,6 +388,11 @@ struct common_log * common_log_init() {
|
||||
|
||||
struct common_log * common_log_main() {
|
||||
static struct common_log log;
|
||||
static std::once_flag init_flag;
|
||||
std::call_once(init_flag, [&]() {
|
||||
// Set default to auto-detect colors
|
||||
log.set_colors(common_log_should_use_colors_auto());
|
||||
});
|
||||
|
||||
return &log;
|
||||
}
|
||||
@@ -380,8 +420,19 @@ void common_log_set_file(struct common_log * log, const char * file) {
|
||||
log->set_file(file);
|
||||
}
|
||||
|
||||
void common_log_set_colors(struct common_log * log, bool colors) {
|
||||
log->set_colors(colors);
|
||||
void common_log_set_colors(struct common_log * log, log_colors colors) {
|
||||
if (colors == LOG_COLORS_AUTO) {
|
||||
log->set_colors(common_log_should_use_colors_auto());
|
||||
return;
|
||||
}
|
||||
|
||||
if (colors == LOG_COLORS_DISABLED) {
|
||||
log->set_colors(false);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
|
||||
log->set_colors(true);
|
||||
}
|
||||
|
||||
void common_log_set_prefix(struct common_log * log, bool prefix) {
|
||||
|
||||
14
common/log.h
14
common/log.h
@@ -24,6 +24,12 @@
|
||||
#define LOG_DEFAULT_DEBUG 1
|
||||
#define LOG_DEFAULT_LLAMA 0
|
||||
|
||||
enum log_colors {
|
||||
LOG_COLORS_AUTO = -1,
|
||||
LOG_COLORS_DISABLED = 0,
|
||||
LOG_COLORS_ENABLED = 1,
|
||||
};
|
||||
|
||||
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
|
||||
// set via common_log_set_verbosity()
|
||||
extern int common_log_verbosity_thold;
|
||||
@@ -65,10 +71,10 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
|
||||
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
||||
//
|
||||
|
||||
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
||||
void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe
|
||||
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
||||
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
||||
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
|
||||
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
||||
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
||||
|
||||
// helper macros for logging
|
||||
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
||||
|
||||
@@ -5122,6 +5122,15 @@ class Gemma3Model(TextModel):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("Gemma3TextModel")
|
||||
class EmbeddingGemma(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self._try_set_pooling_type()
|
||||
|
||||
|
||||
@ModelBase.register("Gemma3ForConditionalGeneration")
|
||||
class Gemma3VisionModel(MmprojModel):
|
||||
def set_gguf_parameters(self):
|
||||
|
||||
@@ -333,17 +333,17 @@ static void print_params(struct my_llama_hparams * params) {
|
||||
}
|
||||
|
||||
static void print_tensor_info(const struct ggml_context * ctx) {
|
||||
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
LOG_INF("%s: Allocating ", __func__);
|
||||
int64_t total = 1;
|
||||
int i = 0;
|
||||
for (; i < ggml_n_dims(t); ++i) {
|
||||
if (i > 0) LOG("x ");
|
||||
LOG("[%" PRId64 "] ", t->ne[i]);
|
||||
if (i > 0) { LOG_INF("x "); }
|
||||
LOG_INF("[%" PRId64 "] ", t->ne[i]);
|
||||
total *= t->ne[i];
|
||||
}
|
||||
if (i > 1) LOG("= [%" PRId64 "] ", total);
|
||||
LOG("float space for %s\n", ggml_get_name(t));
|
||||
if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); }
|
||||
LOG_INF("float space for %s\n", ggml_get_name(t));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ base_model:
|
||||
Recommended way to run this model:
|
||||
|
||||
```sh
|
||||
llama-server -hf {namespace}/{model_name}-GGUF
|
||||
llama-server -hf {namespace}/{model_name}-GGUF --embeddings
|
||||
```
|
||||
|
||||
Then the endpoint can be accessed at http://localhost:8080/embedding, for
|
||||
|
||||
@@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
static const uint3 init_fastdiv_values(uint32_t d) {
|
||||
GGML_ASSERT(d != 0);
|
||||
|
||||
// compute L = ceil(log2(d));
|
||||
uint32_t L = 0;
|
||||
while (L < 32 && (uint32_t{ 1 } << L) < d) {
|
||||
|
||||
@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
|
||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
|
||||
|
||||
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
||||
constexpr int qi = ggml_cuda_type_traits<type>::qi;
|
||||
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
|
||||
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
|
||||
|
||||
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
const uint32_t channel_dst = blockIdx.y;
|
||||
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
|
||||
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
|
||||
const uint32_t sample_dst = blockIdx.z;
|
||||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||
const uint32_t sample_y = sample_dst;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
@@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
|
||||
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
|
||||
|
||||
const int channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int sample_ratio = nsamples_dst / nsamples_x;
|
||||
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
|
||||
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
@@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
{
|
||||
case 1: {
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 2: {
|
||||
constexpr int c_ncols_dst = 2;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 3: {
|
||||
constexpr int c_ncols_dst = 3;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 4: {
|
||||
constexpr int c_ncols_dst = 4;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 5:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 5: {
|
||||
constexpr int c_ncols_dst = 5;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 6:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 6: {
|
||||
constexpr int c_ncols_dst = 6;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 7:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 7: {
|
||||
constexpr int c_ncols_dst = 7;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
case 8:
|
||||
{
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 8: {
|
||||
constexpr int c_ncols_dst = 8;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
break;
|
||||
}
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
#include "quantize.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
|
||||
static __global__ void quantize_q8_1(
|
||||
const float * __restrict__ x, void * __restrict__ vy,
|
||||
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t ne0, const int ne1, const int ne2) {
|
||||
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
|
||||
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i0 >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i3 = fastdiv(blockIdx.z, ne2);
|
||||
const int64_t i2 = blockIdx.z - i3*ne2.z;
|
||||
const int64_t i1 = blockIdx.y;
|
||||
const int64_t i2 = blockIdx.z % ne2;
|
||||
const int64_t i3 = blockIdx.z / ne2;
|
||||
|
||||
const int64_t & i00 = i0;
|
||||
const int64_t & i01 = i1;
|
||||
const int64_t & i02 = i2;
|
||||
const int64_t & i03 = i3;
|
||||
|
||||
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
|
||||
const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
|
||||
|
||||
block_q8_1 * y = (block_q8_1 *) vy;
|
||||
|
||||
@@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
amax = warp_reduce_max(amax);
|
||||
sum = warp_reduce_sum(sum);
|
||||
amax = warp_reduce_max<QK8_1>(amax);
|
||||
sum = warp_reduce_sum<QK8_1>(sum);
|
||||
|
||||
const float d = amax / 127;
|
||||
const float d = amax / 127.0f;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
|
||||
y[ib].qs[iqs] = q;
|
||||
@@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
|
||||
return;
|
||||
}
|
||||
|
||||
reinterpret_cast<half&>(y[ib].ds.x) = d;
|
||||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||
y[ib].ds = make_half2(d, sum);
|
||||
}
|
||||
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
@@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
|
||||
GGML_ASSERT(!ids);
|
||||
GGML_ASSERT(ne0 % QK8_1 == 0);
|
||||
|
||||
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
|
||||
|
||||
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
|
||||
GGML_UNUSED(type_src0);
|
||||
}
|
||||
|
||||
|
||||
@@ -407,6 +407,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
||||
@@ -1439,6 +1440,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10, mul_mm_id_map0_f16_ne20_10, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
||||
@@ -3979,6 +3981,7 @@ static int ggml_metal_encode_node(
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
|
||||
case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break;
|
||||
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
|
||||
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
|
||||
}
|
||||
|
||||
@@ -7618,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
|
||||
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
|
||||
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
|
||||
@@ -1166,50 +1166,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
|
||||
ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
|
||||
}
|
||||
|
||||
struct gguf_writer {
|
||||
std::vector<int8_t> & buf;
|
||||
struct gguf_writer_base {
|
||||
size_t written_bytes {0u};
|
||||
|
||||
gguf_writer(std::vector<int8_t> & buf) : buf(buf) {}
|
||||
~gguf_writer_base(void) {}
|
||||
|
||||
// we bet on devirtualization
|
||||
virtual void write(int8_t val) = 0;
|
||||
virtual void write(const std::vector<int8_t> & val) = 0;
|
||||
virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0;
|
||||
|
||||
template <typename T>
|
||||
void write(const T & val) const {
|
||||
void write(const T & val) {
|
||||
for (size_t i = 0; i < sizeof(val); ++i) {
|
||||
buf.push_back(reinterpret_cast<const int8_t *>(&val)[i]);
|
||||
write(reinterpret_cast<const int8_t *>(&val)[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void write(const std::vector<int8_t> & val) const {
|
||||
buf.insert(buf.end(), val.begin(), val.end());
|
||||
}
|
||||
|
||||
void write(const bool & val) const {
|
||||
void write(const bool & val) {
|
||||
const int8_t val8 = val ? 1 : 0;
|
||||
write(val8);
|
||||
}
|
||||
|
||||
void write(const std::string & val) const {
|
||||
void write(const std::string & val) {
|
||||
{
|
||||
const uint64_t n = val.length();
|
||||
write(n);
|
||||
}
|
||||
for (size_t i = 0; i < val.length(); ++i) {
|
||||
buf.push_back(reinterpret_cast<const int8_t *>(val.data())[i]);
|
||||
write((val.data())[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void write(const char * val) const {
|
||||
void write(const char * val) {
|
||||
write(std::string(val));
|
||||
}
|
||||
|
||||
void write(const enum ggml_type & val) const {
|
||||
void write(const enum ggml_type & val) {
|
||||
write(int32_t(val));
|
||||
}
|
||||
|
||||
void write(const enum gguf_type & val) const {
|
||||
void write(const enum gguf_type & val) {
|
||||
write(int32_t(val));
|
||||
}
|
||||
|
||||
void write(const struct gguf_kv & kv) const {
|
||||
void write(const struct gguf_kv & kv) {
|
||||
const uint64_t ne = kv.get_ne();
|
||||
|
||||
write(kv.get_key());
|
||||
@@ -1250,7 +1251,7 @@ struct gguf_writer {
|
||||
}
|
||||
}
|
||||
|
||||
void write_tensor_meta(const struct gguf_tensor_info & info) const {
|
||||
void write_tensor_meta(const struct gguf_tensor_info & info) {
|
||||
write(info.t.name);
|
||||
|
||||
const uint32_t n_dims = ggml_n_dims(&info.t);
|
||||
@@ -1263,14 +1264,33 @@ struct gguf_writer {
|
||||
write(info.offset);
|
||||
}
|
||||
|
||||
void pad(const size_t alignment) const {
|
||||
while (buf.size() % alignment != 0) {
|
||||
void pad(const size_t alignment) {
|
||||
while (written_bytes % alignment != 0) {
|
||||
const int8_t zero = 0;
|
||||
write(zero);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
|
||||
// vector buffer based writer
|
||||
struct gguf_writer_buf final : public gguf_writer_base {
|
||||
std::vector<int8_t> & buf;
|
||||
|
||||
gguf_writer_buf(std::vector<int8_t> & buf) : buf(buf) {}
|
||||
|
||||
using gguf_writer_base::write;
|
||||
|
||||
void write(const int8_t val) override {
|
||||
buf.push_back(val);
|
||||
written_bytes++;
|
||||
}
|
||||
|
||||
void write(const std::vector<int8_t> & val) override {
|
||||
buf.insert(buf.end(), val.begin(), val.end());
|
||||
written_bytes += val.size();
|
||||
}
|
||||
|
||||
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
|
||||
GGML_ASSERT(buf.size() - offset_data == info.offset);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(&info.t));
|
||||
@@ -1284,14 +1304,58 @@ struct gguf_writer {
|
||||
GGML_ASSERT(info.t.data);
|
||||
memcpy(buf.data() + offset, info.t.data, nbytes);
|
||||
}
|
||||
written_bytes += nbytes;
|
||||
|
||||
pad(alignment);
|
||||
}
|
||||
};
|
||||
|
||||
void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) {
|
||||
const struct gguf_writer gw(buf);
|
||||
// file based writer
|
||||
struct gguf_writer_file final : public gguf_writer_base {
|
||||
FILE * file;
|
||||
|
||||
gguf_writer_file(FILE* file) : file(file) {}
|
||||
|
||||
using gguf_writer_base::write;
|
||||
|
||||
void write(const int8_t val) override {
|
||||
const auto real_val = static_cast<uint8_t>(val);
|
||||
const auto ret = fputc(real_val, file);
|
||||
written_bytes++;
|
||||
if (ret != real_val) {
|
||||
throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'");
|
||||
}
|
||||
}
|
||||
|
||||
void write(const std::vector<int8_t> & val) override {
|
||||
const auto ret = fwrite(val.data(), 1, val.size(), file);
|
||||
written_bytes += val.size();
|
||||
if (ret != val.size()) {
|
||||
throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'");
|
||||
}
|
||||
}
|
||||
|
||||
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
|
||||
GGML_ASSERT(written_bytes - offset_data == info.offset);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(&info.t));
|
||||
const size_t nbytes = ggml_nbytes(&info.t);
|
||||
|
||||
std::vector<int8_t> buf(nbytes);
|
||||
if (info.t.buffer) {
|
||||
ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes);
|
||||
} else {
|
||||
GGML_ASSERT(info.t.data);
|
||||
memcpy(buf.data(), info.t.data, nbytes);
|
||||
}
|
||||
write(buf);
|
||||
|
||||
pad(alignment);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename writer_t>
|
||||
static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) {
|
||||
const int64_t n_kv = gguf_get_n_kv(ctx);
|
||||
const int64_t n_tensors = gguf_get_n_tensors(ctx);
|
||||
|
||||
@@ -1321,7 +1385,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t offset_data = gw.buf.size();
|
||||
const size_t offset_data = gw.written_bytes;
|
||||
|
||||
// write tensor data
|
||||
for (int64_t i = 0; i < n_tensors; ++i) {
|
||||
@@ -1329,6 +1393,11 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu
|
||||
}
|
||||
}
|
||||
|
||||
void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) {
|
||||
gguf_writer_buf gw(buf);
|
||||
gguf_write_out(ctx, gw, only_meta);
|
||||
}
|
||||
|
||||
bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
|
||||
FILE * file = ggml_fopen(fname, "wb");
|
||||
|
||||
@@ -1337,11 +1406,17 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int8_t> buf;
|
||||
gguf_write_to_buf(ctx, buf, only_meta);
|
||||
const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size();
|
||||
try {
|
||||
gguf_writer_file gw(file);
|
||||
gguf_write_out(ctx, gw, only_meta);
|
||||
} catch (const std::runtime_error& ex) {
|
||||
GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what());
|
||||
fclose(file);
|
||||
return false;
|
||||
}
|
||||
|
||||
fclose(file);
|
||||
return ok;
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t gguf_get_meta_size(const struct gguf_context * ctx) {
|
||||
|
||||
@@ -340,6 +340,7 @@ class MODEL_ARCH(IntEnum):
|
||||
GEMMA2 = auto()
|
||||
GEMMA3 = auto()
|
||||
GEMMA3N = auto()
|
||||
GEMMA_EMBEDDING = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
@@ -674,6 +675,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.GEMMA3N: "gemma3n",
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
@@ -1719,6 +1721,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.LAUREL_R,
|
||||
MODEL_TENSOR.LAUREL_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA_EMBEDDING: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.STARCODER2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -14,6 +14,7 @@ class TensorNameMap:
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 plamo2 granite-hybrid
|
||||
"embed_tokens", # embeddinggemma
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
@@ -141,6 +142,7 @@ class TensorNameMap:
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv6
|
||||
"model.layers.{bid}.ln1", # rwkv7
|
||||
"model.layers.{bid}.input_layernorm", # llama4
|
||||
"layers.{bid}.input_layernorm", # embeddinggemma
|
||||
"transformer_encoder.{bid}.attention_norm", # neobert
|
||||
"model.layers.{bid}.operator_norm", # lfm2
|
||||
"model.transformer.blocks.{bid}.attn_norm", # llada
|
||||
@@ -179,6 +181,7 @@ class TensorNameMap:
|
||||
# Attention query
|
||||
MODEL_TENSOR.ATTN_Q: (
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.self_attn.q_proj", # embeddinggemma
|
||||
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
@@ -197,6 +200,7 @@ class TensorNameMap:
|
||||
# Attention key
|
||||
MODEL_TENSOR.ATTN_K: (
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.self_attn.k_proj", # embeddinggemma
|
||||
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
@@ -216,6 +220,7 @@ class TensorNameMap:
|
||||
# Attention value
|
||||
MODEL_TENSOR.ATTN_V: (
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.self_attn.v_proj", # embeddinggemma
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.layer.{bid}.attention.v_lin", # distillbert
|
||||
@@ -239,6 +244,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.self_attn.o_proj", # embeddinggemma
|
||||
"model.layers.{bid}.self_attn.out_proj", # lfm2
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
@@ -277,6 +283,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.ATTN_POST_NORM: (
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge
|
||||
"layers.{bid}.post_attention_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414
|
||||
"model.layers.layers.{bid}.post_mixer_norm.weight", # plamo2
|
||||
),
|
||||
@@ -320,12 +327,14 @@ class TensorNameMap:
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
"layers.{bid}.pre_feedforward_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.pre_ff_layernorm.weight",
|
||||
),
|
||||
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
|
||||
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
|
||||
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
||||
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
|
||||
"model.layers.{bid}.feed_forward.up_proj",
|
||||
@@ -362,6 +371,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
|
||||
"layers.{bid}.mlp.up_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.layer.{bid}.ffn.lin1", # distillbert
|
||||
@@ -421,6 +431,7 @@ class TensorNameMap:
|
||||
# Feed-forward gate
|
||||
MODEL_TENSOR.FFN_GATE: (
|
||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
|
||||
"layers.{bid}.mlp.gate_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||
"transformer.h.{bid}.mlp.w2", # qwen
|
||||
"transformer.h.{bid}.mlp.c_fc2", # jais
|
||||
@@ -461,6 +472,7 @@ class TensorNameMap:
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
|
||||
"layers.{bid}.mlp.down_proj", # embeddinggemma
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.layer.{bid}.ffn.lin2", # distillbert
|
||||
@@ -513,6 +525,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"layers.{bid}.self_attn.q_norm", # embeddinggemma
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
@@ -525,6 +538,7 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"layers.{bid}.self_attn.k_norm", # embeddinggemma
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
|
||||
162
models/templates/NVIDIA-Nemotron-Nano-v2.jinja
Normal file
162
models/templates/NVIDIA-Nemotron-Nano-v2.jinja
Normal file
@@ -0,0 +1,162 @@
|
||||
{%- set ns = namespace(enable_thinking=true) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if message['role'] == 'user' or message['role'] == 'system' -%}
|
||||
{%- if '/think' in content -%}
|
||||
{%- set ns.enable_thinking = true -%}
|
||||
{%- elif '/no_think' in content -%}
|
||||
{%- set ns.enable_thinking = false -%}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if messages[0]['role'] != 'system' -%}
|
||||
{%- set ns.non_tool_system_content = '' -%}
|
||||
{{- '<SPECIAL_10>System
|
||||
' -}}
|
||||
{%- else -%}
|
||||
{%- set ns.non_tool_system_content = (messages[0]['content'] | default('', true)).replace('/think', '').replace('/no_think', '').strip() -%}
|
||||
{{- '<SPECIAL_10>System
|
||||
' + ns.non_tool_system_content }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- if ns.non_tool_system_content is defined and ns.non_tool_system_content != '' -%}
|
||||
{{- '
|
||||
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{{- 'You can use the following tools to assist the user if required:' -}}
|
||||
{{- '
|
||||
<AVAILABLE_TOOLS>[' -}}
|
||||
{%- for tool in tools -%}
|
||||
{{- (tool.function if tool.function is defined else tool) | tojson -}}
|
||||
{{- ', ' if not loop.last else '' -}}
|
||||
{%- endfor -%}
|
||||
{{- ']</AVAILABLE_TOOLS>
|
||||
|
||||
' -}}
|
||||
{{- 'If you decide to call any tool(s), use the following format:
|
||||
' -}}
|
||||
{{- '<TOOLCALL>[{{"name": "tool_name1", "arguments": "tool_args1"}}, ' -}}
|
||||
{{- '{{"name": "tool_name2", "arguments": "tool_args2"}}]</TOOLCALL>
|
||||
|
||||
' -}}
|
||||
{{- 'The user will execute tool-calls and return responses from tool(s) in this format:
|
||||
' -}}
|
||||
{{- '<TOOL_RESPONSE>[{{"tool_response1"}}, {{"tool_response2"}}]</TOOL_RESPONSE>
|
||||
|
||||
' -}}
|
||||
{{- 'Based on the tool responses, you can call additional tools if needed, correct tool calls if any errors are found, or just respond to the user.' -}}
|
||||
{%- endif -%}
|
||||
{{- '
|
||||
|
||||
' -}}
|
||||
{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%}
|
||||
{%- if messages[-1]['role'] == 'assistant' -%}
|
||||
{%- set ns.last_turn_assistant_content = (messages[-1]['content'] | default('', true)).strip() -%}
|
||||
{%- set ns.last_turn_assistant_tool_calls = messages[-1]['tool_calls'] if 'tool_calls' in messages[-1] else [] -%}
|
||||
{%- set messages = messages[:-1] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- for message in messages %}
|
||||
{%- set content = message['content'] %}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<SPECIAL_11>User
|
||||
' + (content | default('', true)).replace('/think', '').replace('/no_think', '').strip() + '
|
||||
' }}
|
||||
{%- elif message['role'] == 'tool' -%}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') -%}
|
||||
{{- '<SPECIAL_11>User
|
||||
' + '<TOOL_RESPONSE>[' }}
|
||||
{%- endif -%}
|
||||
{{- message['content'] -}}
|
||||
{{- ', ' if not loop.last and (messages[loop.index0 + 1].role == 'tool') else '' -}}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') -%}
|
||||
{{- ']</TOOL_RESPONSE>' -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{%- if content and '</think>' in content -%}
|
||||
{%- set content = (content.split('</think>')[1] | default('', true)).strip() %}
|
||||
{%- endif -%}
|
||||
{{- '<SPECIAL_11>Assistant
|
||||
' + ((content | default('', true)).strip() if content is not none else '') }}
|
||||
{%- if message.tool_calls -%}
|
||||
{%- if (content | default('', true)).strip() != '' -%}
|
||||
{{- '
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{{- '<TOOLCALL>[' -}}
|
||||
{%- for call in message.tool_calls -%}
|
||||
{%- set fn = call.function if call.function is defined else call -%}
|
||||
{{- '{"name": "' + fn.name + '", "arguments": ' -}}
|
||||
{%- if fn.arguments is string -%}
|
||||
{{- fn.arguments -}}
|
||||
{%- else -%}
|
||||
{{- fn.arguments | tojson -}}
|
||||
{%- endif -%}
|
||||
{{- '}' + (', ' if not loop.last else '') -}}
|
||||
{%- endfor -%}
|
||||
{{- ']</TOOLCALL>' -}}
|
||||
{%- endif -%}
|
||||
{{- '
|
||||
<SPECIAL_12>
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{{- '<SPECIAL_11>Assistant
|
||||
' -}}
|
||||
{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}
|
||||
{{- '<think></think>' -}}
|
||||
{%- else -%}
|
||||
{{- '<think>
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}
|
||||
{{- ns.last_turn_assistant_content -}}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}
|
||||
{{- '<SPECIAL_11>Assistant
|
||||
' -}}
|
||||
{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}
|
||||
{{- '<think></think>' -}}
|
||||
{%- else -%}
|
||||
{{- '<think>
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{{- ns.last_turn_assistant_content -}}
|
||||
{%- if continue_final_message is defined -%}
|
||||
{%- if continue_final_message is false -%}
|
||||
{{- '
|
||||
<SPECIAL_12>
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{- '
|
||||
<SPECIAL_12>
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if ns.last_turn_assistant_tool_calls is defined and ns.last_turn_assistant_tool_calls | length > 0 -%}
|
||||
{{- '<SPECIAL_11>Assistant
|
||||
' -}}
|
||||
{{- '<TOOLCALL>[' -}}
|
||||
{%- for call in ns.last_turn_assistant_tool_calls -%}
|
||||
{%- set fn = call.function if call.function is defined else call -%}
|
||||
{{- '{"name": "' + fn.name + '", "arguments": ' -}}
|
||||
{%- if fn.arguments is string -%}
|
||||
{{- fn.arguments -}}
|
||||
{%- else -%}
|
||||
{{- fn.arguments | tojson -}}
|
||||
{%- endif -%}
|
||||
{{- '}' + (', ' if not loop.last else '') -}}
|
||||
{%- endfor -%}
|
||||
{{- ']</TOOLCALL>' -}}
|
||||
{{- '<SPECIAL_12>
|
||||
|
||||
' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
504
scripts/jinja/jinja-tester.py
Executable file
504
scripts/jinja/jinja-tester.py
Executable file
@@ -0,0 +1,504 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
import jinja2.ext as jinja2_ext
|
||||
from PySide6.QtWidgets import (
|
||||
QApplication,
|
||||
QMainWindow,
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QPlainTextEdit,
|
||||
QTextEdit,
|
||||
QPushButton,
|
||||
QFileDialog,
|
||||
)
|
||||
from PySide6.QtGui import QColor, QColorConstants, QTextCursor, QTextFormat
|
||||
from PySide6.QtCore import Qt, QRect, QSize
|
||||
from jinja2 import TemplateSyntaxError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def format_template_content(template_content):
|
||||
"""Format the Jinja template content using Jinja2's lexer."""
|
||||
if not template_content.strip():
|
||||
return template_content
|
||||
|
||||
env = ImmutableSandboxedEnvironment()
|
||||
tc_rstrip = template_content.rstrip()
|
||||
tokens = list(env.lex(tc_rstrip))
|
||||
result = ""
|
||||
indent_level = 0
|
||||
i = 0
|
||||
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
_, token_type, token_value = token
|
||||
|
||||
if token_type == "block_begin":
|
||||
block_start = i
|
||||
# Collect all tokens for this block construct
|
||||
construct_content = token_value
|
||||
end_token_type = token_type.replace("_begin", "_end")
|
||||
j = i + 1
|
||||
while j < len(tokens) and tokens[j][1] != end_token_type:
|
||||
construct_content += tokens[j][2]
|
||||
j += 1
|
||||
|
||||
if j < len(tokens): # Found the end token
|
||||
construct_content += tokens[j][2]
|
||||
i = j # Skip to the end token
|
||||
|
||||
# Check for control structure keywords for indentation
|
||||
stripped_content = construct_content.strip()
|
||||
instr = block_start + 1
|
||||
while tokens[instr][1] == "whitespace":
|
||||
instr = instr + 1
|
||||
|
||||
instruction_token = tokens[instr][2]
|
||||
start_control_tokens = ["if", "for", "macro", "call", "block"]
|
||||
end_control_tokens = ["end" + t for t in start_control_tokens]
|
||||
is_control_start = any(
|
||||
instruction_token.startswith(kw) for kw in start_control_tokens
|
||||
)
|
||||
is_control_end = any(
|
||||
instruction_token.startswith(kw) for kw in end_control_tokens
|
||||
)
|
||||
|
||||
# Adjust indentation for control structures
|
||||
# For control end blocks, decrease indent BEFORE adding the content
|
||||
if is_control_end:
|
||||
indent_level = max(0, indent_level - 1)
|
||||
|
||||
# Remove all previous whitespace before this block
|
||||
result = result.rstrip()
|
||||
|
||||
# Add proper indent, but only if this is not the first token
|
||||
added_newline = False
|
||||
if result: # Only add newline and indent if there's already content
|
||||
result += (
|
||||
"\n" + " " * indent_level
|
||||
) # Use 2 spaces per indent level
|
||||
added_newline = True
|
||||
else: # For the first token, don't add any indent
|
||||
result += ""
|
||||
|
||||
# Add the block content
|
||||
result += stripped_content
|
||||
|
||||
# Add '-' after '%' if it wasn't there and we added a newline or indent
|
||||
if (
|
||||
added_newline
|
||||
and stripped_content.startswith("{%")
|
||||
and not stripped_content.startswith("{%-")
|
||||
):
|
||||
# Add '-' at the beginning
|
||||
result = (
|
||||
result[: result.rfind("{%")]
|
||||
+ "{%-"
|
||||
+ result[result.rfind("{%") + 2 :]
|
||||
)
|
||||
if stripped_content.endswith("%}") and not stripped_content.endswith(
|
||||
"-%}"
|
||||
):
|
||||
# Only add '-' if this is not the last token or if there's content after
|
||||
if i + 1 < len(tokens) and tokens[i + 1][1] != "eof":
|
||||
result = result[:-2] + "-%}"
|
||||
|
||||
# For control start blocks, increase indent AFTER adding the content
|
||||
if is_control_start:
|
||||
indent_level += 1
|
||||
else:
|
||||
# Malformed template, just add the token
|
||||
result += token_value
|
||||
elif token_type == "variable_begin":
|
||||
# Collect all tokens for this variable construct
|
||||
construct_content = token_value
|
||||
end_token_type = token_type.replace("_begin", "_end")
|
||||
j = i + 1
|
||||
while j < len(tokens) and tokens[j][1] != end_token_type:
|
||||
construct_content += tokens[j][2]
|
||||
j += 1
|
||||
|
||||
if j < len(tokens): # Found the end token
|
||||
construct_content += tokens[j][2]
|
||||
i = j # Skip to the end token
|
||||
|
||||
# For variable constructs, leave them alone
|
||||
# Do not add indent or whitespace before or after them
|
||||
result += construct_content
|
||||
else:
|
||||
# Malformed template, just add the token
|
||||
result += token_value
|
||||
elif token_type == "data":
|
||||
# Handle data (text between Jinja constructs)
|
||||
# For data content, preserve it as is
|
||||
result += token_value
|
||||
else:
|
||||
# Handle any other tokens
|
||||
result += token_value
|
||||
|
||||
i += 1
|
||||
|
||||
# Clean up trailing newlines and spaces
|
||||
result = result.rstrip()
|
||||
|
||||
# Copy the newline / space count from the original
|
||||
if (trailing_length := len(template_content) - len(tc_rstrip)):
|
||||
result += template_content[-trailing_length:]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Line Number Widget
|
||||
# ------------------------
|
||||
class LineNumberArea(QWidget):
|
||||
def __init__(self, editor):
|
||||
super().__init__(editor)
|
||||
self.code_editor = editor
|
||||
|
||||
def sizeHint(self):
|
||||
return QSize(self.code_editor.line_number_area_width(), 0)
|
||||
|
||||
def paintEvent(self, event):
|
||||
self.code_editor.line_number_area_paint_event(event)
|
||||
|
||||
|
||||
class CodeEditor(QPlainTextEdit):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.line_number_area = LineNumberArea(self)
|
||||
|
||||
self.blockCountChanged.connect(self.update_line_number_area_width)
|
||||
self.updateRequest.connect(self.update_line_number_area)
|
||||
self.cursorPositionChanged.connect(self.highlight_current_line)
|
||||
|
||||
self.update_line_number_area_width(0)
|
||||
self.highlight_current_line()
|
||||
|
||||
def line_number_area_width(self):
|
||||
digits = len(str(self.blockCount()))
|
||||
space = 3 + self.fontMetrics().horizontalAdvance("9") * digits
|
||||
return space
|
||||
|
||||
def update_line_number_area_width(self, _):
|
||||
self.setViewportMargins(self.line_number_area_width(), 0, 0, 0)
|
||||
|
||||
def update_line_number_area(self, rect, dy):
|
||||
if dy:
|
||||
self.line_number_area.scroll(0, dy)
|
||||
else:
|
||||
self.line_number_area.update(
|
||||
0, rect.y(), self.line_number_area.width(), rect.height()
|
||||
)
|
||||
|
||||
if rect.contains(self.viewport().rect()):
|
||||
self.update_line_number_area_width(0)
|
||||
|
||||
def resizeEvent(self, event):
|
||||
super().resizeEvent(event)
|
||||
cr = self.contentsRect()
|
||||
self.line_number_area.setGeometry(
|
||||
QRect(cr.left(), cr.top(), self.line_number_area_width(), cr.height())
|
||||
)
|
||||
|
||||
def line_number_area_paint_event(self, event):
|
||||
from PySide6.QtGui import QPainter
|
||||
|
||||
painter = QPainter(self.line_number_area)
|
||||
painter.fillRect(event.rect(), QColorConstants.LightGray)
|
||||
|
||||
block = self.firstVisibleBlock()
|
||||
block_number = block.blockNumber()
|
||||
top = int(
|
||||
self.blockBoundingGeometry(block).translated(self.contentOffset()).top()
|
||||
)
|
||||
bottom = top + int(self.blockBoundingRect(block).height())
|
||||
|
||||
while block.isValid() and top <= event.rect().bottom():
|
||||
if block.isVisible() and bottom >= event.rect().top():
|
||||
number = str(block_number + 1)
|
||||
painter.setPen(QColorConstants.Black)
|
||||
painter.drawText(
|
||||
0,
|
||||
top,
|
||||
self.line_number_area.width() - 2,
|
||||
self.fontMetrics().height(),
|
||||
Qt.AlignmentFlag.AlignRight,
|
||||
number,
|
||||
)
|
||||
block = block.next()
|
||||
top = bottom
|
||||
bottom = top + int(self.blockBoundingRect(block).height())
|
||||
block_number += 1
|
||||
|
||||
def highlight_current_line(self):
|
||||
extra_selections = []
|
||||
if not self.isReadOnly():
|
||||
selection = QTextEdit.ExtraSelection()
|
||||
line_color = QColorConstants.Yellow.lighter(160)
|
||||
selection.format.setBackground(line_color) # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.format.setProperty(QTextFormat.Property.FullWidthSelection, True) # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.cursor = self.textCursor() # pyright: ignore[reportAttributeAccessIssue]
|
||||
selection.cursor.clearSelection() # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra_selections.append(selection)
|
||||
self.setExtraSelections(extra_selections)
|
||||
|
||||
def highlight_position(self, lineno: int, col: int, color: QColor):
|
||||
block = self.document().findBlockByLineNumber(lineno - 1)
|
||||
if block.isValid():
|
||||
cursor = QTextCursor(block)
|
||||
text = block.text()
|
||||
start = block.position() + max(0, col - 1)
|
||||
cursor.setPosition(start)
|
||||
if col <= len(text):
|
||||
cursor.movePosition(
|
||||
QTextCursor.MoveOperation.NextCharacter,
|
||||
QTextCursor.MoveMode.KeepAnchor,
|
||||
)
|
||||
|
||||
extra = QTextEdit.ExtraSelection()
|
||||
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
self.setExtraSelections(self.extraSelections() + [extra])
|
||||
|
||||
def highlight_line(self, lineno: int, color: QColor):
|
||||
block = self.document().findBlockByLineNumber(lineno - 1)
|
||||
if block.isValid():
|
||||
cursor = QTextCursor(block)
|
||||
cursor.select(QTextCursor.SelectionType.LineUnderCursor)
|
||||
|
||||
extra = QTextEdit.ExtraSelection()
|
||||
extra.format.setBackground(color.lighter(160)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
extra.cursor = cursor # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
self.setExtraSelections(self.extraSelections() + [extra])
|
||||
|
||||
def clear_highlighting(self):
|
||||
self.highlight_current_line()
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Main App
|
||||
# ------------------------
|
||||
class JinjaTester(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("Jinja Template Tester")
|
||||
self.resize(1200, 800)
|
||||
|
||||
central = QWidget()
|
||||
main_layout = QVBoxLayout(central)
|
||||
|
||||
# -------- Top input area --------
|
||||
input_layout = QHBoxLayout()
|
||||
|
||||
# Template editor with label
|
||||
template_layout = QVBoxLayout()
|
||||
template_label = QLabel("Jinja2 Template")
|
||||
template_layout.addWidget(template_label)
|
||||
self.template_edit = CodeEditor()
|
||||
template_layout.addWidget(self.template_edit)
|
||||
input_layout.addLayout(template_layout)
|
||||
|
||||
# JSON editor with label
|
||||
json_layout = QVBoxLayout()
|
||||
json_label = QLabel("Context (JSON)")
|
||||
json_layout.addWidget(json_label)
|
||||
self.json_edit = CodeEditor()
|
||||
self.json_edit.setPlainText("""
|
||||
{
|
||||
"add_generation_prompt": true,
|
||||
"bos_token": "",
|
||||
"eos_token": "",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of Poland?"
|
||||
}
|
||||
]
|
||||
}
|
||||
""".strip())
|
||||
json_layout.addWidget(self.json_edit)
|
||||
input_layout.addLayout(json_layout)
|
||||
|
||||
main_layout.addLayout(input_layout)
|
||||
|
||||
# -------- Rendered output area --------
|
||||
output_label = QLabel("Rendered Output")
|
||||
main_layout.addWidget(output_label)
|
||||
self.output_edit = QPlainTextEdit()
|
||||
self.output_edit.setReadOnly(True)
|
||||
main_layout.addWidget(self.output_edit)
|
||||
|
||||
# -------- Render button and status --------
|
||||
btn_layout = QHBoxLayout()
|
||||
|
||||
# Load template button
|
||||
self.load_btn = QPushButton("Load Template")
|
||||
self.load_btn.clicked.connect(self.load_template)
|
||||
btn_layout.addWidget(self.load_btn)
|
||||
|
||||
# Format template button
|
||||
self.format_btn = QPushButton("Format")
|
||||
self.format_btn.clicked.connect(self.format_template)
|
||||
btn_layout.addWidget(self.format_btn)
|
||||
|
||||
self.render_btn = QPushButton("Render")
|
||||
self.render_btn.clicked.connect(self.render_template)
|
||||
btn_layout.addWidget(self.render_btn)
|
||||
main_layout.addLayout(btn_layout)
|
||||
|
||||
# Status label below buttons
|
||||
self.status_label = QLabel("Ready")
|
||||
main_layout.addWidget(self.status_label)
|
||||
|
||||
self.setCentralWidget(central)
|
||||
|
||||
def render_template(self):
|
||||
self.template_edit.clear_highlighting()
|
||||
self.output_edit.clear()
|
||||
|
||||
template_str = self.template_edit.toPlainText()
|
||||
json_str = self.json_edit.toPlainText()
|
||||
|
||||
# Parse JSON context
|
||||
try:
|
||||
context = json.loads(json_str) if json_str.strip() else {}
|
||||
except Exception as e:
|
||||
self.status_label.setText(f"❌ JSON Error: {e}")
|
||||
return
|
||||
|
||||
def raise_exception(text: str) -> str:
|
||||
raise RuntimeError(text)
|
||||
|
||||
env = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[jinja2_ext.loopcontrols],
|
||||
)
|
||||
env.filters["tojson"] = (
|
||||
lambda x,
|
||||
indent=None,
|
||||
separators=None,
|
||||
sort_keys=False,
|
||||
ensure_ascii=False: json.dumps(
|
||||
x,
|
||||
indent=indent,
|
||||
separators=separators,
|
||||
sort_keys=sort_keys,
|
||||
ensure_ascii=ensure_ascii,
|
||||
)
|
||||
)
|
||||
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
|
||||
env.globals["raise_exception"] = raise_exception
|
||||
try:
|
||||
template = env.from_string(template_str)
|
||||
output = template.render(context)
|
||||
self.output_edit.setPlainText(output)
|
||||
self.status_label.setText("✅ Render successful")
|
||||
except TemplateSyntaxError as e:
|
||||
self.status_label.setText(f"❌ Syntax Error (line {e.lineno}): {e.message}")
|
||||
if e.lineno:
|
||||
self.template_edit.highlight_line(e.lineno, QColor("red"))
|
||||
except Exception as e:
|
||||
# Catch all runtime errors
|
||||
# Try to extract template line number
|
||||
lineno = None
|
||||
tb = e.__traceback__
|
||||
while tb:
|
||||
frame = tb.tb_frame
|
||||
if frame.f_code.co_filename == "<template>":
|
||||
lineno = tb.tb_lineno
|
||||
break
|
||||
tb = tb.tb_next
|
||||
|
||||
error_msg = f"Runtime Error: {type(e).__name__}: {e}"
|
||||
if lineno:
|
||||
error_msg = f"Runtime Error at line {lineno} in template: {type(e).__name__}: {e}"
|
||||
self.template_edit.highlight_line(lineno, QColor("orange"))
|
||||
|
||||
self.output_edit.setPlainText(error_msg)
|
||||
self.status_label.setText(f"❌ {error_msg}")
|
||||
|
||||
def load_template(self):
|
||||
"""Load a Jinja template from a file using a file dialog."""
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
"Load Jinja Template",
|
||||
"",
|
||||
"Template Files (*.jinja *.j2 *.html *.txt);;All Files (*)",
|
||||
)
|
||||
|
||||
if file_path:
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.template_edit.setPlainText(content)
|
||||
self.status_label.setText(f"✅ Loaded template from {file_path}")
|
||||
except Exception as e:
|
||||
self.status_label.setText(f"❌ Error loading file: {str(e)}")
|
||||
|
||||
def format_template(self):
|
||||
"""Format the Jinja template using Jinja2's lexer for proper parsing."""
|
||||
try:
|
||||
template_content = self.template_edit.toPlainText()
|
||||
if not template_content.strip():
|
||||
self.status_label.setText("⚠️ Template is empty")
|
||||
return
|
||||
|
||||
formatted_content = format_template_content(template_content)
|
||||
self.template_edit.setPlainText(formatted_content)
|
||||
self.status_label.setText("✅ Template formatted")
|
||||
except Exception as e:
|
||||
self.status_label.setText(f"❌ Error formatting template: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# CLI mode
|
||||
parser = argparse.ArgumentParser(description="Jinja Template Tester")
|
||||
parser.add_argument(
|
||||
"--template", required=True, help="Path to Jinja template file"
|
||||
)
|
||||
parser.add_argument("--context", required=True, help="JSON string for context")
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
choices=["format", "render"],
|
||||
default="render",
|
||||
help="Action to perform",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load template
|
||||
with open(args.template, "r", encoding="utf-8") as f:
|
||||
template_content = f.read()
|
||||
|
||||
# Load JSON
|
||||
context = json.loads(args.context)
|
||||
# Add missing variables
|
||||
context.setdefault("bos_token", "")
|
||||
context.setdefault("eos_token", "")
|
||||
context.setdefault("add_generation_prompt", False)
|
||||
|
||||
env = ImmutableSandboxedEnvironment()
|
||||
|
||||
if args.action == "format":
|
||||
formatted = format_template_content(template_content)
|
||||
print(formatted) # noqa: NP100
|
||||
elif args.action == "render":
|
||||
template = env.from_string(template_content)
|
||||
output = template.render(context)
|
||||
print(output) # noqa: NP100
|
||||
|
||||
else:
|
||||
# GUI mode
|
||||
app = QApplication(sys.argv)
|
||||
window = JinjaTester()
|
||||
window.show()
|
||||
sys.exit(app.exec())
|
||||
2
scripts/jinja/requirements.txt
Normal file
2
scripts/jinja/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
PySide6
|
||||
jinja2
|
||||
@@ -45,6 +45,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
||||
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_MAMBA2, "mamba2" },
|
||||
@@ -1038,6 +1039,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA_EMBEDDING,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
|
||||
@@ -49,6 +49,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_GEMMA3N,
|
||||
LLM_ARCH_GEMMA_EMBEDDING,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_MAMBA2,
|
||||
|
||||
@@ -258,6 +258,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
static void print_mask(float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) {
|
||||
LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__);
|
||||
const char * swa_type_str = (swa_type == LLAMA_SWA_TYPE_NONE) ? "LLAMA_SWA_TYPE_NONE" :
|
||||
(swa_type == LLAMA_SWA_TYPE_STANDARD) ? "LLAMA_SWA_TYPE_STANDARD" :
|
||||
(swa_type == LLAMA_SWA_TYPE_CHUNKED) ? "LLAMA_SWA_TYPE_CHUNKED" :
|
||||
(swa_type == LLAMA_SWA_TYPE_SYMMETRIC) ? "LLAMA_SWA_TYPE_SYMMETRIC" : "unknown";
|
||||
LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str);
|
||||
LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__);
|
||||
LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__);
|
||||
|
||||
LLAMA_LOG_DEBUG(" ");
|
||||
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
||||
LLAMA_LOG_DEBUG("%2d", j);
|
||||
}
|
||||
LLAMA_LOG_DEBUG("\n");
|
||||
|
||||
for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) {
|
||||
LLAMA_LOG_DEBUG(" %2d ", i);
|
||||
for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) {
|
||||
float val = data[i * n_kv + j];
|
||||
if (val == -INFINITY) {
|
||||
LLAMA_LOG_DEBUG(" ∞");
|
||||
} else {
|
||||
LLAMA_LOG_DEBUG(" 0");
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_DEBUG("\n");
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
const int64_t n_kv = ubatch->n_tokens;
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
@@ -267,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
float * data = (float *) kq_mask->data;
|
||||
|
||||
// [TAG_NO_CACHE_ISWA]
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
|
||||
|
||||
for (int h = 0; h < 1; ++h) {
|
||||
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
||||
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
||||
@@ -277,21 +310,33 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
||||
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
||||
|
||||
// TODO: reimplement this like in llama_kv_cache
|
||||
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
break;
|
||||
if (s0 != s1) {
|
||||
continue; // skip different sequences
|
||||
}
|
||||
|
||||
if (cparams.causal_attn && ubatch->pos[i0] > ubatch->pos[i1]) {
|
||||
continue; // skip future tokens for causal attention
|
||||
}
|
||||
|
||||
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
|
||||
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
|
||||
// continue; // skip masked tokens for SWA
|
||||
//}
|
||||
|
||||
// TODO: reimplement this like in llama_kv_cache_unified
|
||||
if (hparams.use_alibi) {
|
||||
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
||||
} else {
|
||||
f = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (debug) {
|
||||
print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
||||
|
||||
@@ -78,6 +78,11 @@ struct llm_graph_params;
|
||||
|
||||
class llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_i() {
|
||||
const char * LLAMA_GRAPH_INPUT_DEBUG = getenv("LLAMA_GRAPH_INPUT_DEBUG");
|
||||
debug = LLAMA_GRAPH_INPUT_DEBUG ? atoi(LLAMA_GRAPH_INPUT_DEBUG) : 0;
|
||||
}
|
||||
|
||||
virtual ~llm_graph_input_i() = default;
|
||||
|
||||
virtual void set_input(const llama_ubatch * ubatch) = 0;
|
||||
@@ -90,6 +95,9 @@ public:
|
||||
GGML_UNUSED(params);
|
||||
return false;
|
||||
}
|
||||
protected:
|
||||
// env: LLAMA_GRAPH_INPUT_DEBUG
|
||||
int debug = 0;
|
||||
};
|
||||
|
||||
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "llama-hparams.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#include <cassert>
|
||||
|
||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
|
||||
if (dense_first) {
|
||||
@@ -178,3 +179,39 @@ uint32_t llama_hparams::n_layer_kv() const {
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_SYMMETRIC:
|
||||
{
|
||||
const int32_t half_n_swa = (int32_t) n_swa / 2;
|
||||
const int32_t pos_diff = p1 - p0;
|
||||
|
||||
// Mask if outside the symmetric window
|
||||
if (pos_diff < -half_n_swa || pos_diff > half_n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -16,9 +16,10 @@ enum llama_expert_gating_func_type {
|
||||
};
|
||||
|
||||
enum llama_swa_type {
|
||||
LLAMA_SWA_TYPE_NONE = 0,
|
||||
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||
LLAMA_SWA_TYPE_NONE = 0,
|
||||
LLAMA_SWA_TYPE_STANDARD = 1,
|
||||
LLAMA_SWA_TYPE_CHUNKED = 2,
|
||||
LLAMA_SWA_TYPE_SYMMETRIC = 3,
|
||||
};
|
||||
|
||||
struct llama_hparams_posnet {
|
||||
@@ -227,6 +228,11 @@ struct llama_hparams {
|
||||
|
||||
// number of layers for which has_kv() returns true
|
||||
uint32_t n_layer_kv() const;
|
||||
|
||||
// note that this function uses different SWA parameters from those in the hparams
|
||||
// TODO: think of a better place for this function
|
||||
// TODO: pack the SWA params in a struct?
|
||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||
|
||||
@@ -1393,29 +1393,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
||||
}
|
||||
|
||||
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||
assert(p0 >= 0 && p1 >= 0);
|
||||
|
||||
switch (swa_type) {
|
||||
case LLAMA_SWA_TYPE_NONE:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_STANDARD:
|
||||
{
|
||||
if (p1 - p0 >= (int32_t) n_swa) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
case LLAMA_SWA_TYPE_CHUNKED:
|
||||
{
|
||||
const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa;
|
||||
|
||||
if (p0 < pos_chunk_start) {
|
||||
return true;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
|
||||
return false;
|
||||
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
|
||||
@@ -212,6 +212,7 @@ private:
|
||||
// env: LLAMA_KV_CACHE_DEBUG
|
||||
int debug = 0;
|
||||
|
||||
// this is the SWA type of the cache - not to be confused with the model SWA type
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
|
||||
@@ -1142,6 +1142,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
|
||||
hparams.set_swa_pattern(6);
|
||||
|
||||
hparams.causal_attn = false; // embeddings do not use causal attention
|
||||
hparams.rope_freq_base_train_swa = 10000.0f;
|
||||
hparams.rope_freq_scale_train_swa = 1.0f;
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 24: type = LLM_TYPE_0_3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -3484,6 +3504,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
@@ -11045,6 +11066,137 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_gemma_embedding_iswa : public llm_graph_context {
|
||||
llm_build_gemma_embedding_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_k;
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
||||
if (ubatch.token) {
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
|
||||
auto * inp_attn = build_attn_inp_kv_iswa();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
||||
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
|
||||
Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
||||
cb(sa_out, "sa_out", il);
|
||||
|
||||
cur = build_norm(sa_out,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
cb(cur, "ffn_post_norm", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, sa_out);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: move up next to build_starcoder
|
||||
struct llm_build_starcoder2 : public llm_graph_context {
|
||||
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
@@ -18481,6 +18633,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
case LLM_ARCH_NOMIC_BERT_MOE:
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
|
||||
case LLM_ARCH_DREAM:
|
||||
case LLM_ARCH_LLADA:
|
||||
{
|
||||
@@ -18761,6 +18914,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
{
|
||||
llm = std::make_unique<llm_build_gemma_embedding_iswa>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_starcoder2>(*this, params);
|
||||
@@ -19161,6 +19318,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_GEMMA3N:
|
||||
case LLM_ARCH_GEMMA_EMBEDDING:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <thread>
|
||||
@@ -6741,8 +6742,90 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
static void list_all_ops() {
|
||||
printf("GGML operations:\n");
|
||||
std::set<std::string> all_ops;
|
||||
|
||||
for (int i = 1; i < GGML_OP_COUNT; i++) {
|
||||
all_ops.insert(ggml_op_name((enum ggml_op)i));
|
||||
}
|
||||
for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
|
||||
all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));
|
||||
}
|
||||
for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {
|
||||
all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
|
||||
}
|
||||
for (const auto & op : all_ops) {
|
||||
printf(" %s\n", op.c_str());
|
||||
}
|
||||
printf("\nTotal: %zu operations\n", all_ops.size());
|
||||
}
|
||||
|
||||
static void show_test_coverage() {
|
||||
std::set<std::string> all_ops;
|
||||
for (int i = 1; i < GGML_OP_COUNT; i++) {
|
||||
all_ops.insert(ggml_op_name((enum ggml_op)i));
|
||||
}
|
||||
for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
|
||||
all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));
|
||||
}
|
||||
for (int i = 0; i < GGML_GLU_OP_COUNT; i++) {
|
||||
all_ops.insert(ggml_glu_op_name((enum ggml_glu_op)i));
|
||||
}
|
||||
auto test_cases = make_test_cases_eval();
|
||||
std::set<std::string> tested_ops;
|
||||
|
||||
ggml_init_params params = {
|
||||
/* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
|
||||
/* .mem_base = */ NULL,
|
||||
/* .no_alloc = */ true,
|
||||
};
|
||||
|
||||
for (auto & test_case : test_cases) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
if (ctx) {
|
||||
test_case->mode = MODE_TEST;
|
||||
ggml_tensor * out = test_case->build_graph(ctx);
|
||||
if (out && out->op != GGML_OP_NONE) {
|
||||
if (out->op == GGML_OP_UNARY) {
|
||||
tested_ops.insert(ggml_unary_op_name(ggml_get_unary_op(out)));
|
||||
} else if (out->op == GGML_OP_GLU) {
|
||||
tested_ops.insert(ggml_glu_op_name(ggml_get_glu_op(out)));
|
||||
} else {
|
||||
tested_ops.insert(ggml_op_name(out->op));
|
||||
}
|
||||
}
|
||||
ggml_free(ctx);
|
||||
}
|
||||
}
|
||||
std::set<std::string> covered_ops;
|
||||
std::set<std::string> uncovered_ops;
|
||||
for (const auto & op : all_ops) {
|
||||
if (tested_ops.count(op) > 0) {
|
||||
covered_ops.insert(op);
|
||||
} else {
|
||||
uncovered_ops.insert(op);
|
||||
}
|
||||
}
|
||||
|
||||
printf("Operations covered by tests (%zu):\n", covered_ops.size());
|
||||
for (const auto & op : covered_ops) {
|
||||
printf(" ✓ %s\n", op.c_str());
|
||||
}
|
||||
printf("\nOperations without tests (%zu):\n", uncovered_ops.size());
|
||||
for (const auto & op : uncovered_ops) {
|
||||
printf(" ✗ %s\n", op.c_str());
|
||||
}
|
||||
|
||||
printf("\nCoverage Summary:\n");
|
||||
printf(" Total operations: %zu\n", all_ops.size());
|
||||
printf(" Tested operations: %zu\n", covered_ops.size());
|
||||
printf(" Untested operations: %zu\n", uncovered_ops.size());
|
||||
printf(" Coverage: %.1f%%\n", (double)covered_ops.size() / all_ops.size() * 100.0);
|
||||
}
|
||||
|
||||
static void usage(char ** argv) {
|
||||
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
|
||||
printf("Usage: %s [mode] [-o <op,..>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>] [--list-ops] [--show-coverage]\n", argv[0]);
|
||||
printf(" valid modes:\n");
|
||||
printf(" - test (default, compare with CPU backend for correctness)\n");
|
||||
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
|
||||
@@ -6751,6 +6834,8 @@ static void usage(char ** argv) {
|
||||
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n");
|
||||
printf(" optionally including the full test case string (e.g. \"ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\")\n");
|
||||
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
|
||||
printf(" --list-ops lists all available GGML operations\n");
|
||||
printf(" --show-coverage shows test coverage\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
@@ -6800,6 +6885,12 @@ int main(int argc, char ** argv) {
|
||||
usage(argv);
|
||||
return 1;
|
||||
}
|
||||
} else if (strcmp(argv[i], "--list-ops") == 0) {
|
||||
list_all_ops();
|
||||
return 0;
|
||||
} else if (strcmp(argv[i], "--show-coverage") == 0) {
|
||||
show_test_coverage();
|
||||
return 0;
|
||||
} else {
|
||||
usage(argv);
|
||||
return 1;
|
||||
|
||||
@@ -420,6 +420,7 @@ const common_chat_msg message_assist_call_empty_args = simple_assist
|
||||
const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg");
|
||||
const common_chat_msg message_assist_call_thoughts = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}");
|
||||
const common_chat_msg message_assist_call_thoughts_unparsed = simple_assist_msg("<think>I'm\nthinking</think>\n\n", "", "special_function", "{\"arg1\": 1}");
|
||||
const common_chat_msg message_assist_call_thoughts_content = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}");
|
||||
const common_chat_msg message_assist_call_id = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789");
|
||||
const common_chat_msg message_assist_call_idx = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0");
|
||||
const common_chat_msg message_assist_thoughts_call_idx = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0");
|
||||
@@ -436,6 +437,7 @@ static void test_msgs_oaicompat_json_conversion() {
|
||||
message_assist_call,
|
||||
message_assist_call_thoughts,
|
||||
message_assist_call_thoughts_unparsed,
|
||||
message_assist_call_thoughts_content,
|
||||
message_assist_call_id,
|
||||
message_assist_call_idx,
|
||||
message_assist_call_python,
|
||||
@@ -1755,6 +1757,77 @@ static void test_template_output_parsers() {
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_SEED_OSS}));
|
||||
}
|
||||
|
||||
{
|
||||
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-Nano-v2.jinja");
|
||||
std::vector<std::string> end_tokens{ "<SPECIAL_12>" };
|
||||
|
||||
assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
|
||||
// Test parsing regular content
|
||||
assert_msg_equals(message_assist,
|
||||
common_chat_parse(
|
||||
"Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));
|
||||
|
||||
// Test parsing content with thinking
|
||||
assert_msg_equals(message_assist_thoughts,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||
}));
|
||||
|
||||
// Test parsing tool calls
|
||||
assert_msg_equals(message_assist_call,
|
||||
common_chat_parse(
|
||||
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));
|
||||
|
||||
// Test parsing tool calls with thinking
|
||||
assert_msg_equals(message_assist_call_thoughts,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
|
||||
}));
|
||||
|
||||
// Test tool calls with extra content
|
||||
assert_msg_equals(message_assist_call_content,
|
||||
common_chat_parse(
|
||||
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{COMMON_CHAT_FORMAT_NEMOTRON_V2}
|
||||
));
|
||||
|
||||
// Test tool calls with extra content AND thinking
|
||||
assert_msg_equals(message_assist_call_thoughts_content,
|
||||
common_chat_parse(
|
||||
"<think>I'm\nthinking</think><TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>Hello, world!\nWhat's up?",
|
||||
/* is_partial= */ false,
|
||||
{
|
||||
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
|
||||
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
|
||||
}));
|
||||
|
||||
// Test template generation for regular content
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools,
|
||||
"Hello, world!\nWhat's up?\n",
|
||||
/* expect_grammar_triggered= */ false);
|
||||
|
||||
// Test template generation for tool calls
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
|
||||
/* expect_grammar_triggered= */ true
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
static void test_msg_diffs_compute() {
|
||||
|
||||
Reference in New Issue
Block a user