model: mistral small 4 support (#20649)

* model: mistral small 4 support

* fix test

* fix test (2)

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* change newline

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Xuan-Son Nguyen
2026-03-17 00:31:14 +01:00
committed by GitHub
parent 45172df4d6
commit d34ff7eb5b
6 changed files with 133 additions and 42 deletions

View File

@@ -298,11 +298,16 @@ class ModelBase:
scale = scale.float() scale = scale.float()
if block_size is not None: if block_size is not None:
dim_offset = scale.ndim - len(block_size)
for i, size in enumerate(block_size): for i, size in enumerate(block_size):
scale = scale.repeat_interleave(size, i) scale = scale.repeat_interleave(size, dim_offset + i)
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size) # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
scale = scale[tuple(slice(0, size) for size in weight.shape)] scale = scale[tuple(slice(0, size) for size in weight.shape)]
# align scale dims to weight for correct broadcasting (e.g. [128] -> [128, 1, 1])
while scale.ndim < weight.ndim:
scale = scale.unsqueeze(-1)
return weight.float() * scale return weight.float() * scale
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476 # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
@@ -393,7 +398,7 @@ class ModelBase:
elif quant_method == "fp8": elif quant_method == "fp8":
block_size = quant_config.get("weight_block_size") block_size = quant_config.get("weight_block_size")
for name in self.model_tensors.keys(): for name in self.model_tensors.keys():
if name.endswith(".weight_scale_inv"): if name.endswith("_scale_inv"):
weight_name = name.removesuffix("_scale_inv") weight_name = name.removesuffix("_scale_inv")
w = self.model_tensors[weight_name] w = self.model_tensors[weight_name]
s = self.model_tensors[name] s = self.model_tensors[name]
@@ -401,6 +406,8 @@ class ModelBase:
tensors_to_remove.append(name) tensors_to_remove.append(name)
if name.endswith(".activation_scale"): # unused if name.endswith(".activation_scale"): # unused
tensors_to_remove.append(name) tensors_to_remove.append(name)
if name.endswith("_activation_scale"): # Mistral-Small-4-119B-2602, unused
tensors_to_remove.append(name)
# mistral format # mistral format
if name.endswith(".qscale_weight"): if name.endswith(".qscale_weight"):
weight_name = name.removesuffix("qscale_weight") + "weight" weight_name = name.removesuffix("qscale_weight") + "weight"
@@ -3031,10 +3038,16 @@ class LlavaVisionModel(MmprojModel):
def get_token_id(self, token: str) -> int: def get_token_id(self, token: str) -> int:
tokenizer_config_file = self.dir_model / 'tokenizer_config.json' tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
with open(tokenizer_config_file, "r", encoding="utf-8") as f: with open(tokenizer_config_file, "r", encoding="utf-8") as f:
added_tokens_decoder = json.load(f)['added_tokens_decoder'] added_tokens_decoder = json.load(f).get('added_tokens_decoder') or {}
for id_, token_data in added_tokens_decoder.items(): for id_, token_data in added_tokens_decoder.items():
if token_data["content"] == token: if token_data.get("content") == token:
return int(id_) return int(id_)
# fallthrough to tokenizer.json
with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
for token_data in tokenizer_json["added_tokens"]:
if token_data["content"] == token:
return int(token_data["id"])
raise ValueError(f"Token '{token}' not found in tokenizer config.") raise ValueError(f"Token '{token}' not found in tokenizer config.")
def set_gguf_parameters(self): def set_gguf_parameters(self):
@@ -3198,40 +3211,6 @@ class Llama4VisionModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid) yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register(
"Mistral3ForConditionalGeneration",
"Ministral3ForCausalLM",
)
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.MISTRAL3
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# for compatibility, we use LLAMA arch for older models
# TODO: remove this once everyone has migrated to newer version of llama.cpp
if self.hparams.get("model_type") != "ministral3":
self.model_arch = gguf.MODEL_ARCH.LLAMA
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_params = self.rope_parameters
if self.hparams.get("model_type") == "ministral3":
assert rope_params, "ministral3 must have 'rope_parameters' config"
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
if "multi_modal_projector" in name or "vision_tower" in name:
return
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("DeciLMForCausalLM") @ModelBase.register("DeciLMForCausalLM")
class DeciModel(TextModel): class DeciModel(TextModel):
model_arch = gguf.MODEL_ARCH.DECI model_arch = gguf.MODEL_ARCH.DECI
@@ -8271,6 +8250,8 @@ class DeepseekV2Model(TextModel):
# TODO @ngxson : remove this when we support MTP for deepseek models # TODO @ngxson : remove this when we support MTP for deepseek models
skip_mtp = True skip_mtp = True
merge_expert = True
def set_vocab(self): def set_vocab(self):
try: try:
self._set_vocab_gpt2() self._set_vocab_gpt2()
@@ -8409,7 +8390,7 @@ class DeepseekV2Model(TextModel):
return return
# process the experts separately # process the experts separately
if name.find("mlp.experts") != -1: if self.merge_expert and name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"] n_experts = self.hparams["n_routed_experts"]
assert bid is not None assert bid is not None
@@ -8468,6 +8449,69 @@ class DeepseekV2Model(TextModel):
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register(
"Mistral3ForConditionalGeneration",
"Ministral3ForCausalLM",
)
class Mistral3Model(TextModel):
class Ministral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.MISTRAL3
def set_gguf_parameters(self):
super().set_gguf_parameters()
rope_params = self.rope_parameters
if self.hparams.get("model_type") == "ministral3":
assert rope_params, "ministral3 must have 'rope_parameters' config"
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
if "multi_modal_projector" in name or "vision_tower" in name:
return
yield from super().modify_tensors(data_torch, name, bid)
class Mistral4Model(DeepseekV2Model):
model_arch = gguf.MODEL_ARCH.MISTRAL4
skip_mtp = False # model contains no MTP layers, so no need to skip
merge_expert = False # experts are already stacked as 3D
def modify_tensors(self, data_torch, name, bid):
if name.endswith(".down_proj") or name.endswith(".gate_up_proj"):
name = name + ".weight"
yield from super().modify_tensors(data_torch, name, bid)
model_arch = gguf.MODEL_ARCH.MISTRAL3 # unused
impl: TextModel
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams.get("model_type") == "mistral4":
self.impl = Mistral3Model.Mistral4Model(*args, **kwargs)
else:
self.impl = Mistral3Model.Ministral3Model(*args, **kwargs)
def set_vocab(self):
self.impl.set_vocab()
def set_gguf_parameters(self):
self.impl.set_gguf_parameters()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
yield from self.impl.modify_tensors(data_torch, name, bid)
def prepare_tensors(self):
self.impl.prepare_tensors()
def write_vocab(self):
self.impl.write_vocab()
def write(self):
self.impl.write()
@ModelBase.register("MiniMaxM2ForCausalLM") @ModelBase.register("MiniMaxM2ForCausalLM")
class MiniMaxM2Model(TextModel): class MiniMaxM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MINIMAXM2 model_arch = gguf.MODEL_ARCH.MINIMAXM2

