convert : add split() to LoraTorchTensor in LoRA converter (#22832)

* convert : add split() method to LoraTorchTensor

* Fix python type-check

* Fix flake8 Lint

* fix: handle positional dim arg in torch.split dispatch

* Fix type-check again

* Fix type-checks

* Remove unit test per reviewers feedback

* work around ty deficiency

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Jesus Talavera
2026-05-12 07:17:04 +02:00
committed by GitHub
parent da44953329
commit 78fbbc2c07

View File

@@ -188,6 +188,24 @@ class LoraTorchTensor:
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
return self.transpose(axis0, axis1)
def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]:
shape = self.shape
ndim = len(shape)
if dim < 0:
dim += ndim
if dim == ndim - 1:
A_chunks = self._lora_A.split(split_size, dim=-1)
return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks)
elif dim == ndim - 2:
B_chunks = self._lora_B.split(split_size, dim=-2)
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
else:
B_chunks = self._lora_B.split(split_size, dim=dim)
if self._lora_A.shape[dim] == 1:
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
A_chunks = self._lora_A.split(split_size, dim=dim)
return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks))
def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
@@ -230,6 +248,11 @@ class LoraTorchTensor:
)
else:
raise NotImplementedError
elif func is torch.split:
assert len(args) and len(args) >= 2
tensor, split_size = args[0], args[1]
dim = args[2] if len(args) > 2 else kwargs.get("dim", 0)
return tensor.split(split_size, dim=dim)
else:
raise NotImplementedError