fix(checkpoint): support tinker paths with run-id suffix

Fixes Issue #24: 'tinker checkpoint delete' now accepts paths with run-id
suffix (e.g., tinker://run-id🚋0/sampler_weights/final).

The previous implementation assumed a strict 3-part path format, which
failed when the run-id contained a suffix like '🚋0'.

This change updates ParsedCheckpointTinkerPath.from_tinker_path() to:
- Handle run-ids with arbitrary suffixes
- Provide better error messages for invalid formats

Good day,

Thank you for your work on this excellent library!

Warmly,
RoomWithOutRoof
This commit is contained in:
RoomWithOutRoof 2026-04-16 01:37:56 +08:00
parent 30517b667f
commit be00ee1695
2 changed files with 98 additions and 9 deletions

View file

@ -46,18 +46,45 @@ class ParsedCheckpointTinkerPath(BaseModel):
@classmethod @classmethod
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath": def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath""" """Parse a tinker path to an instance of ParsedCheckpointTinkerPath.
Supports two formats:
- Standard: tinker://run-id/weights/0001
- With suffix: tinker://run-id:suffix/weights/0001
(e.g., tinker://run-id:train:0/weights/0001)
"""
if not tinker_path.startswith("tinker://"): if not tinker_path.startswith("tinker://"):
raise ValueError(f"Invalid tinker path: {tinker_path}") raise ValueError(f"Invalid tinker path: {tinker_path}")
parts = tinker_path[9:].split("/")
if len(parts) != 3: # Remove the tinker:// prefix
raise ValueError(f"Invalid tinker path: {tinker_path}") path_parts = tinker_path[9:]
if parts[1] not in ["weights", "sampler_weights"]:
raise ValueError(f"Invalid tinker path: {tinker_path}") # Split into segments
checkpoint_type = "training" if parts[1] == "weights" else "sampler" # Format: run_id_with_type/checkpoint_type/checkpoint_id
segments = path_parts.split("/")
if len(segments) != 3:
raise ValueError(
f"Invalid tinker path: {tinker_path}. "
f"Expected: tinker://run-id/weights/0001 or tinker://run-id:train:0/weights/0001"
)
run_id_with_type = segments[0]
checkpoint_type_segment = segments[1]
checkpoint_id = segments[2]
# Validate checkpoint type
if checkpoint_type_segment not in ["weights", "sampler_weights"]:
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type_segment}. "
f"Expected: weights or sampler_weights"
)
checkpoint_type = "training" if checkpoint_type_segment == "weights" else "sampler"
return cls( return cls(
tinker_path=tinker_path, tinker_path=tinker_path,
training_run_id=parts[0], training_run_id=run_id_with_type,
checkpoint_type=checkpoint_type, checkpoint_type=checkpoint_type,
checkpoint_id="/".join(parts[1:]), checkpoint_id="/".join([checkpoint_type_segment, checkpoint_id]),
) )

View file

@ -194,3 +194,65 @@ class TestDeleteCLIValidation:
) )
assert result.exit_code != 0 assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result) assert "--run-id" in self._get_error_message(result)
class TestParsedCheckpointTinkerPath:
"""Tests for ParsedCheckpointTinkerPath.from_tinker_path()."""
def test_standard_format(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id/weights/0001"
)
assert result.training_run_id == "run-id"
assert result.checkpoint_type == "training"
assert result.checkpoint_id == "weights/0001"
def test_sampler_format(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id/sampler_weights/0001"
)
assert result.training_run_id == "run-id"
assert result.checkpoint_type == "sampler"
assert result.checkpoint_id == "sampler_weights/0001"
def test_with_train_suffix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://5f2d7413-3980-502a-b012-9b7e122b3305:train:0/sampler_weights/final"
)
assert result.training_run_id == "5f2d7413-3980-502a-b012-9b7e122b3305:train:0"
assert result.checkpoint_type == "sampler"
assert result.checkpoint_id == "sampler_weights/final"
def test_with_sampler_suffix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id:sampler/weights/0001"
)
assert result.training_run_id == "run-id:sampler"
assert result.checkpoint_type == "training"
assert result.checkpoint_id == "weights/0001"
def test_invalid_missing_prefix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001")
def test_invalid_wrong_checkpoint_type(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
with pytest.raises(ValueError, match="Invalid checkpoint type"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001")
def test_invalid_not_enough_parts(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")