mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 04:24:17 +00:00
convert : apply Q/K RoPE permutation in NVFP4 repack path (#22611)
Llama-architecture q_proj/k_proj weights need an axis-0 row permutation to match GGML's RoPE convention. The BF16 path applies this in LlamaModel.modify_tensors via LlamaModel.permute, but the NVFP4 path bypasses modify_tensors and writes weights directly through ModelBase._repack_nvfp4. Without the permutation, attention heads end up scrambled at inference and the model produces gibberish. This change overrides _repack_nvfp4 on LlamaModel and applies the same permutation to both the nibble-packed weight and the per-block scale before delegating to ModelBase._repack_nvfp4 via super(). Reuses the existing LlamaModel.permute static helper and respects the existing undo_permute flag, so subclasses (Mistral, Granite, Llama4, etc.) inherit the fix automatically. Verified on TinyLlama-1.1B reproducer: perplexity drops from 4419 (gibberish) to 43.9, matching the BF16-dequantized baseline (44.0). Also verified end-to-end on ALIA-40b-instruct-2601 (BSC, Llama architecture) with multilingual generation in Spanish/Catalan/Basque/ Galician all coherent with the fix applied. Co-authored-by: Chema <chema@montevive.ai>
This commit is contained in:
@@ -2889,6 +2889,20 @@ class LlamaModel(TextModel):
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape))
|
||||
|
||||
def _repack_nvfp4(self, name: str, weight: Tensor, scale: Tensor, scale2: Tensor, input_scale: Tensor):
|
||||
# Mirror the BF16 Q/K RoPE permutation site in modify_tensors; the NVFP4 path bypasses it.
|
||||
if self.undo_permute:
|
||||
n_head = self.find_hparam(["n_heads", "num_attention_heads"], optional=True)
|
||||
n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"], optional=True)
|
||||
if n_head is not None:
|
||||
if name.endswith("q_proj.weight"):
|
||||
weight = LlamaModel.permute(weight, n_head, n_head)
|
||||
scale = LlamaModel.permute(scale, n_head, n_head)
|
||||
elif name.endswith("k_proj.weight"):
|
||||
weight = LlamaModel.permute(weight, n_head, n_kv_head)
|
||||
scale = LlamaModel.permute(scale, n_head, n_kv_head)
|
||||
super()._repack_nvfp4(name, weight, scale, scale2, input_scale)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
Reference in New Issue
Block a user