mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 12:34:05 +00:00
convert : add split() to LoraTorchTensor in LoRA converter (#22832)
* convert : add split() method to LoraTorchTensor * Fix python type-check * Fix flake8 Lint * fix: handle positional dim arg in torch.split dispatch * Fix type-check again * Fix type-checks * Remove unit test per reviewers feedback * work around ty deficiency --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -188,6 +188,24 @@ class LoraTorchTensor:
|
||||
def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
|
||||
return self.transpose(axis0, axis1)
|
||||
|
||||
def split(self, split_size: int | Sequence[int], dim: int = 0) -> tuple[LoraTorchTensor, ...]:
|
||||
shape = self.shape
|
||||
ndim = len(shape)
|
||||
if dim < 0:
|
||||
dim += ndim
|
||||
if dim == ndim - 1:
|
||||
A_chunks = self._lora_A.split(split_size, dim=-1)
|
||||
return tuple(LoraTorchTensor(a, self._lora_B) for a in A_chunks)
|
||||
elif dim == ndim - 2:
|
||||
B_chunks = self._lora_B.split(split_size, dim=-2)
|
||||
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
|
||||
else:
|
||||
B_chunks = self._lora_B.split(split_size, dim=dim)
|
||||
if self._lora_A.shape[dim] == 1:
|
||||
return tuple(LoraTorchTensor(self._lora_A, b) for b in B_chunks)
|
||||
A_chunks = self._lora_A.split(split_size, dim=dim)
|
||||
return tuple(LoraTorchTensor(a, b) for a, b in zip(A_chunks, B_chunks))
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
|
||||
|
||||
@@ -230,6 +248,11 @@ class LoraTorchTensor:
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif func is torch.split:
|
||||
assert len(args) and len(args) >= 2
|
||||
tensor, split_size = args[0], args[1]
|
||||
dim = args[2] if len(args) > 2 else kwargs.get("dim", 0)
|
||||
return tensor.split(split_size, dim=dim)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user