mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-05-13 20:44:09 +00:00
fix parts validation
This commit is contained in:
@@ -210,6 +210,7 @@ class ModelBase:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
|
||||
if not self.is_mistral_format:
|
||||
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
|
||||
@@ -243,6 +244,7 @@ class ModelBase:
|
||||
assert model_part is not None
|
||||
|
||||
for name in model_part.keys():
|
||||
tensor_names_from_parts.add(name)
|
||||
if is_safetensors:
|
||||
data: gguf.utility.LocalTensor = model_part[name]
|
||||
if self.lazy:
|
||||
@@ -262,7 +264,6 @@ class ModelBase:
|
||||
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_index) > 0:
|
||||
tensor_names_from_parts = set(tensors.keys())
|
||||
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
|
||||
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
|
||||
|
||||
Reference in New Issue
Block a user