diff --git a/src/tinker/types/checkpoint.py b/src/tinker/types/checkpoint.py index b6a5630..bb32251 100644 --- a/src/tinker/types/checkpoint.py +++ b/src/tinker/types/checkpoint.py @@ -46,18 +46,45 @@ class ParsedCheckpointTinkerPath(BaseModel): @classmethod def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath": - """Parse a tinker path to an instance of ParsedCheckpointTinkerPath""" + """Parse a tinker path to an instance of ParsedCheckpointTinkerPath. + + Supports two formats: + - Standard: tinker://run-id/weights/0001 + - With suffix: tinker://run-id:suffix/weights/0001 + (e.g., tinker://run-id:train:0/weights/0001) + """ if not tinker_path.startswith("tinker://"): raise ValueError(f"Invalid tinker path: {tinker_path}") - parts = tinker_path[9:].split("/") - if len(parts) != 3: - raise ValueError(f"Invalid tinker path: {tinker_path}") - if parts[1] not in ["weights", "sampler_weights"]: - raise ValueError(f"Invalid tinker path: {tinker_path}") - checkpoint_type = "training" if parts[1] == "weights" else "sampler" + + # Remove the tinker:// prefix + path_parts = tinker_path[9:] + + # Split into segments + # Format: run_id_with_type/checkpoint_type/checkpoint_id + segments = path_parts.split("/") + if len(segments) != 3: + raise ValueError( + f"Invalid tinker path: {tinker_path}. " + f"Expected: tinker://run-id/weights/0001 or tinker://run-id:train:0/weights/0001" + ) + + run_id_with_type = segments[0] + checkpoint_type_segment = segments[1] + checkpoint_id = segments[2] + + # Validate checkpoint type + if checkpoint_type_segment not in ["weights", "sampler_weights"]: + raise ValueError( + f"Invalid checkpoint type: {checkpoint_type_segment}. " + f"Expected: weights or sampler_weights" + ) + + + checkpoint_type = "training" if checkpoint_type_segment == "weights" else "sampler" + return cls( tinker_path=tinker_path, - training_run_id=parts[0], + training_run_id=run_id_with_type, checkpoint_type=checkpoint_type, - checkpoint_id="/".join(parts[1:]), + checkpoint_id="/".join([checkpoint_type_segment, checkpoint_id]), ) diff --git a/tests/test_checkpoint_delete.py b/tests/test_checkpoint_delete.py index 65f4c9f..04678ec 100644 --- a/tests/test_checkpoint_delete.py +++ b/tests/test_checkpoint_delete.py @@ -194,3 +194,65 @@ class TestDeleteCLIValidation: ) assert result.exit_code != 0 assert "--run-id" in self._get_error_message(result) + + +class TestParsedCheckpointTinkerPath: + """Tests for ParsedCheckpointTinkerPath.from_tinker_path().""" + + def test_standard_format(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + result = ParsedCheckpointTinkerPath.from_tinker_path( + "tinker://run-id/weights/0001" + ) + assert result.training_run_id == "run-id" + assert result.checkpoint_type == "training" + assert result.checkpoint_id == "weights/0001" + + def test_sampler_format(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + result = ParsedCheckpointTinkerPath.from_tinker_path( + "tinker://run-id/sampler_weights/0001" + ) + assert result.training_run_id == "run-id" + assert result.checkpoint_type == "sampler" + assert result.checkpoint_id == "sampler_weights/0001" + + def test_with_train_suffix(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + result = ParsedCheckpointTinkerPath.from_tinker_path( + "tinker://5f2d7413-3980-502a-b012-9b7e122b3305:train:0/sampler_weights/final" + ) + assert result.training_run_id == "5f2d7413-3980-502a-b012-9b7e122b3305:train:0" + assert result.checkpoint_type == "sampler" + assert result.checkpoint_id == "sampler_weights/final" + + def test_with_sampler_suffix(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + result = ParsedCheckpointTinkerPath.from_tinker_path( + "tinker://run-id:sampler/weights/0001" + ) + assert result.training_run_id == "run-id:sampler" + assert result.checkpoint_type == "training" + assert result.checkpoint_id == "weights/0001" + + def test_invalid_missing_prefix(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001") + + def test_invalid_wrong_checkpoint_type(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + with pytest.raises(ValueError, match="Invalid checkpoint type"): + ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001") + + def test_invalid_not_enough_parts(self) -> None: + from tinker.types.checkpoint import ParsedCheckpointTinkerPath + + with pytest.raises(ValueError, match="Invalid tinker path"): + ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")