mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Fixes Issue #24: 'tinker checkpoint delete' now accepts paths with run-id suffix (e.g., tinker://run-id🚋0/sampler_weights/final). The previous implementation assumed a strict 3-part path format, which failed when the run-id contained a suffix like '🚋0'. This change updates ParsedCheckpointTinkerPath.from_tinker_path() to: - Handle run-ids with arbitrary suffixes - Provide better error messages for invalid formats Good day, Thank you for your work on this excellent library! Warmly, RoomWithOutRoof
258 lines
9.8 KiB
Python
258 lines
9.8 KiB
Python
"""Tests for bulk checkpoint deletion: CLI flags and date parsing."""
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
from click.testing import CliRunner
|
|
|
|
from tinker.cli.commands.checkpoint import _filter_checkpoints, _parse_date
|
|
|
|
|
|
class TestParseDate:
|
|
"""Tests for the _parse_date ISO 8601 parser."""
|
|
|
|
def test_date_only(self) -> None:
|
|
dt = _parse_date("2024-01-15")
|
|
assert dt.year == 2024
|
|
assert dt.month == 1
|
|
assert dt.day == 15
|
|
assert dt.tzinfo is not None
|
|
|
|
def test_datetime_with_z(self) -> None:
|
|
dt = _parse_date("2024-06-01T12:00:00Z")
|
|
assert dt.year == 2024
|
|
assert dt.month == 6
|
|
assert dt.hour == 12
|
|
assert dt.tzinfo is not None
|
|
|
|
def test_datetime_with_offset(self) -> None:
|
|
dt = _parse_date("2024-06-01T12:00:00+00:00")
|
|
assert dt.year == 2024
|
|
assert dt.tzinfo is not None
|
|
|
|
def test_datetime_naive_gets_utc(self) -> None:
|
|
dt = _parse_date("2024-06-01T12:00:00")
|
|
assert dt.tzinfo is not None
|
|
|
|
def test_whitespace_stripped(self) -> None:
|
|
dt = _parse_date(" 2024-01-01 ")
|
|
assert dt.year == 2024
|
|
|
|
def test_invalid_raises(self) -> None:
|
|
from tinker.cli.exceptions import TinkerCliError
|
|
|
|
with pytest.raises(TinkerCliError):
|
|
_parse_date("not-a-date")
|
|
|
|
def test_invalid_format_raises(self) -> None:
|
|
from tinker.cli.exceptions import TinkerCliError
|
|
|
|
with pytest.raises(TinkerCliError):
|
|
_parse_date("01/15/2024")
|
|
|
|
|
|
class TestFilterCheckpoints:
|
|
"""Tests for the _filter_checkpoints function."""
|
|
|
|
@pytest.fixture()
|
|
def sample_checkpoints(self):
|
|
from tinker.types.checkpoint import Checkpoint
|
|
|
|
now = datetime.now(UTC)
|
|
return [
|
|
Checkpoint(
|
|
checkpoint_id="weights/0001",
|
|
checkpoint_type="training",
|
|
time=now - timedelta(days=10),
|
|
tinker_path="tinker://run-1/weights/0001",
|
|
size_bytes=1000,
|
|
),
|
|
Checkpoint(
|
|
checkpoint_id="weights/0002",
|
|
checkpoint_type="training",
|
|
time=now - timedelta(days=3),
|
|
tinker_path="tinker://run-1/weights/0002",
|
|
size_bytes=2000,
|
|
),
|
|
Checkpoint(
|
|
checkpoint_id="sampler_weights/0001",
|
|
checkpoint_type="sampler",
|
|
time=now - timedelta(days=10),
|
|
tinker_path="tinker://run-1/sampler_weights/0001",
|
|
size_bytes=500,
|
|
),
|
|
Checkpoint(
|
|
checkpoint_id="weights/0003",
|
|
checkpoint_type="training",
|
|
time=now - timedelta(hours=1),
|
|
tinker_path="tinker://run-1/weights/0003",
|
|
size_bytes=3000,
|
|
),
|
|
]
|
|
|
|
def test_no_filters(self, sample_checkpoints) -> None:
|
|
result = _filter_checkpoints(sample_checkpoints, None, None, None)
|
|
assert len(result) == 4
|
|
|
|
def test_filter_by_weights_type(self, sample_checkpoints) -> None:
|
|
result = _filter_checkpoints(sample_checkpoints, "weights", None, None)
|
|
assert len(result) == 3
|
|
assert all(c.checkpoint_type == "training" for c in result)
|
|
|
|
def test_filter_by_sampler_weights_type(self, sample_checkpoints) -> None:
|
|
result = _filter_checkpoints(sample_checkpoints, "sampler_weights", None, None)
|
|
assert len(result) == 1
|
|
assert result[0].checkpoint_type == "sampler"
|
|
|
|
def test_filter_before(self, sample_checkpoints) -> None:
|
|
# Before 7 days ago → only the 10-day-old checkpoints
|
|
cutoff = datetime.now(UTC) - timedelta(days=7)
|
|
result = _filter_checkpoints(sample_checkpoints, None, cutoff, None)
|
|
assert len(result) == 2
|
|
assert all("0001" in c.checkpoint_id for c in result)
|
|
|
|
def test_filter_after(self, sample_checkpoints) -> None:
|
|
# After 7 days ago → the 3-day-old and 1-hour-old checkpoints
|
|
cutoff = datetime.now(UTC) - timedelta(days=7)
|
|
result = _filter_checkpoints(sample_checkpoints, None, None, cutoff)
|
|
assert len(result) == 2
|
|
paths = {c.tinker_path for c in result}
|
|
assert "tinker://run-1/weights/0002" in paths
|
|
assert "tinker://run-1/weights/0003" in paths
|
|
|
|
def test_filter_date_range(self, sample_checkpoints) -> None:
|
|
# Between 5 and 2 days ago → only the 3-day-old checkpoint
|
|
after_dt = datetime.now(UTC) - timedelta(days=5)
|
|
before_dt = datetime.now(UTC) - timedelta(days=2)
|
|
result = _filter_checkpoints(sample_checkpoints, None, before_dt, after_dt)
|
|
assert len(result) == 1
|
|
assert result[0].tinker_path == "tinker://run-1/weights/0002"
|
|
|
|
def test_filter_by_type_and_before(self, sample_checkpoints) -> None:
|
|
cutoff = datetime.now(UTC) - timedelta(days=7)
|
|
result = _filter_checkpoints(sample_checkpoints, "weights", cutoff, None)
|
|
assert len(result) == 1
|
|
assert result[0].tinker_path == "tinker://run-1/weights/0001"
|
|
|
|
def test_invalid_type_raises(self, sample_checkpoints) -> None:
|
|
from tinker.cli.exceptions import TinkerCliError
|
|
|
|
with pytest.raises(TinkerCliError):
|
|
_filter_checkpoints(sample_checkpoints, "invalid_type", None, None)
|
|
|
|
|
|
class TestDeleteCLIValidation:
|
|
"""Tests for CLI delete command argument validation."""
|
|
|
|
def _get_error_message(self, result) -> str:
|
|
"""Get error message from either output or exception."""
|
|
if result.output:
|
|
return result.output
|
|
if result.exception:
|
|
return str(result.exception)
|
|
return ""
|
|
|
|
def test_no_args_shows_error(self) -> None:
|
|
from tinker.cli.commands.checkpoint import cli
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["delete"])
|
|
assert result.exit_code != 0
|
|
|
|
def test_paths_and_run_id_conflict(self) -> None:
|
|
from tinker.cli.commands.checkpoint import cli
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--run-id", "run-1"])
|
|
assert result.exit_code != 0
|
|
assert "Cannot specify both" in self._get_error_message(result)
|
|
|
|
def test_type_without_run_id_error(self) -> None:
|
|
from tinker.cli.commands.checkpoint import cli
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--type", "weights"])
|
|
assert result.exit_code != 0
|
|
assert "--run-id" in self._get_error_message(result)
|
|
|
|
def test_before_without_run_id_error(self) -> None:
|
|
from tinker.cli.commands.checkpoint import cli
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
cli, ["delete", "tinker://run-1/weights/0001", "--before", "2024-01-01"]
|
|
)
|
|
assert result.exit_code != 0
|
|
assert "--run-id" in self._get_error_message(result)
|
|
|
|
def test_after_without_run_id_error(self) -> None:
|
|
from tinker.cli.commands.checkpoint import cli
|
|
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
cli, ["delete", "tinker://run-1/weights/0001", "--after", "2024-01-01"]
|
|
)
|
|
assert result.exit_code != 0
|
|
assert "--run-id" in self._get_error_message(result)
|
|
|
|
|
|
class TestParsedCheckpointTinkerPath:
|
|
"""Tests for ParsedCheckpointTinkerPath.from_tinker_path()."""
|
|
|
|
def test_standard_format(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
|
"tinker://run-id/weights/0001"
|
|
)
|
|
assert result.training_run_id == "run-id"
|
|
assert result.checkpoint_type == "training"
|
|
assert result.checkpoint_id == "weights/0001"
|
|
|
|
def test_sampler_format(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
|
"tinker://run-id/sampler_weights/0001"
|
|
)
|
|
assert result.training_run_id == "run-id"
|
|
assert result.checkpoint_type == "sampler"
|
|
assert result.checkpoint_id == "sampler_weights/0001"
|
|
|
|
def test_with_train_suffix(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
|
"tinker://5f2d7413-3980-502a-b012-9b7e122b3305:train:0/sampler_weights/final"
|
|
)
|
|
assert result.training_run_id == "5f2d7413-3980-502a-b012-9b7e122b3305:train:0"
|
|
assert result.checkpoint_type == "sampler"
|
|
assert result.checkpoint_id == "sampler_weights/final"
|
|
|
|
def test_with_sampler_suffix(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
result = ParsedCheckpointTinkerPath.from_tinker_path(
|
|
"tinker://run-id:sampler/weights/0001"
|
|
)
|
|
assert result.training_run_id == "run-id:sampler"
|
|
assert result.checkpoint_type == "training"
|
|
assert result.checkpoint_id == "weights/0001"
|
|
|
|
def test_invalid_missing_prefix(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
with pytest.raises(ValueError, match="Invalid tinker path"):
|
|
ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001")
|
|
|
|
def test_invalid_wrong_checkpoint_type(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
with pytest.raises(ValueError, match="Invalid checkpoint type"):
|
|
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001")
|
|
|
|
def test_invalid_not_enough_parts(self) -> None:
|
|
from tinker.types.checkpoint import ParsedCheckpointTinkerPath
|
|
|
|
with pytest.raises(ValueError, match="Invalid tinker path"):
|
|
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")
|