View File

@@ -478,6 +478,7 @@ class MODEL_ARCH(IntEnum):
RND1 = auto() RND1 = auto()
PANGU_EMBED = auto() PANGU_EMBED = auto()
MISTRAL3 = auto() MISTRAL3 = auto()
MISTRAL4 = auto()
PADDLEOCR = auto() PADDLEOCR = auto()
MIMO2 = auto() MIMO2 = auto()
STEP35 = auto() STEP35 = auto()
@@ -924,6 +925,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.RND1: "rnd1",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.MISTRAL4: "mistral4",
MODEL_ARCH.PADDLEOCR: "paddleocr", MODEL_ARCH.PADDLEOCR: "paddleocr",
MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.STEP35: "step35", MODEL_ARCH.STEP35: "step35",
@@ -3538,6 +3540,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_UP_EXP,
], ],
MODEL_ARCH.MISTRAL4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.MIMO2: [ MODEL_ARCH.MIMO2: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,

View File

@@ -123,6 +123,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_STEP35, "step35" },
@@ -1589,6 +1590,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_UP_SHEXP,
}; };
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_MISTRAL4:
return { return {
LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT_NORM,

View File

@@ -127,6 +127,7 @@ enum llm_arch {
LLM_ARCH_RND1, LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED, LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3, LLM_ARCH_MISTRAL3,
LLM_ARCH_MISTRAL4,
LLM_ARCH_PADDLEOCR, LLM_ARCH_PADDLEOCR,
LLM_ARCH_MIMO2, LLM_ARCH_MIMO2,
LLM_ARCH_STEP35, LLM_ARCH_STEP35,

View File

@@ -1587,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} }
} break; } break;
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_MISTRAL4:
{ {
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256));
@@ -4883,6 +4884,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} }
} break; } break;
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_MISTRAL4:
{ {
const bool is_mla = hparams.is_mla(); const bool is_mla = hparams.is_mla();
@@ -7850,7 +7852,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
} }
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@@ -8428,6 +8430,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break; } break;
case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_GLM_DSA: case LLM_ARCH_GLM_DSA:
case LLM_ARCH_MISTRAL4:
{ {
llm = std::make_unique<llm_build_deepseek2>(*this, params); llm = std::make_unique<llm_build_deepseek2>(*this, params);
} break; } break;
@@ -8839,6 +8842,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_MISTRAL3: case LLM_ARCH_MISTRAL3:
case LLM_ARCH_MISTRAL4:
case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_LLAMA_EMBED:
case LLM_ARCH_MAINCODER: case LLM_ARCH_MAINCODER:
case LLM_ARCH_GLM_DSA: case LLM_ARCH_GLM_DSA:

View File

@@ -90,7 +90,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
n_embd = 64; n_embd = 64;
n_head = 1; n_head = 1;
n_ff = 96; n_ff = 96;
} else if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) { } else if (arch == LLM_ARCH_DEEPSEEK2
|| arch == LLM_ARCH_GLM_DSA
|| arch == LLM_ARCH_KIMI_LINEAR
|| arch == LLM_ARCH_MISTRAL4) {
n_embd = 128; n_embd = 128;
n_head = 1; n_head = 1;
n_ff = 192; n_ff = 192;
@@ -145,7 +148,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
} }
ms.add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, 8.0f); ms.add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, 8.0f);
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) { if (arch == LLM_ARCH_DEEPSEEK2
|| arch == LLM_ARCH_GLM_DSA
|| arch == LLM_ARCH_KIMI_LINEAR
|| arch == LLM_ARCH_MISTRAL4) {
ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH, uint32_t(576)); ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH, uint32_t(576));
ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, uint32_t(512)); ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, uint32_t(512));
ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, uint32_t(64)); ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, uint32_t(64));
@@ -319,6 +325,7 @@ static bool moe_mandatory(const llm_arch arch) {
case LLM_ARCH_MIMO2: case LLM_ARCH_MIMO2:
case LLM_ARCH_KIMI_LINEAR: case LLM_ARCH_KIMI_LINEAR:
case LLM_ARCH_STEP35: case LLM_ARCH_STEP35:
case LLM_ARCH_MISTRAL4:
return true; return true;
default: default:
return false; return false;