From 98d2d2884e28ca8718774caad691ddd525f2f7b2 Mon Sep 17 00:00:00 2001 From: Kwa Jie Hao <31984694+kwajiehao@users.noreply.github.com> Date: Wed, 22 Apr 2026 02:02:49 +0800 Subject: [PATCH] mtmd: Add support for Reka Edge 2603 (#21616) * feat: (vocab) fix stray text appended in llama_decode_text Remove accidental concatenation of the full `text` string when formatting UNK_BYTE hex escapes. Only the closing "]" should be appended. * feat(mtmd): add Yasa2 vision encoder support Add a Yasa2 (ConvNeXtV2-based) vision encoder for reka-edge: - Register PROJECTOR_TYPE_YASA2 and tensor name definitions - Add yasa2_block/yasa2_stage model structs - Implement graph builder with ConvNeXt stages, GRN, adaptive pooling - Wire into clip.cpp switch statements and mtmd.cpp init_vision - Use mtmd_image_preprocessor_fixed_size for image preprocessing * feat(chat): add reka-edge template handler (tools, thinking) - Add chat-reka.cpp/h implementing PEG-based parser for reka-edge format - Add Reka-Edge.jinja chat template - Detect reka-edge template in try_specialized_template() - Add LLAMA_EXAMPLE_MTMD to chat-template-file arg * feat: add reka vlm to gguf conversion script Converts Reka Yasa2 hf checkpoints to GGUF format: - Text decoder: Llama-arch with tiktoken/BPE vocab - Mmproj (--mmproj): ConvNeXt vision backbone + language_projection - Generates 2D sincos positional embeddings for vision encoder * test: add Reka Edge chat template and parser tests - test-chat-template: oracle tests comparing Jinja engine output vs common_chat_templates_apply for text, tools, thinking, images, video - test-chat: PEG parser tests for Reka Edge format, round-trip tests for image/video content parts, common path integration tests * scripts: add Reka Edge mixed quantization helper Q4_0 base quantization with Q8_0 override for the last 8 transformer blocks (layers 24-31) via --tensor-type regex. * fix: adapt chat-reka and tests to upstream API - Use autoparser::generation_params (not templates_params) - Add p.prefix(generation_prompt) to PEG parser - Simplify reasoning parser to match LFM2 pattern - Remove image/video oracle tests (unsupported by oaicompat parser; no other multimodal models test this path) * fix: avoid duplicate tensor loading in yasa2 vision encoder TN_YASA_PATCH_W and TN_PATCH_EMBD both resolve to "v.patch_embd.weight", causing the same tensor to be loaded twice into ctx_data and overflowing the memory pool. Reuse the tensors already loaded by the common section. * chore: update image pre-processing settings The reka-edge model depends on the following settings in an older fork of llama.cpp: 1. Fixed square resize 2. BICUBIC 3. add_padding=false In current llama.cpp, this means setting: - image_resize_algo = RESIZE_ALGO_BICUBIC - image_resize_pad = false * chore: remove reka gguf conversion script * chore: remove reka quantization script * chore: remove unnecessary changes from PR scope This commit removes a couple of unnecessary changes for the PR scope: 1. BPE decoder bug fix - this affects reka edge because there's a bug in our tokenization that doesn't represent tokens as special tokens. However this isn't meant to be a thinking model so when run with --reasoning off the edge case does not affect us 2. --chat-template-file support from llama-mtmd-cli - the focus is on llama-server and the reka edge gguf contains the necessary metadata to detect the chat template 3. reka edge oracle test cases - no other model has similar test cases, so I removed it for standardization * chore: remove unnecessary ggml_cast This commit removes unnecessary ggml_cast after updating the reka vlm -> gguf conversion script on hugging face. * chore: remove redundant code * chore: remove unnecessary ggml_cont calls This commit removes all ggml_cont calls except the four that precede ggml_reshape_3d/ggml_reshape_4d. Those are necessary because ggml_reshape recomputes strides assuming contiguous layout and asserts ggml_is_contiguous. Other operations (ggml_mean, ggml_add, ggml_mul etc.) use stride-based indexing and handle non-contiguous inputs correctly and so we are ok to remove ggml_cont for those. * chore: remove unnecessary ggml_repeat calls This commit removes unnecessary ggml_repeat calls because the underlying ops already broadcast automatically. Every ggml_repeat in yasa2.cpp was expanding a smaller tensor to match a larger one's shape before passing both to an elementwise op (ggml_add, ggml_sub, ggml_mul, or ggml_div). This is unnecessary because all four of these ops already support broadcasting internally. * chore: restore ggml_cont needed for cpu operations * refactor: locate reka chat template handler in chat.cpp * chore: remove unnecessary warmup tokens * chore: add code comments on image_resize_pad * chore: remove custom reka parsing code * chore: revert common/chat.cpp * Uncomment debug logging for PEG input parsing --------- Co-authored-by: Piotr Wilkin (ilintar) --- tests/test-chat.cpp | 95 ++++++++++++++++++ tools/mtmd/CMakeLists.txt | 1 + tools/mtmd/clip-impl.h | 11 +++ tools/mtmd/clip-model.h | 30 ++++++ tools/mtmd/clip.cpp | 69 +++++++++++++ tools/mtmd/models/models.h | 8 ++ tools/mtmd/models/yasa2.cpp | 191 ++++++++++++++++++++++++++++++++++++ tools/mtmd/mtmd.cpp | 13 +++ 8 files changed, 418 insertions(+) create mode 100644 tools/mtmd/models/yasa2.cpp diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index eee28271be..79e1891b8e 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -3595,6 +3595,51 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .run(); } + // Reka Edge + { + auto tst = peg_tester("models/templates/Reka-Edge.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?") + .enable_thinking(false) + .expect(message_assist) + .run(); + tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + tst.test("Hello, world!\nWhat's up?\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call_content) + .run(); + tst.test("I'm\nthinking\n\n\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n\n\n{\"name\": \"special_function_with_opt\", \"arguments\": {\"arg1\": 1, \"arg2\": 2}}\n") + .enable_thinking(false) + .parallel_tool_calls(true) + .tools({ special_function_tool, special_function_tool_with_optional_param }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg") + .enable_thinking(false) + .tools({ special_function_tool }) + .is_partial(true) + .expect(message_assist_call_cutoff_args) + .run(); + } + // Apriel 1.5 { auto tst = peg_tester("models/templates/unsloth-Apriel-1.5.jinja", detailed_debug); @@ -4077,6 +4122,55 @@ static void test_template_output_peg_parsers(bool detailed_debug) { } } +static void test_reka_edge_common_path() { + auto tmpls = read_templates("models/templates/Reka-Edge.jinja"); + + { + common_chat_templates_inputs inputs; + common_chat_msg system_msg; + system_msg.role = "system"; + system_msg.content = "Use tools when needed."; + + common_chat_msg tool_call_msg = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); + + common_chat_msg tool_msg; + tool_msg.role = "tool"; + tool_msg.tool_name = "special_function"; + tool_msg.tool_call_id = "call0"; + tool_msg.content = "Sunny"; + + inputs.messages = { system_msg, message_user, tool_call_msg, tool_msg, message_user }; + inputs.tools = { special_function_tool }; + inputs.enable_thinking = true; + inputs.add_generation_prompt = true; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + + if (params.prompt.find("\nSunny\n") == std::string::npos) { + throw std::runtime_error("Reka Edge prompt did not render tool response history"); + } + if (params.prompt.rfind("assistant: \n") == std::string::npos) { + throw std::runtime_error("Reka Edge prompt did not render thinking generation prompt"); + } + } + + { + common_chat_templates_inputs inputs; + inputs.messages = { + message_user, + simple_assist_msg("The first point is") + }; + inputs.add_generation_prompt = false; + inputs.enable_thinking = false; + inputs.chat_template_kwargs["continue_final_message"] = "true"; + + auto params = common_chat_templates_apply(tmpls.get(), inputs); + if (string_ends_with(params.prompt, "")) { + throw std::runtime_error("Reka Edge continue_final_message unexpectedly closed the assistant turn"); + } + } +} + // Test the developer role to system workaround with a simple mock template static void test_developer_role_to_system_workaround() { LOG_DBG("%s\n", __func__); @@ -4256,6 +4350,7 @@ int main(int argc, char ** argv) { test_msgs_oaicompat_json_conversion(); test_tools_oaicompat_json_conversion(); test_developer_role_to_system_workaround(); + test_reka_edge_common_path(); test_template_output_peg_parsers(detailed_debug); std::cout << "\n[chat] All tests passed!" << '\n'; } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 399876128e..35d721d5a4 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -40,6 +40,7 @@ add_library(mtmd models/deepseekocr.cpp models/mobilenetv5.cpp models/youtuvl.cpp + models/yasa2.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 17cb703f7f..61fe82439f 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -242,6 +242,15 @@ #define TN_STD_BIAS "v.std_bias" #define TN_STD_SCALE "v.std_scale" +// yasa2 +#define TN_YASA_PATCH_LN_W "v.patch_ln.weight" +#define TN_YASA_PATCH_LN_B "v.patch_ln.bias" +#define TN_YASA_BACKBONE_LN_W "v.backbone_ln.weight" +#define TN_YASA_BACKBONE_LN_B "v.backbone_ln.bias" +#define TN_YASA_POS_EMBD "v.vision_pos_embed" +#define TN_YASA_STAGE_DOWN_LN "v.stage.%d.down.ln.%s" +#define TN_YASA_STAGE_DOWN_CONV "v.stage.%d.down.conv.%s" +#define TN_YASA_STAGE_BLK "v.stage.%d.blk.%d.%s.%s" // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -290,6 +299,7 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_YASA2, PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_HUNYUANOCR, @@ -335,6 +345,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_YASA2, "yasa2"}, { PROJECTOR_TYPE_KIMIK25, "kimik25"}, { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, { PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"}, diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 9a93584d9b..bf8031b55b 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -268,6 +268,27 @@ struct mobilenetv5_block { ggml_tensor * attn_norm_w = nullptr; }; +struct yasa2_block { + ggml_tensor * dw_w = nullptr; + ggml_tensor * dw_b = nullptr; + ggml_tensor * ln_w = nullptr; + ggml_tensor * ln_b = nullptr; + ggml_tensor * pw1_w = nullptr; + ggml_tensor * pw1_b = nullptr; + ggml_tensor * grn_w = nullptr; + ggml_tensor * grn_b = nullptr; + ggml_tensor * pw2_w = nullptr; + ggml_tensor * pw2_b = nullptr; +}; + +struct yasa2_stage { + ggml_tensor * down_ln_w = nullptr; + ggml_tensor * down_ln_b = nullptr; + ggml_tensor * down_conv_w = nullptr; + ggml_tensor * down_conv_b = nullptr; + std::vector blocks; +}; + struct clip_model { clip_modality modality = CLIP_MODALITY_VISION; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -402,6 +423,15 @@ struct clip_model { ggml_tensor * msfa_ffn_expand_bn = nullptr; ggml_tensor * msfa_ffn_project_bn = nullptr; + // yasa2 + ggml_tensor * yasa_patch_w = nullptr; + ggml_tensor * yasa_patch_b = nullptr; + ggml_tensor * yasa_patch_ln_w = nullptr; + ggml_tensor * yasa_patch_ln_b = nullptr; + ggml_tensor * yasa_backbone_ln_w = nullptr; + ggml_tensor * yasa_backbone_ln_b = nullptr; + ggml_tensor * yasa_vision_pos_embed = nullptr; + std::vector yasa_stages; // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f0e8786b66..540b0ea414 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -947,6 +947,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YASA2: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1389,6 +1393,16 @@ struct clip_model_loader { hparams.set_limit_image_tokens(1, 62500); hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_YASA2: + { + hparams.ffn_op = FFN_GELU_ERF; + log_ffn_op = "gelu_erf"; + hparams.image_resize_algo = RESIZE_ALGO_BICUBIC; + + // reka model performs better when using resize_bicubic, which stretches + // the image to fit fixed square size + hparams.image_resize_pad = false; + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1839,6 +1853,55 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YASA2: + { + // reuse tensors already loaded by the common section + // (TN_PATCH_EMBD and TN_PATCH_BIAS have the same tensor names) + GGML_ASSERT(model.patch_embeddings_0 && "yasa2 requires v.patch_embd.weight"); + model.yasa_patch_w = model.patch_embeddings_0; + model.yasa_patch_b = model.patch_bias; + model.yasa_patch_ln_w = get_tensor(TN_YASA_PATCH_LN_W, false); + model.yasa_patch_ln_b = get_tensor(TN_YASA_PATCH_LN_B, false); + model.yasa_backbone_ln_w = get_tensor(TN_YASA_BACKBONE_LN_W, false); + model.yasa_backbone_ln_b = get_tensor(TN_YASA_BACKBONE_LN_B, false); + model.yasa_vision_pos_embed = get_tensor(TN_YASA_POS_EMBD, false); + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + + model.yasa_stages.clear(); + for (int s = 0; ; ++s) { + yasa2_stage stage; + stage.down_ln_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "weight"), false); + stage.down_ln_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_LN, s, "bias"), false); + stage.down_conv_w = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "weight"), false); + stage.down_conv_b = get_tensor(string_format(TN_YASA_STAGE_DOWN_CONV, s, "bias"), false); + + for (int bi = 0; ; ++bi) { + yasa2_block blk; + blk.dw_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "weight"), false); + if (!blk.dw_w) { + break; + } + blk.dw_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "dw", "bias"), false); + blk.ln_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "weight"), false); + blk.ln_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "ln", "bias"), false); + blk.pw1_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "weight"), false); + blk.pw1_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw1", "bias"), false); + blk.grn_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "weight"), false); + blk.grn_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "grn", "bias"), false); + blk.pw2_w = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "weight"), false); + blk.pw2_b = get_tensor(string_format(TN_YASA_STAGE_BLK, s, bi, "pw2", "bias"), false); + stage.blocks.push_back(blk); + } + + if (!stage.down_conv_w && stage.blocks.empty()) { + break; + } + model.yasa_stages.push_back(std::move(stage)); + } + } break; case PROJECTOR_TYPE_GLM4V: { model.mm_fc_w = get_tensor(string_format(TN_MM_PROJECTOR, "weight")); @@ -2843,6 +2906,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { // do nothing } break; + case PROJECTOR_TYPE_YASA2: + { + n_patches = 64; // adaptive average pooling to 8x8 tokens + } break; case PROJECTOR_TYPE_LDP: case PROJECTOR_TYPE_LDPV2: case PROJECTOR_TYPE_GLM_EDGE: @@ -3463,6 +3530,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_COGVLM: case PROJECTOR_TYPE_HUNYUANOCR: + case PROJECTOR_TYPE_YASA2: { // do nothing } break; @@ -3689,6 +3757,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_KIMIVL: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_KIMIK25: + case PROJECTOR_TYPE_YASA2: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_HUNYUANOCR: return ctx->model.mm_model_proj->ne[1]; diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 03d99e15b0..c30d79133e 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -43,6 +43,14 @@ struct clip_graph_youtuvl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_yasa2 : clip_graph { + clip_graph_yasa2(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps = 1e-6f); + ggml_tensor * convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b); +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/yasa2.cpp b/tools/mtmd/models/yasa2.cpp new file mode 100644 index 0000000000..e8cd3dacbf --- /dev/null +++ b/tools/mtmd/models/yasa2.cpp @@ -0,0 +1,191 @@ +// ABOUTME: Yasa2 vision encoder graph builder for ConvNeXt-based architecture. +// ABOUTME: Implements patch embedding, ConvNeXt stages with GRN, and adaptive pooling. + +#include "models.h" + +static ggml_tensor * add_channel_bias( + ggml_context * ctx0, + ggml_tensor * x_whcb, + ggml_tensor * b_c) { + if (!b_c) { + return x_whcb; + } + ggml_tensor * b4 = ggml_reshape_4d(ctx0, b_c, 1, 1, b_c->ne[0], 1); + return ggml_add(ctx0, x_whcb, b4); +} + +static ggml_tensor * mul_channel_weight( + ggml_context * ctx0, + ggml_tensor * x_whcb, + ggml_tensor * w_c) { + if (!w_c) { + return x_whcb; + } + ggml_tensor * w4 = ggml_reshape_4d(ctx0, w_c, 1, 1, w_c->ne[0], 1); + return ggml_mul(ctx0, x_whcb, w4); +} + +ggml_tensor * clip_graph_yasa2::layer_norm_channels(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b, float eps) { + // Match HF ConvNextLayerNorm(channels_first): + // u = mean_c(x), s = mean_c((x-u)^2), x = (x-u)/sqrt(s+eps) + // cast back to input dtype before affine. + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); // [W,H,C,B] -> [C,H,W,B] + cur = ggml_cont(ctx0, cur); + + ggml_tensor * u = ggml_mean(ctx0, cur); // [1,H,W,B] + ggml_tensor * xm = ggml_sub(ctx0, cur, u); // [C,H,W,B] + + ggml_tensor * s = ggml_mul(ctx0, xm, xm); // [C,H,W,B] + s = ggml_mean(ctx0, s); // [1,H,W,B] + s = ggml_clamp(ctx0, s, eps, 1e30f); // avoid div-by-zero in no-alloc warmup + s = ggml_sqrt(ctx0, s); // [1,H,W,B] + + ggml_tensor * xhat = ggml_div(ctx0, xm, s); // [C,H,W,B] + xhat = ggml_permute(ctx0, xhat, 2, 1, 0, 3); // [W,H,C,B] + xhat = ggml_cont(ctx0, xhat); + xhat = mul_channel_weight(ctx0, xhat, w); + xhat = add_channel_bias(ctx0, xhat, b); + return xhat; +} + +ggml_tensor * clip_graph_yasa2::convnext_grn(ggml_tensor * inp, ggml_tensor * w, ggml_tensor * b) { + // Exact ConvNeXtV2 GRN: + // Gx = ||x||_2 over spatial dims (W,H), Nx = Gx / (mean_c(Gx) + eps) + // y = w * (x * Nx) + b + x + const int64_t wdim = inp->ne[0]; + const int64_t hdim = inp->ne[1]; + const int64_t cdim = inp->ne[2]; + const int64_t bdim = inp->ne[3]; + + // Keep GRN math in fp32 for stability; fp16/bf16 accumulation can drift. + ggml_tensor * sq = ggml_mul(ctx0, inp, inp); + ggml_tensor * sq_flat = ggml_reshape_4d(ctx0, sq, wdim * hdim, cdim, 1, bdim); // [WH,C,1,B] + ggml_tensor * gx = ggml_sum_rows(ctx0, sq_flat); // [1,C,1,B] + gx = ggml_sqrt(ctx0, gx); // [1,C,1,B] + + ggml_tensor * gx_ch_first = ggml_permute(ctx0, gx, 1, 0, 2, 3); // [C,1,1,B] + gx_ch_first = ggml_cont(ctx0, gx_ch_first); + ggml_tensor * gx_mean = ggml_mean(ctx0, gx_ch_first); // [1,1,1,B] + + gx_mean = ggml_clamp(ctx0, gx_mean, 1e-6f, 1e30f); // approx +eps, warmup-safe + ggml_tensor * nx = ggml_div(ctx0, gx, gx_mean); // [1,C,1,B] + nx = ggml_permute(ctx0, nx, 0, 2, 1, 3); // [1,1,C,B] + nx = ggml_cont(ctx0, nx); + + ggml_tensor * xnx = ggml_mul(ctx0, inp, nx); + xnx = mul_channel_weight(ctx0, xnx, w); + xnx = add_channel_bias(ctx0, xnx, b); + return ggml_add(ctx0, inp, xnx); +} + +ggml_cgraph * clip_graph_yasa2::build() { + ggml_tensor * cur = build_inp_raw(); + + // Patch embedding Conv2d(kernel=4, stride=4) + cur = ggml_conv_2d(ctx0, model.yasa_patch_w, cur, patch_size, patch_size, 0, 0, 1, 1); + cur = add_channel_bias(ctx0, cur, model.yasa_patch_b); + ggml_set_name(cur, "yasa2_patch_conv_out"); + cb(cur, "yasa2_patch_conv_out", -1); + cur = layer_norm_channels(cur, model.yasa_patch_ln_w, model.yasa_patch_ln_b, eps); + ggml_set_name(cur, "yasa2_patch_ln_out"); + cb(cur, "yasa2_patch_ln_out", -1); + + // ConvNeXt stages + for (size_t s = 0; s < model.yasa_stages.size(); ++s) { + const auto & stage = model.yasa_stages[s]; + + if (stage.down_conv_w) { + cur = layer_norm_channels(cur, stage.down_ln_w, stage.down_ln_b, eps); + cur = ggml_conv_2d(ctx0, stage.down_conv_w, cur, 2, 2, 0, 0, 1, 1); + cur = add_channel_bias(ctx0, cur, stage.down_conv_b); + ggml_format_name(cur, "yasa2_stage%zu_down_out", s); + } + + for (size_t bi = 0; bi < stage.blocks.size(); ++bi) { + const auto & blk = stage.blocks[bi]; + ggml_tensor * res = cur; + + ggml_tensor * x = ggml_conv_2d_dw(ctx0, blk.dw_w, cur, 1, 1, 3, 3, 1, 1); + x = add_channel_bias(ctx0, x, blk.dw_b); + x = layer_norm_channels(x, blk.ln_w, blk.ln_b, eps); + + // pwconv1/pwconv2 are HF Linear layers over channels; implement via matmul on tokens. + const int64_t w = x->ne[0]; + const int64_t h = x->ne[1]; + const int64_t b = x->ne[3]; + + ggml_tensor * tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,C,B] + tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [C,T,B] + tok = ggml_cont(ctx0, tok); + + tok = ggml_mul_mat(ctx0, blk.pw1_w, tok); // [4C,T,B] + if (blk.pw1_b) { + ggml_tensor * b1 = ggml_reshape_3d(ctx0, blk.pw1_b, blk.pw1_b->ne[0], 1, 1); // [4C,1,1] + tok = ggml_add(ctx0, tok, b1); + } + x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,4C,B] + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,4C,B] + x = ggml_gelu_erf(ctx0, x); + x = convnext_grn(x, blk.grn_w, blk.grn_b); + + tok = ggml_reshape_3d(ctx0, x, w * h, x->ne[2], b); // [T,4C,B] + tok = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [4C,T,B] + tok = ggml_cont(ctx0, tok); + + tok = ggml_mul_mat(ctx0, blk.pw2_w, tok); // [C,T,B] + if (blk.pw2_b) { + ggml_tensor * b2 = ggml_reshape_3d(ctx0, blk.pw2_b, blk.pw2_b->ne[0], 1, 1); // [C,1,1] + tok = ggml_add(ctx0, tok, b2); + } + x = ggml_permute(ctx0, tok, 1, 0, 2, 3); // [T,C,B] + x = ggml_cont(ctx0, x); + x = ggml_reshape_4d(ctx0, x, w, h, tok->ne[0], b); // [W,H,C,B] + + cur = ggml_add(ctx0, res, x); + ggml_format_name(cur, "yasa2_stage%zu_blk%zu_out", s, bi); + } + } + + // HF path adds vision position embeddings BEFORE adaptive pooling. + const int64_t pre_w = cur->ne[0]; + const int64_t pre_h = cur->ne[1]; + ggml_tensor * tokens_pre = ggml_reshape_3d(ctx0, cur, pre_w * pre_h, cur->ne[2], cur->ne[3]); // [T,C,B] + tokens_pre = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [C,T,B] + tokens_pre = ggml_cont(ctx0, tokens_pre); + if (model.yasa_vision_pos_embed && tokens_pre->ne[1] == model.yasa_vision_pos_embed->ne[1]) { + const int64_t n_ch = model.yasa_vision_pos_embed->ne[0]; + const int64_t n_tokens = model.yasa_vision_pos_embed->ne[1]; + ggml_tensor * pos = ggml_reshape_3d(ctx0, model.yasa_vision_pos_embed, (int) n_ch, (int) n_tokens, 1); + tokens_pre = ggml_add(ctx0, tokens_pre, pos); + } + cur = ggml_permute(ctx0, tokens_pre, 1, 0, 2, 3); // [T,C,B] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_4d(ctx0, cur, pre_w, pre_h, cur->ne[1], cur->ne[2]); // [W,H,C,B] + + // AdaptiveAvgPool2d target is 8x8 for real inputs, but warmup can use tiny images. + const int pooled_w = std::min(8, (int) cur->ne[0]); + const int pooled_h = std::min(8, (int) cur->ne[1]); + const int kw = std::max(1, (int) cur->ne[0] / pooled_w); + const int kh = std::max(1, (int) cur->ne[1] / pooled_h); + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kw, kh, kw, kh, 0, 0); + + // [W,H,C,B] -> [C,T,B] + ggml_tensor * tokens = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2], cur->ne[3]); + tokens = ggml_permute(ctx0, tokens, 1, 0, 2, 3); + tokens = ggml_cont(ctx0, tokens); + cb(tokens, "yasa2_tokens", -1); + + GGML_ASSERT(model.mm_0_w && model.mm_2_w); + ggml_tensor * embeddings = build_ffn( + tokens, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + cb(embeddings, "yasa2_emb", -1); + + ggml_build_forward_expand(gf, embeddings); + return gf; +} diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 854ac81e0a..cc3de6a858 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -316,6 +316,19 @@ struct mtmd_context { img_end = "<|vision_end|>"; image_preproc = std::make_unique(ctx_v); } break; + case PROJECTOR_TYPE_YASA2: + { + img_beg = ""; + img_end = ""; + // Currently only supprots single-tile preprocessing: any input is downscaled + // to one image_size x image_size tile (64 output tokens via 8x8 adaptive avg + // pool). + // However, the model itself supports llava-uhd multi-tile tiling for high-res + // images. This will be implemented in a future PR (dispatch on has_pinpoints + // - see LDP/COGVLM branch above) and emit image_grid_pinpoints in the conversion + // script. + image_preproc = std::make_unique(ctx_v); + } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_GEMMA3NV: {