tinker/tests/test_checkpoint_delete.py
RoomWithOutRoof be00ee1695 fix(checkpoint): support tinker paths with run-id suffix
Fixes Issue #24: 'tinker checkpoint delete' now accepts paths with run-id
suffix (e.g., tinker://run-id🚋0/sampler_weights/final).

The previous implementation assumed a strict 3-part path format, which
failed when the run-id contained a suffix like '🚋0'.

This change updates ParsedCheckpointTinkerPath.from_tinker_path() to:
- Handle run-ids with arbitrary suffixes
- Provide better error messages for invalid formats

Good day,

Thank you for your work on this excellent library!

Warmly,
RoomWithOutRoof
2026-04-16 01:37:56 +08:00

258 lines
9.8 KiB
Python

"""Tests for bulk checkpoint deletion: CLI flags and date parsing."""
from datetime import UTC, datetime, timedelta
import pytest
from click.testing import CliRunner
from tinker.cli.commands.checkpoint import _filter_checkpoints, _parse_date
class TestParseDate:
"""Tests for the _parse_date ISO 8601 parser."""
def test_date_only(self) -> None:
dt = _parse_date("2024-01-15")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
assert dt.tzinfo is not None
def test_datetime_with_z(self) -> None:
dt = _parse_date("2024-06-01T12:00:00Z")
assert dt.year == 2024
assert dt.month == 6
assert dt.hour == 12
assert dt.tzinfo is not None
def test_datetime_with_offset(self) -> None:
dt = _parse_date("2024-06-01T12:00:00+00:00")
assert dt.year == 2024
assert dt.tzinfo is not None
def test_datetime_naive_gets_utc(self) -> None:
dt = _parse_date("2024-06-01T12:00:00")
assert dt.tzinfo is not None
def test_whitespace_stripped(self) -> None:
dt = _parse_date(" 2024-01-01 ")
assert dt.year == 2024
def test_invalid_raises(self) -> None:
from tinker.cli.exceptions import TinkerCliError
with pytest.raises(TinkerCliError):
_parse_date("not-a-date")
def test_invalid_format_raises(self) -> None:
from tinker.cli.exceptions import TinkerCliError
with pytest.raises(TinkerCliError):
_parse_date("01/15/2024")
class TestFilterCheckpoints:
"""Tests for the _filter_checkpoints function."""
@pytest.fixture()
def sample_checkpoints(self):
from tinker.types.checkpoint import Checkpoint
now = datetime.now(UTC)
return [
Checkpoint(
checkpoint_id="weights/0001",
checkpoint_type="training",
time=now - timedelta(days=10),
tinker_path="tinker://run-1/weights/0001",
size_bytes=1000,
),
Checkpoint(
checkpoint_id="weights/0002",
checkpoint_type="training",
time=now - timedelta(days=3),
tinker_path="tinker://run-1/weights/0002",
size_bytes=2000,
),
Checkpoint(
checkpoint_id="sampler_weights/0001",
checkpoint_type="sampler",
time=now - timedelta(days=10),
tinker_path="tinker://run-1/sampler_weights/0001",
size_bytes=500,
),
Checkpoint(
checkpoint_id="weights/0003",
checkpoint_type="training",
time=now - timedelta(hours=1),
tinker_path="tinker://run-1/weights/0003",
size_bytes=3000,
),
]
def test_no_filters(self, sample_checkpoints) -> None:
result = _filter_checkpoints(sample_checkpoints, None, None, None)
assert len(result) == 4
def test_filter_by_weights_type(self, sample_checkpoints) -> None:
result = _filter_checkpoints(sample_checkpoints, "weights", None, None)
assert len(result) == 3
assert all(c.checkpoint_type == "training" for c in result)
def test_filter_by_sampler_weights_type(self, sample_checkpoints) -> None:
result = _filter_checkpoints(sample_checkpoints, "sampler_weights", None, None)
assert len(result) == 1
assert result[0].checkpoint_type == "sampler"
def test_filter_before(self, sample_checkpoints) -> None:
# Before 7 days ago → only the 10-day-old checkpoints
cutoff = datetime.now(UTC) - timedelta(days=7)
result = _filter_checkpoints(sample_checkpoints, None, cutoff, None)
assert len(result) == 2
assert all("0001" in c.checkpoint_id for c in result)
def test_filter_after(self, sample_checkpoints) -> None:
# After 7 days ago → the 3-day-old and 1-hour-old checkpoints
cutoff = datetime.now(UTC) - timedelta(days=7)
result = _filter_checkpoints(sample_checkpoints, None, None, cutoff)
assert len(result) == 2
paths = {c.tinker_path for c in result}
assert "tinker://run-1/weights/0002" in paths
assert "tinker://run-1/weights/0003" in paths
def test_filter_date_range(self, sample_checkpoints) -> None:
# Between 5 and 2 days ago → only the 3-day-old checkpoint
after_dt = datetime.now(UTC) - timedelta(days=5)
before_dt = datetime.now(UTC) - timedelta(days=2)
result = _filter_checkpoints(sample_checkpoints, None, before_dt, after_dt)
assert len(result) == 1
assert result[0].tinker_path == "tinker://run-1/weights/0002"
def test_filter_by_type_and_before(self, sample_checkpoints) -> None:
cutoff = datetime.now(UTC) - timedelta(days=7)
result = _filter_checkpoints(sample_checkpoints, "weights", cutoff, None)
assert len(result) == 1
assert result[0].tinker_path == "tinker://run-1/weights/0001"
def test_invalid_type_raises(self, sample_checkpoints) -> None:
from tinker.cli.exceptions import TinkerCliError
with pytest.raises(TinkerCliError):
_filter_checkpoints(sample_checkpoints, "invalid_type", None, None)
class TestDeleteCLIValidation:
"""Tests for CLI delete command argument validation."""
def _get_error_message(self, result) -> str:
"""Get error message from either output or exception."""
if result.output:
return result.output
if result.exception:
return str(result.exception)
return ""
def test_no_args_shows_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(cli, ["delete"])
assert result.exit_code != 0
def test_paths_and_run_id_conflict(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--run-id", "run-1"])
assert result.exit_code != 0
assert "Cannot specify both" in self._get_error_message(result)
def test_type_without_run_id_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--type", "weights"])
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)
def test_before_without_run_id_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(
cli, ["delete", "tinker://run-1/weights/0001", "--before", "2024-01-01"]
)
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)
def test_after_without_run_id_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(
cli, ["delete", "tinker://run-1/weights/0001", "--after", "2024-01-01"]
)
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)
class TestParsedCheckpointTinkerPath:
"""Tests for ParsedCheckpointTinkerPath.from_tinker_path()."""
def test_standard_format(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id/weights/0001"
)
assert result.training_run_id == "run-id"
assert result.checkpoint_type == "training"
assert result.checkpoint_id == "weights/0001"
def test_sampler_format(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id/sampler_weights/0001"
)
assert result.training_run_id == "run-id"
assert result.checkpoint_type == "sampler"
assert result.checkpoint_id == "sampler_weights/0001"
def test_with_train_suffix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://5f2d7413-3980-502a-b012-9b7e122b3305:train:0/sampler_weights/final"
)
assert result.training_run_id == "5f2d7413-3980-502a-b012-9b7e122b3305:train:0"
assert result.checkpoint_type == "sampler"
assert result.checkpoint_id == "sampler_weights/final"
def test_with_sampler_suffix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id:sampler/weights/0001"
)
assert result.training_run_id == "run-id:sampler"
assert result.checkpoint_type == "training"
assert result.checkpoint_id == "weights/0001"
def test_invalid_missing_prefix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001")
def test_invalid_wrong_checkpoint_type(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
with pytest.raises(ValueError, match="Invalid checkpoint type"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001")
def test_invalid_not_enough_parts(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")