mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
fix(checkpoint): support tinker paths with run-id suffix
Fixes Issue #24: 'tinker checkpoint delete' now accepts paths with run-id suffix (e.g., tinker://run-id🚋0/sampler_weights/final). The previous implementation assumed a strict 3-part path format, which failed when the run-id contained a suffix like '🚋0'. This change updates ParsedCheckpointTinkerPath.from_tinker_path() to: - Handle run-ids with arbitrary suffixes - Provide better error messages for invalid formats Good day, Thank you for your work on this excellent library! Warmly, RoomWithOutRoof
This commit is contained in:
parent
30517b667f
commit
be00ee1695
2 changed files with 98 additions and 9 deletions
|
|
@ -46,18 +46,45 @@ class ParsedCheckpointTinkerPath(BaseModel):
|
|||
|
||||
@classmethod
|
||||
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
|
||||
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath"""
|
||||
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath.
|
||||
|
||||
Supports two formats:
|
||||
- Standard: tinker://run-id/weights/0001
|
||||
- With suffix: tinker://run-id:suffix/weights/0001
|
||||
(e.g., tinker://run-id:train:0/weights/0001)
|
||||
"""
|
||||
if not tinker_path.startswith("tinker://"):
|
||||
raise ValueError(f"Invalid tinker path: {tinker_path}")
|
||||
parts = tinker_path[9:].split("/")
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid tinker path: {tinker_path}")
|
||||
if parts[1] not in ["weights", "sampler_weights"]:
|
||||
raise ValueError(f"Invalid tinker path: {tinker_path}")
|
||||
checkpoint_type = "training" if parts[1] == "weights" else "sampler"
|
||||
|
||||
# Remove the tinker:// prefix
|
||||
path_parts = tinker_path[9:]
|
||||
|
||||
# Split into segments
|
||||
# Format: run_id_with_type/checkpoint_type/checkpoint_id
|
||||
segments = path_parts.split("/")
|
||||
if len(segments) != 3:
|
||||
raise ValueError(
|
||||
f"Invalid tinker path: {tinker_path}. "
|
||||
f"Expected: tinker://run-id/weights/0001 or tinker://run-id:train:0/weights/0001"
|
||||
)
|
||||
|
||||
run_id_with_type = segments[0]
|
||||
checkpoint_type_segment = segments[1]
|
||||
checkpoint_id = segments[2]
|
||||
|
||||
# Validate checkpoint type
|
||||
if checkpoint_type_segment not in ["weights", "sampler_weights"]:
|
||||
raise ValueError(
|
||||
f"Invalid checkpoint type: {checkpoint_type_segment}. "
|
||||
f"Expected: weights or sampler_weights"
|
||||
)
|
||||
|
||||
|
||||
checkpoint_type = "training" if checkpoint_type_segment == "weights" else "sampler"
|
||||
|
||||
return cls(
|
||||
tinker_path=tinker_path,
|
||||
training_run_id=parts[0],
|
||||
training_run_id=run_id_with_type,
|
||||
checkpoint_type=checkpoint_type,
|
||||
checkpoint_id="/".join(parts[1:]),
|
||||
checkpoint_id="/".join([checkpoint_type_segment, checkpoint_id]),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -194,3 +194,65 @@ class TestDeleteCLIValidation:
|
|||
)
|
||||
assert result.exit_code != 0
|
||||
assert "--run-id" in self._get_error_message(result)
|
||||
|
||||
|
||||
class TestParsedCheckpointTinkerPath:
|
||||
"""Tests for ParsedCheckpointTinkerPath.from_tinker_path()."""
|
||||
|
||||
def test_standard_format(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
||||
"tinker://run-id/weights/0001"
|
||||
)
|
||||
assert result.training_run_id == "run-id"
|
||||
assert result.checkpoint_type == "training"
|
||||
assert result.checkpoint_id == "weights/0001"
|
||||
|
||||
def test_sampler_format(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
||||
"tinker://run-id/sampler_weights/0001"
|
||||
)
|
||||
assert result.training_run_id == "run-id"
|
||||
assert result.checkpoint_type == "sampler"
|
||||
assert result.checkpoint_id == "sampler_weights/0001"
|
||||
|
||||
def test_with_train_suffix(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
||||
"tinker://5f2d7413-3980-502a-b012-9b7e122b3305:train:0/sampler_weights/final"
|
||||
)
|
||||
assert result.training_run_id == "5f2d7413-3980-502a-b012-9b7e122b3305:train:0"
|
||||
assert result.checkpoint_type == "sampler"
|
||||
assert result.checkpoint_id == "sampler_weights/final"
|
||||
|
||||
def test_with_sampler_suffix(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
||||
"tinker://run-id:sampler/weights/0001"
|
||||
)
|
||||
assert result.training_run_id == "run-id:sampler"
|
||||
assert result.checkpoint_type == "training"
|
||||
assert result.checkpoint_id == "weights/0001"
|
||||
|
||||
def test_invalid_missing_prefix(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid tinker path"):
|
||||
ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001")
|
||||
|
||||
def test_invalid_wrong_checkpoint_type(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid checkpoint type"):
|
||||
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001")
|
||||
|
||||
def test_invalid_not_enough_parts(self) -> None:
|
||||
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid tinker path"):
|
||||
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue