From d34ff7eb5ba289bd61f659c8f3f48c983c3ce4f8 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 17 Mar 2026 00:31:14 +0100 Subject: [PATCH] model: mistral small 4 support (#20649) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * model: mistral small 4 support * fix test * fix test (2) * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * change newline --------- Co-authored-by: Sigbjørn Skjæret --- convert_hf_to_gguf.py | 122 +++++++++++++++++++++++++------------ gguf-py/gguf/constants.py | 33 ++++++++++ src/llama-arch.cpp | 2 + src/llama-arch.h | 1 + src/llama-model.cpp | 6 +- tests/test-llama-archs.cpp | 11 +++- 6 files changed, 133 insertions(+), 42 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b4ff8dd959..46469c8620 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -298,11 +298,16 @@ class ModelBase: scale = scale.float() if block_size is not None: + dim_offset = scale.ndim - len(block_size) for i, size in enumerate(block_size): - scale = scale.repeat_interleave(size, i) + scale = scale.repeat_interleave(size, dim_offset + i) # unpad the scale (e.g. when the tensor size isn't a multiple of the block size) scale = scale[tuple(slice(0, size) for size in weight.shape)] + # align scale dims to weight for correct broadcasting (e.g. [128] -> [128, 1, 1]) + while scale.ndim < weight.ndim: + scale = scale.unsqueeze(-1) + return weight.float() * scale # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476 @@ -393,7 +398,7 @@ class ModelBase: elif quant_method == "fp8": block_size = quant_config.get("weight_block_size") for name in self.model_tensors.keys(): - if name.endswith(".weight_scale_inv"): + if name.endswith("_scale_inv"): weight_name = name.removesuffix("_scale_inv") w = self.model_tensors[weight_name] s = self.model_tensors[name] @@ -401,6 +406,8 @@ class ModelBase: tensors_to_remove.append(name) if name.endswith(".activation_scale"): # unused tensors_to_remove.append(name) + if name.endswith("_activation_scale"): # Mistral-Small-4-119B-2602, unused + tensors_to_remove.append(name) # mistral format if name.endswith(".qscale_weight"): weight_name = name.removesuffix("qscale_weight") + "weight" @@ -3031,10 +3038,16 @@ class LlavaVisionModel(MmprojModel): def get_token_id(self, token: str) -> int: tokenizer_config_file = self.dir_model / 'tokenizer_config.json' with open(tokenizer_config_file, "r", encoding="utf-8") as f: - added_tokens_decoder = json.load(f)['added_tokens_decoder'] + added_tokens_decoder = json.load(f).get('added_tokens_decoder') or {} for id_, token_data in added_tokens_decoder.items(): - if token_data["content"] == token: + if token_data.get("content") == token: return int(id_) + # fallthrough to tokenizer.json + with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + for token_data in tokenizer_json["added_tokens"]: + if token_data["content"] == token: + return int(token_data["id"]) raise ValueError(f"Token '{token}' not found in tokenizer config.") def set_gguf_parameters(self): @@ -3198,40 +3211,6 @@ class Llama4VisionModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register( - "Mistral3ForConditionalGeneration", - "Ministral3ForCausalLM", -) -class Mistral3Model(LlamaModel): - model_arch = gguf.MODEL_ARCH.MISTRAL3 - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # for compatibility, we use LLAMA arch for older models - # TODO: remove this once everyone has migrated to newer version of llama.cpp - if self.hparams.get("model_type") != "ministral3": - self.model_arch = gguf.MODEL_ARCH.LLAMA - self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] - self.gguf_writer.add_architecture() - self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) - - def set_gguf_parameters(self): - super().set_gguf_parameters() - rope_params = self.rope_parameters - if self.hparams.get("model_type") == "ministral3": - assert rope_params, "ministral3 must have 'rope_parameters' config" - assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" - self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"]) - self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - name = name.replace("language_model.", "") - if "multi_modal_projector" in name or "vision_tower" in name: - return - - yield from super().modify_tensors(data_torch, name, bid) - - @ModelBase.register("DeciLMForCausalLM") class DeciModel(TextModel): model_arch = gguf.MODEL_ARCH.DECI @@ -8271,6 +8250,8 @@ class DeepseekV2Model(TextModel): # TODO @ngxson : remove this when we support MTP for deepseek models skip_mtp = True + merge_expert = True + def set_vocab(self): try: self._set_vocab_gpt2() @@ -8409,7 +8390,7 @@ class DeepseekV2Model(TextModel): return # process the experts separately - if name.find("mlp.experts") != -1: + if self.merge_expert and name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] assert bid is not None @@ -8468,6 +8449,69 @@ class DeepseekV2Model(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register( + "Mistral3ForConditionalGeneration", + "Ministral3ForCausalLM", +) +class Mistral3Model(TextModel): + class Ministral3Model(LlamaModel): + model_arch = gguf.MODEL_ARCH.MISTRAL3 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + rope_params = self.rope_parameters + if self.hparams.get("model_type") == "ministral3": + assert rope_params, "ministral3 must have 'rope_parameters' config" + assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'" + self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"]) + self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("language_model.", "") + if "multi_modal_projector" in name or "vision_tower" in name: + return + + yield from super().modify_tensors(data_torch, name, bid) + + class Mistral4Model(DeepseekV2Model): + model_arch = gguf.MODEL_ARCH.MISTRAL4 + skip_mtp = False # model contains no MTP layers, so no need to skip + merge_expert = False # experts are already stacked as 3D + + def modify_tensors(self, data_torch, name, bid): + if name.endswith(".down_proj") or name.endswith(".gate_up_proj"): + name = name + ".weight" + yield from super().modify_tensors(data_torch, name, bid) + + model_arch = gguf.MODEL_ARCH.MISTRAL3 # unused + impl: TextModel + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.hparams.get("model_type") == "mistral4": + self.impl = Mistral3Model.Mistral4Model(*args, **kwargs) + else: + self.impl = Mistral3Model.Ministral3Model(*args, **kwargs) + + def set_vocab(self): + self.impl.set_vocab() + + def set_gguf_parameters(self): + self.impl.set_gguf_parameters() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + yield from self.impl.modify_tensors(data_torch, name, bid) + + def prepare_tensors(self): + self.impl.prepare_tensors() + + def write_vocab(self): + self.impl.write_vocab() + + def write(self): + self.impl.write() + + @ModelBase.register("MiniMaxM2ForCausalLM") class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index bf617382d0..0a032e9039 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -478,6 +478,7 @@ class MODEL_ARCH(IntEnum): RND1 = auto() PANGU_EMBED = auto() MISTRAL3 = auto() + MISTRAL4 = auto() PADDLEOCR = auto() MIMO2 = auto() STEP35 = auto() @@ -924,6 +925,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", + MODEL_ARCH.MISTRAL4: "mistral4", MODEL_ARCH.PADDLEOCR: "paddleocr", MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.STEP35: "step35", @@ -3538,6 +3540,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.MISTRAL4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.MIMO2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 799d16167b..84dc6d8f1b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -123,6 +123,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, @@ -1589,6 +1590,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_SHEXP, }; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: return { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, diff --git a/src/llama-arch.h b/src/llama-arch.h index b1b1dcf188..9b9eec2f5c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -127,6 +127,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MISTRAL4, LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, LLM_ARCH_STEP35, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bae02e32b1..85db938a7a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1587,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); @@ -4883,6 +4884,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { const bool is_mla = hparams.is_mla(); @@ -7850,7 +7852,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -8428,6 +8430,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_GLM_DSA: + case LLM_ARCH_MISTRAL4: { llm = std::make_unique(*this, params); } break; @@ -8839,6 +8842,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: case LLM_ARCH_GLM_DSA: diff --git a/tests/test-llama-archs.cpp b/tests/test-llama-archs.cpp index 014b3f2b14..d51c09e99f 100644 --- a/tests/test-llama-archs.cpp +++ b/tests/test-llama-archs.cpp @@ -90,7 +90,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) { n_embd = 64; n_head = 1; n_ff = 96; - } else if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) { + } else if (arch == LLM_ARCH_DEEPSEEK2 + || arch == LLM_ARCH_GLM_DSA + || arch == LLM_ARCH_KIMI_LINEAR + || arch == LLM_ARCH_MISTRAL4) { n_embd = 128; n_head = 1; n_ff = 192; @@ -145,7 +148,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) { } ms.add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, 8.0f); - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) { + if (arch == LLM_ARCH_DEEPSEEK2 + || arch == LLM_ARCH_GLM_DSA + || arch == LLM_ARCH_KIMI_LINEAR + || arch == LLM_ARCH_MISTRAL4) { ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH, uint32_t(576)); ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, uint32_t(512)); ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, uint32_t(64)); @@ -319,6 +325,7 @@ static bool moe_mandatory(const llm_arch arch) { case LLM_ARCH_MIMO2: case LLM_ARCH_KIMI_LINEAR: case LLM_ARCH_STEP35: + case LLM_ARCH_MISTRAL4: return true; default: return false;