diff --git a/.sync_state b/.sync_state index 847d303..0712e7f 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1", - "last_sync_time": "2026-01-21T23:13:02.930293" + "last_synced_sha": "9ee57ec63f1159d18309fabc51bead0105d52917", + "last_sync_time": "2026-01-25T05:52:42.500531" } \ No newline at end of file diff --git a/docs/api/samplingclient.md b/docs/api/samplingclient.md index 5cf9d71..080fd1a 100644 --- a/docs/api/samplingclient.md +++ b/docs/api/samplingclient.md @@ -110,3 +110,14 @@ async def compute_logprobs_async( ``` Async version of compute_logprobs. + +#### `get_tokenizer` + +```python +def get_tokenizer() -> PreTrainedTokenizer +``` + +Get the tokenizer for the current model. + +Returns: +- `PreTrainedTokenizer` compatible with the model diff --git a/docs/api/trainingclient.md b/docs/api/trainingclient.md index 487f904..564f113 100644 --- a/docs/api/trainingclient.md +++ b/docs/api/trainingclient.md @@ -176,6 +176,11 @@ def optim_step( Update model parameters using Adam optimizer. +The Adam optimizer used by tinker is identical +to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html). +Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay). + + Args: - `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay) @@ -212,13 +217,17 @@ Async version of optim_step. #### `save_state` ```python -def save_state(name: str) -> APIFuture[types.SaveWeightsResponse] +def save_state( + name: str, + ttl_seconds: int | None = None +) -> APIFuture[types.SaveWeightsResponse] ``` Save model weights to persistent storage. Args: - `name`: Name for the saved checkpoint +- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires) Returns: - `APIFuture` containing the save response with checkpoint path @@ -234,7 +243,10 @@ print(f"Saved to: {result.path}") #### `save_state_async` ```python -async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse] +async def save_state_async( + name: str, + ttl_seconds: int | None = None +) -> APIFuture[types.SaveWeightsResponse] ``` Async version of save_state. @@ -310,13 +322,16 @@ Async version of load_state_with_optimizer. ```python def save_weights_for_sampler( - name: str) -> APIFuture[types.SaveWeightsForSamplerResponse] + name: str, + ttl_seconds: int | None = None +) -> APIFuture[types.SaveWeightsForSamplerResponse] ``` Save model weights for use with a SamplingClient. Args: - `name`: Name for the saved sampler weights +- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires) Returns: - `APIFuture` containing the save response with sampler path @@ -338,7 +353,9 @@ sampling_client = service_client.create_sampling_client( ```python async def save_weights_for_sampler_async( - name: str) -> APIFuture[types.SaveWeightsForSamplerResponse] + name: str, + ttl_seconds: int | None = None +) -> APIFuture[types.SaveWeightsForSamplerResponse] ``` Async version of save_weights_for_sampler. diff --git a/docs/api/types.md b/docs/api/types.md index 61b631d..5e2aa5e 100644 --- a/docs/api/types.md +++ b/docs/api/types.md @@ -26,7 +26,7 @@ Weight decay for the optimizer. Uses decoupled weight decay. #### `grad_clip_norm` -Gradient clip norm for the optimizer. 0.0 means no clipping. +Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping. ## `SupportedModel` Objects @@ -159,6 +159,10 @@ The size of the checkpoint in bytes Whether the checkpoint is publicly accessible +#### `expires_at` + +When this checkpoint expires (None = never expires) + ## `ParsedCheckpointTinkerPath` Objects ```python @@ -725,6 +729,10 @@ class SaveWeightsRequest(StrictBase) A file/directory name for the weights +#### `ttl_seconds` + +TTL in seconds for this checkpoint (None = never expires) + ## `LoraConfig` Objects ```python @@ -834,6 +842,10 @@ class SaveWeightsForSamplerRequest(StrictBase) A file/directory name for the weights +#### `ttl_seconds` + +TTL in seconds for this checkpoint (None = never expires) + ## `SamplingParams` Objects ```python diff --git a/pyproject.toml b/pyproject.toml index 6708dcc..a2c2f9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.8.1" +version = "0.9.0" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py index a57fe58..6ed948a 100644 --- a/src/tinker/cli/commands/checkpoint.py +++ b/src/tinker/cli/commands/checkpoint.py @@ -88,7 +88,7 @@ class CheckpointListOutput(OutputBase): def get_table_columns(self) -> List[str]: """Return column headers for table output.""" - return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"] + return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Expires", "Path"] def get_table_rows(self) -> List[List[str]]: """Return rows for table output.""" @@ -103,6 +103,7 @@ class CheckpointListOutput(OutputBase): else "N/A", format_bool(ckpt.public), format_timestamp(ckpt.time), + format_timestamp(ckpt.expires_at) if ckpt.expires_at else "Never", ckpt.tinker_path, ] ) @@ -161,6 +162,12 @@ class CheckpointInfoOutput(OutputBase): # Creation time rows.append(["Created", format_timestamp(self.checkpoint.time)]) + # Expiration time + if self.checkpoint.expires_at: + rows.append(["Expires", format_timestamp(self.checkpoint.expires_at)]) + else: + rows.append(["Expires", "Never"]) + # Parse training run ID from path if self.checkpoint.tinker_path.startswith("tinker://"): parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/") diff --git a/src/tinker/cli/output.py b/src/tinker/cli/output.py index 9f63795..54c22bb 100644 --- a/src/tinker/cli/output.py +++ b/src/tinker/cli/output.py @@ -177,22 +177,33 @@ def format_timestamp(dt: Union[datetime, str, None]) -> str: # Calculate time difference delta = now - dt - # Format based on age + # Handle future times (negative delta means dt is in the future) + is_future = delta.total_seconds() < 0 + if is_future: + delta = dt - now + suffix = "" + prefix = "in " + else: + suffix = " ago" + prefix = "" + + # Format based on magnitude + total_seconds = int(delta.total_seconds()) if delta.days > 30: return dt.strftime("%Y-%m-%d") elif delta.days > 7: weeks = delta.days // 7 - return f"{weeks} week{'s' if weeks > 1 else ''} ago" + return f"{prefix}{weeks} week{'s' if weeks > 1 else ''}{suffix}" elif delta.days > 0: - return f"{delta.days} day{'s' if delta.days > 1 else ''} ago" - elif delta.seconds > 3600: - hours = delta.seconds // 3600 - return f"{hours} hour{'s' if hours > 1 else ''} ago" - elif delta.seconds > 60: - minutes = delta.seconds // 60 - return f"{minutes} minute{'s' if minutes > 1 else ''} ago" + return f"{prefix}{delta.days} day{'s' if delta.days > 1 else ''}{suffix}" + elif total_seconds > 3600: + hours = total_seconds // 3600 + return f"{prefix}{hours} hour{'s' if hours > 1 else ''}{suffix}" + elif total_seconds > 60: + minutes = total_seconds // 60 + return f"{prefix}{minutes} minute{'s' if minutes > 1 else ''}{suffix}" else: - return "just now" + return "just now" if not is_future else "in less than a minute" except Exception: # If any error occurs, just return string representation diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index d640f92..d25101e 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -486,11 +486,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): return self.optim_step(adam_params) @capture_exceptions(fatal=True) - def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]: + def save_state( + self, name: str, ttl_seconds: int | None = None + ) -> APIFuture[types.SaveWeightsResponse]: """Save model weights to persistent storage. Args: - `name`: Name for the saved checkpoint + - `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires) Returns: - `APIFuture` containing the save response with checkpoint path @@ -514,6 +517,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): model_id=self._guaranteed_model_id(), path=name, seq_id=request_id + 1, + ttl_seconds=ttl_seconds, ) with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: return await client.weights.save( @@ -533,9 +537,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): return self.holder.run_coroutine_threadsafe(_save_state_async()) - async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]: + async def save_state_async( + self, name: str, ttl_seconds: int | None = None + ) -> APIFuture[types.SaveWeightsResponse]: """Async version of save_state.""" - return self.save_state(name) + return self.save_state(name, ttl_seconds=ttl_seconds) @capture_exceptions(fatal=True) async def _load_state_impl( @@ -625,7 +631,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): @capture_exceptions(fatal=True) async def _save_weights_for_sampler_impl( - self, request_id: int, name: str | None + self, request_id: int, name: str | None, ttl_seconds: int | None = None ) -> types.SaveWeightsForSamplerResponseInternal: assert asyncio.get_event_loop() == self.holder.get_loop() start_time = time.time() @@ -636,6 +642,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): model_id=self._guaranteed_model_id(), path=name, seq_id=request_id + 1, + ttl_seconds=ttl_seconds, ) else: sampling_session_seq_id = self.holder._sampling_client_counter @@ -644,6 +651,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): model_id=self._guaranteed_model_id(), seq_id=request_id + 1, sampling_session_seq_id=sampling_session_seq_id, + ttl_seconds=ttl_seconds, ) with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: return await client.weights.save_for_sampler( @@ -662,11 +670,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): ) @capture_exceptions(fatal=True) - def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]: + def save_weights_for_sampler( + self, name: str, ttl_seconds: int | None = None + ) -> APIFuture[types.SaveWeightsForSamplerResponse]: """Save model weights for use with a SamplingClient. Args: - `name`: Name for the saved sampler weights + - `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires) Returns: - `APIFuture` containing the save response with sampler path @@ -687,17 +698,17 @@ class TrainingClient(TelemetryProvider, QueueStateObserver): request_id = self._get_request_id() async def _save_weights_for_sampler_async(): - result = await self._save_weights_for_sampler_impl(request_id, name) + result = await self._save_weights_for_sampler_impl(request_id, name, ttl_seconds) assert result.path is not None return types.SaveWeightsForSamplerResponse(path=result.path) return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async()) async def save_weights_for_sampler_async( - self, name: str + self, name: str, ttl_seconds: int | None = None ) -> APIFuture[types.SaveWeightsForSamplerResponse]: """Async version of save_weights_for_sampler.""" - return self.save_weights_for_sampler(name) + return self.save_weights_for_sampler(name, ttl_seconds=ttl_seconds) def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]: async def _get_info_async(): diff --git a/src/tinker/types/checkpoint.py b/src/tinker/types/checkpoint.py index 70e3973..b6a5630 100644 --- a/src/tinker/types/checkpoint.py +++ b/src/tinker/types/checkpoint.py @@ -27,6 +27,9 @@ class Checkpoint(BaseModel): public: bool = False """Whether the checkpoint is publicly accessible""" + expires_at: datetime | None = None + """When this checkpoint expires (None = never expires)""" + class ParsedCheckpointTinkerPath(BaseModel): tinker_path: str diff --git a/src/tinker/types/save_weights_for_sampler_request.py b/src/tinker/types/save_weights_for_sampler_request.py index 7a27bf1..0df4abd 100644 --- a/src/tinker/types/save_weights_for_sampler_request.py +++ b/src/tinker/types/save_weights_for_sampler_request.py @@ -19,6 +19,9 @@ class SaveWeightsForSamplerRequest(StrictBase): seq_id: Optional[int] = None + ttl_seconds: Optional[int] = None + """TTL in seconds for this checkpoint (None = never expires)""" + type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler" if PYDANTIC_V2: diff --git a/src/tinker/types/save_weights_request.py b/src/tinker/types/save_weights_request.py index b86c51a..107d3e4 100644 --- a/src/tinker/types/save_weights_request.py +++ b/src/tinker/types/save_weights_request.py @@ -17,6 +17,9 @@ class SaveWeightsRequest(StrictBase): seq_id: Optional[int] = None + ttl_seconds: Optional[int] = None + """TTL in seconds for this checkpoint (None = never expires)""" + type: Literal["save_weights"] = "save_weights" if PYDANTIC_V2: diff --git a/tests/test_cli_output.py b/tests/test_cli_output.py new file mode 100644 index 0000000..21c13d5 --- /dev/null +++ b/tests/test_cli_output.py @@ -0,0 +1,253 @@ +"""Tests for CLI output formatting utilities.""" + +import re +from datetime import datetime, timedelta, timezone + +from tinker.cli.output import format_bool, format_size, format_timestamp + + +class TestFormatTimestamp: + """Tests for the format_timestamp function.""" + + def test_none_returns_na(self) -> None: + assert format_timestamp(None) == "N/A" + + def test_empty_string_returns_na(self) -> None: + assert format_timestamp("") == "N/A" + + def test_just_now_past(self) -> None: + """Times within the last minute should show 'just now'.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(seconds=30) + assert format_timestamp(dt) == "just now" + + def test_just_now_future(self) -> None: + """Times within the next minute should show 'in less than a minute'.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(seconds=30) + assert format_timestamp(dt) == "in less than a minute" + + def test_minutes_ago(self) -> None: + """Times a few minutes in the past.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(minutes=5, seconds=30) + result = format_timestamp(dt) + # Allow for slight timing variations (4-5 minutes) + assert re.match(r"[45] minutes ago", result), ( + f"Expected '4 minutes ago' or '5 minutes ago', got '{result}'" + ) + + def test_minutes_future(self) -> None: + """Times a few minutes in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(minutes=5, seconds=30) + result = format_timestamp(dt) + # Allow for slight timing variations (4-5 minutes) + assert re.match(r"in [45] minutes", result), ( + f"Expected 'in 4 minutes' or 'in 5 minutes', got '{result}'" + ) + + def test_one_minute_ago(self) -> None: + """Singular 'minute' for exactly 1 minute.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(minutes=1, seconds=30) + assert format_timestamp(dt) == "1 minute ago" + + def test_one_minute_future(self) -> None: + """Singular 'minute' for exactly 1 minute in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(minutes=1, seconds=30) + assert format_timestamp(dt) == "in 1 minute" + + def test_hours_ago(self) -> None: + """Times a few hours in the past.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(hours=3, minutes=30) + result = format_timestamp(dt) + # Allow for slight timing variations (2-3 hours) + assert re.match(r"[23] hours ago", result), ( + f"Expected '2 hours ago' or '3 hours ago', got '{result}'" + ) + + def test_hours_future(self) -> None: + """Times a few hours in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(hours=3, minutes=30) + result = format_timestamp(dt) + # Allow for slight timing variations (2-3 hours) + assert re.match(r"in [23] hours", result), ( + f"Expected 'in 2 hours' or 'in 3 hours', got '{result}'" + ) + + def test_one_hour_ago(self) -> None: + """Singular 'hour' for exactly 1 hour.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(hours=1, minutes=30) + assert format_timestamp(dt) == "1 hour ago" + + def test_one_hour_future(self) -> None: + """Singular 'hour' for exactly 1 hour in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(hours=1, minutes=30) + assert format_timestamp(dt) == "in 1 hour" + + def test_days_ago(self) -> None: + """Times a few days in the past.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(days=3, hours=12) + result = format_timestamp(dt) + # Allow for slight timing variations (2-3 days) + assert re.match(r"[23] days ago", result), ( + f"Expected '2 days ago' or '3 days ago', got '{result}'" + ) + + def test_days_future(self) -> None: + """Times a few days in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(days=3, hours=12) + result = format_timestamp(dt) + # Allow for slight timing variations (2-3 days) + assert re.match(r"in [23] days", result), ( + f"Expected 'in 2 days' or 'in 3 days', got '{result}'" + ) + + def test_one_day_ago(self) -> None: + """Singular 'day' for exactly 1 day.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(days=1, hours=12) + assert format_timestamp(dt) == "1 day ago" + + def test_one_day_future(self) -> None: + """Singular 'day' for exactly 1 day in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(days=1, hours=12) + assert format_timestamp(dt) == "in 1 day" + + def test_weeks_ago(self) -> None: + """Times a few weeks in the past.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(weeks=2, days=3) + result = format_timestamp(dt) + # Allow for slight timing variations (1-2 weeks) + assert re.match(r"[12] weeks? ago", result), ( + f"Expected '1 week ago' or '2 weeks ago', got '{result}'" + ) + + def test_weeks_future(self) -> None: + """Times a few weeks in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(weeks=2, days=3) + result = format_timestamp(dt) + # Allow for slight timing variations (1-2 weeks) + assert re.match(r"in [12] weeks?", result), ( + f"Expected 'in 1 week' or 'in 2 weeks', got '{result}'" + ) + + def test_one_week_ago(self) -> None: + """Singular 'week' for exactly 1 week.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(weeks=1, days=3) + assert format_timestamp(dt) == "1 week ago" + + def test_one_week_future(self) -> None: + """Singular 'week' for exactly 1 week in the future.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(weeks=1, days=3) + assert format_timestamp(dt) == "in 1 week" + + def test_old_date_shows_absolute(self) -> None: + """Dates more than 30 days ago show absolute date.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(days=45) + result = format_timestamp(dt) + # Should be in YYYY-MM-DD format + assert result == dt.strftime("%Y-%m-%d") + + def test_far_future_date_shows_absolute(self) -> None: + """Dates more than 30 days in future show absolute date.""" + now = datetime.now(timezone.utc) + dt = now + timedelta(days=45) + result = format_timestamp(dt) + # Should be in YYYY-MM-DD format + assert result == dt.strftime("%Y-%m-%d") + + def test_iso_string_input(self) -> None: + """ISO format strings are parsed correctly.""" + # Create a time 5 minutes ago + now = datetime.now(timezone.utc) + dt = now - timedelta(minutes=5) + iso_str = dt.isoformat() + result = format_timestamp(iso_str) + assert "minute" in result + + def test_iso_string_with_z_suffix(self) -> None: + """ISO strings with Z suffix are parsed correctly.""" + now = datetime.now(timezone.utc) + dt = now - timedelta(hours=2) + # Replace +00:00 with Z + iso_str = dt.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z" + result = format_timestamp(iso_str) + assert "hour" in result + + def test_naive_datetime_treated_as_utc(self) -> None: + """Naive datetimes (no timezone) are treated as UTC.""" + now = datetime.now(timezone.utc) + # Create naive datetime + naive_dt = (now - timedelta(minutes=10)).replace(tzinfo=None) + result = format_timestamp(naive_dt) + assert "minute" in result + + def test_non_utc_timezone(self) -> None: + """Datetimes with non-UTC timezone are converted properly.""" + # Create a timezone +5 hours from UTC + tz_plus5 = timezone(timedelta(hours=5)) + now_utc = datetime.now(timezone.utc) + # Create time 2 hours ago in UTC, but expressed in +5 timezone + dt = (now_utc - timedelta(hours=2)).astimezone(tz_plus5) + result = format_timestamp(dt) + assert "hour" in result + + def test_invalid_string_returns_string(self) -> None: + """Invalid datetime strings are returned as-is.""" + result = format_timestamp("not a date") + assert result == "not a date" + + def test_non_datetime_object_returns_string(self) -> None: + """Non-datetime objects are converted to string.""" + result = format_timestamp(12345) # type: ignore + assert result == "12345" + + +class TestFormatSize: + """Tests for the format_size function.""" + + def test_bytes(self) -> None: + assert format_size(500) == "500 B" + + def test_kilobytes(self) -> None: + assert format_size(1536) == "1.5 KB" + + def test_megabytes(self) -> None: + assert format_size(1572864) == "1.5 MB" + + def test_gigabytes(self) -> None: + assert format_size(1610612736) == "1.5 GB" + + def test_terabytes(self) -> None: + assert format_size(1649267441664) == "1.5 TB" + + def test_zero_bytes(self) -> None: + assert format_size(0) == "0 B" + + def test_negative_returns_na(self) -> None: + assert format_size(-1) == "N/A" + + +class TestFormatBool: + """Tests for the format_bool function.""" + + def test_true(self) -> None: + assert format_bool(True) == "Yes" + + def test_false(self) -> None: + assert format_bool(False) == "No"