check for checkpoint_complete

This commit is contained in:
Kashif Rasul 2026-01-27 20:04:22 +01:00
parent d6aa21005b
commit 054abb5d4e

View file

@ -456,9 +456,13 @@ class RestClient(TelemetryProvider):
types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
def _safe_extract(tar: tarfile.TarFile, path: Path) -> None:
base = path.resolve()
for member in tar.getmembers():
member_path = path / member.name
if not str(member_path.resolve()).startswith(str(path.resolve())):
# Reject symlinks/hardlinks to avoid traversal via link targets
if member.issym() or member.islnk():
raise ValueError("Unsafe symlink or hardlink in tar archive")
member_path = (path / member.name).resolve()
if not str(member_path).startswith(str(base)):
raise ValueError("Unsafe path in tar archive")
tar.extractall(path=path)
@ -502,11 +506,17 @@ class RestClient(TelemetryProvider):
adapter_config = extract_dir / "adapter_config.json"
adapter_safetensors = extract_dir / "adapter_model.safetensors"
adapter_bin = extract_dir / "adapter_model.bin"
checkpoint_complete = extract_dir / "checkpoint_complete"
if not adapter_config.exists() or not (adapter_safetensors.exists() or adapter_bin.exists()):
raise ValueError(
"Checkpoint archive does not contain a PEFT adapter. "
"Expected adapter_config.json and adapter_model.safetensors (or adapter_model.bin)."
)
if not checkpoint_complete.exists():
raise ValueError(
"Checkpoint archive is missing 'checkpoint_complete'. "
"The adapter files may be incomplete."
)
base_model = "unknown"
lora_rank = None