From 78fbbc2c0788efc8857a2c0dc9802ec689fa12c1 Mon Sep 17 00:00:00 2001 From: Jesus Talavera <145992175+jesus-talavera-ibm@users.noreply.github.com> Date: Tue, 12 May 2026 07:17:04 +0200 Subject: [PATCH] convert : add split() to LoraTorchTensor in LoRA converter (#22832) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * convert : add split() method to LoraTorchTensor * Fix python type-check * Fix flake8 Lint * fix: handle positional dim arg in torch.split dispatch * Fix type-check again * Fix type-checks * Remove unit test per reviewers feedback * work around ty deficiency --------- Co-authored-by: Sigbjørn Skjæret --- convert_lora_to_gguf.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index d583342056..ad4751bb96 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -188,6 +188,24 @@ class LoraTorchTensor: def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) + def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]: + shape = self.shape + ndim = len(shape) + if dim < 0: + dim += ndim + if dim == ndim - 1: + A_chunks = self._lora_A.split(split_size, dim=-1) + return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks) + elif dim == ndim - 2: + B_chunks = self._lora_B.split(split_size, dim=-2) + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + else: + B_chunks = self._lora_B.split(split_size, dim=dim) + if self._lora_A.shape[dim] == 1: + return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks) + A_chunks = self._lora_A.split(split_size, dim=dim) + return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks)) + def to(self, *args, **kwargs): return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) @@ -230,6 +248,11 @@ class LoraTorchTensor: ) else: raise NotImplementedError + elif func is torch.split: + assert len(args) and len(args) >= 2 + tensor, split_size = args[0], args[1] + dim = args[2] if len(args) > 2 else kwargs.get("dim", 0) + return tensor.split(split_size, dim=dim) else: raise NotImplementedError