mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
ad03d44978
commit
ca40e08bb4
12 changed files with 358 additions and 27 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1",
|
||||
"last_sync_time": "2026-01-21T23:13:02.930293"
|
||||
"last_synced_sha": "9ee57ec63f1159d18309fabc51bead0105d52917",
|
||||
"last_sync_time": "2026-01-25T05:52:42.500531"
|
||||
}
|
||||
|
|
@ -110,3 +110,14 @@ async def compute_logprobs_async(
|
|||
```
|
||||
|
||||
Async version of compute_logprobs.
|
||||
|
||||
#### `get_tokenizer`
|
||||
|
||||
```python
|
||||
def get_tokenizer() -> PreTrainedTokenizer
|
||||
```
|
||||
|
||||
Get the tokenizer for the current model.
|
||||
|
||||
Returns:
|
||||
- `PreTrainedTokenizer` compatible with the model
|
||||
|
|
|
|||
|
|
@ -176,6 +176,11 @@ def optim_step(
|
|||
|
||||
Update model parameters using Adam optimizer.
|
||||
|
||||
The Adam optimizer used by tinker is identical
|
||||
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
||||
|
||||
|
||||
Args:
|
||||
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||
|
||||
|
|
@ -212,13 +217,17 @@ Async version of optim_step.
|
|||
#### `save_state`
|
||||
|
||||
```python
|
||||
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
def save_state(
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Save model weights to persistent storage.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved checkpoint
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with checkpoint path
|
||||
|
|
@ -234,7 +243,10 @@ print(f"Saved to: {result.path}")
|
|||
#### `save_state_async`
|
||||
|
||||
```python
|
||||
async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
async def save_state_async(
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Async version of save_state.
|
||||
|
|
@ -310,13 +322,16 @@ Async version of load_state_with_optimizer.
|
|||
|
||||
```python
|
||||
def save_weights_for_sampler(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Save model weights for use with a SamplingClient.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved sampler weights
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with sampler path
|
||||
|
|
@ -338,7 +353,9 @@ sampling_client = service_client.create_sampling_client(
|
|||
|
||||
```python
|
||||
async def save_weights_for_sampler_async(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Async version of save_weights_for_sampler.
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Weight decay for the optimizer. Uses decoupled weight decay.
|
|||
|
||||
#### `grad_clip_norm`
|
||||
|
||||
Gradient clip norm for the optimizer. 0.0 means no clipping.
|
||||
Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping.
|
||||
|
||||
## `SupportedModel` Objects
|
||||
|
||||
|
|
@ -159,6 +159,10 @@ The size of the checkpoint in bytes
|
|||
|
||||
Whether the checkpoint is publicly accessible
|
||||
|
||||
#### `expires_at`
|
||||
|
||||
When this checkpoint expires (None = never expires)
|
||||
|
||||
## `ParsedCheckpointTinkerPath` Objects
|
||||
|
||||
```python
|
||||
|
|
@ -725,6 +729,10 @@ class SaveWeightsRequest(StrictBase)
|
|||
|
||||
A file/directory name for the weights
|
||||
|
||||
#### `ttl_seconds`
|
||||
|
||||
TTL in seconds for this checkpoint (None = never expires)
|
||||
|
||||
## `LoraConfig` Objects
|
||||
|
||||
```python
|
||||
|
|
@ -834,6 +842,10 @@ class SaveWeightsForSamplerRequest(StrictBase)
|
|||
|
||||
A file/directory name for the weights
|
||||
|
||||
#### `ttl_seconds`
|
||||
|
||||
TTL in seconds for this checkpoint (None = never expires)
|
||||
|
||||
## `SamplingParams` Objects
|
||||
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.8.1"
|
||||
version = "0.9.0"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class CheckpointListOutput(OutputBase):
|
|||
|
||||
def get_table_columns(self) -> List[str]:
|
||||
"""Return column headers for table output."""
|
||||
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"]
|
||||
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Expires", "Path"]
|
||||
|
||||
def get_table_rows(self) -> List[List[str]]:
|
||||
"""Return rows for table output."""
|
||||
|
|
@ -103,6 +103,7 @@ class CheckpointListOutput(OutputBase):
|
|||
else "N/A",
|
||||
format_bool(ckpt.public),
|
||||
format_timestamp(ckpt.time),
|
||||
format_timestamp(ckpt.expires_at) if ckpt.expires_at else "Never",
|
||||
ckpt.tinker_path,
|
||||
]
|
||||
)
|
||||
|
|
@ -161,6 +162,12 @@ class CheckpointInfoOutput(OutputBase):
|
|||
# Creation time
|
||||
rows.append(["Created", format_timestamp(self.checkpoint.time)])
|
||||
|
||||
# Expiration time
|
||||
if self.checkpoint.expires_at:
|
||||
rows.append(["Expires", format_timestamp(self.checkpoint.expires_at)])
|
||||
else:
|
||||
rows.append(["Expires", "Never"])
|
||||
|
||||
# Parse training run ID from path
|
||||
if self.checkpoint.tinker_path.startswith("tinker://"):
|
||||
parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/")
|
||||
|
|
|
|||
|
|
@ -177,22 +177,33 @@ def format_timestamp(dt: Union[datetime, str, None]) -> str:
|
|||
# Calculate time difference
|
||||
delta = now - dt
|
||||
|
||||
# Format based on age
|
||||
# Handle future times (negative delta means dt is in the future)
|
||||
is_future = delta.total_seconds() < 0
|
||||
if is_future:
|
||||
delta = dt - now
|
||||
suffix = ""
|
||||
prefix = "in "
|
||||
else:
|
||||
suffix = " ago"
|
||||
prefix = ""
|
||||
|
||||
# Format based on magnitude
|
||||
total_seconds = int(delta.total_seconds())
|
||||
if delta.days > 30:
|
||||
return dt.strftime("%Y-%m-%d")
|
||||
elif delta.days > 7:
|
||||
weeks = delta.days // 7
|
||||
return f"{weeks} week{'s' if weeks > 1 else ''} ago"
|
||||
return f"{prefix}{weeks} week{'s' if weeks > 1 else ''}{suffix}"
|
||||
elif delta.days > 0:
|
||||
return f"{delta.days} day{'s' if delta.days > 1 else ''} ago"
|
||||
elif delta.seconds > 3600:
|
||||
hours = delta.seconds // 3600
|
||||
return f"{hours} hour{'s' if hours > 1 else ''} ago"
|
||||
elif delta.seconds > 60:
|
||||
minutes = delta.seconds // 60
|
||||
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
|
||||
return f"{prefix}{delta.days} day{'s' if delta.days > 1 else ''}{suffix}"
|
||||
elif total_seconds > 3600:
|
||||
hours = total_seconds // 3600
|
||||
return f"{prefix}{hours} hour{'s' if hours > 1 else ''}{suffix}"
|
||||
elif total_seconds > 60:
|
||||
minutes = total_seconds // 60
|
||||
return f"{prefix}{minutes} minute{'s' if minutes > 1 else ''}{suffix}"
|
||||
else:
|
||||
return "just now"
|
||||
return "just now" if not is_future else "in less than a minute"
|
||||
|
||||
except Exception:
|
||||
# If any error occurs, just return string representation
|
||||
|
|
|
|||
|
|
@ -486,11 +486,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
return self.optim_step(adam_params)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
|
||||
def save_state(
|
||||
self, name: str, ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]:
|
||||
"""Save model weights to persistent storage.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved checkpoint
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with checkpoint path
|
||||
|
|
@ -514,6 +517,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
model_id=self._guaranteed_model_id(),
|
||||
path=name,
|
||||
seq_id=request_id + 1,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save(
|
||||
|
|
@ -533,9 +537,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
|
||||
return self.holder.run_coroutine_threadsafe(_save_state_async())
|
||||
|
||||
async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
|
||||
async def save_state_async(
|
||||
self, name: str, ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]:
|
||||
"""Async version of save_state."""
|
||||
return self.save_state(name)
|
||||
return self.save_state(name, ttl_seconds=ttl_seconds)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _load_state_impl(
|
||||
|
|
@ -625,7 +631,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _save_weights_for_sampler_impl(
|
||||
self, request_id: int, name: str | None
|
||||
self, request_id: int, name: str | None, ttl_seconds: int | None = None
|
||||
) -> types.SaveWeightsForSamplerResponseInternal:
|
||||
assert asyncio.get_event_loop() == self.holder.get_loop()
|
||||
start_time = time.time()
|
||||
|
|
@ -636,6 +642,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
model_id=self._guaranteed_model_id(),
|
||||
path=name,
|
||||
seq_id=request_id + 1,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
else:
|
||||
sampling_session_seq_id = self.holder._sampling_client_counter
|
||||
|
|
@ -644,6 +651,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
model_id=self._guaranteed_model_id(),
|
||||
seq_id=request_id + 1,
|
||||
sampling_session_seq_id=sampling_session_seq_id,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save_for_sampler(
|
||||
|
|
@ -662,11 +670,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||
def save_weights_for_sampler(
|
||||
self, name: str, ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||
"""Save model weights for use with a SamplingClient.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved sampler weights
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with sampler path
|
||||
|
|
@ -687,17 +698,17 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
|||
request_id = self._get_request_id()
|
||||
|
||||
async def _save_weights_for_sampler_async():
|
||||
result = await self._save_weights_for_sampler_impl(request_id, name)
|
||||
result = await self._save_weights_for_sampler_impl(request_id, name, ttl_seconds)
|
||||
assert result.path is not None
|
||||
return types.SaveWeightsForSamplerResponse(path=result.path)
|
||||
|
||||
return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async())
|
||||
|
||||
async def save_weights_for_sampler_async(
|
||||
self, name: str
|
||||
self, name: str, ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||
"""Async version of save_weights_for_sampler."""
|
||||
return self.save_weights_for_sampler(name)
|
||||
return self.save_weights_for_sampler(name, ttl_seconds=ttl_seconds)
|
||||
|
||||
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
|
||||
async def _get_info_async():
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ class Checkpoint(BaseModel):
|
|||
public: bool = False
|
||||
"""Whether the checkpoint is publicly accessible"""
|
||||
|
||||
expires_at: datetime | None = None
|
||||
"""When this checkpoint expires (None = never expires)"""
|
||||
|
||||
|
||||
class ParsedCheckpointTinkerPath(BaseModel):
|
||||
tinker_path: str
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ class SaveWeightsForSamplerRequest(StrictBase):
|
|||
|
||||
seq_id: Optional[int] = None
|
||||
|
||||
ttl_seconds: Optional[int] = None
|
||||
"""TTL in seconds for this checkpoint (None = never expires)"""
|
||||
|
||||
type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler"
|
||||
|
||||
if PYDANTIC_V2:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ class SaveWeightsRequest(StrictBase):
|
|||
|
||||
seq_id: Optional[int] = None
|
||||
|
||||
ttl_seconds: Optional[int] = None
|
||||
"""TTL in seconds for this checkpoint (None = never expires)"""
|
||||
|
||||
type: Literal["save_weights"] = "save_weights"
|
||||
|
||||
if PYDANTIC_V2:
|
||||
|
|
|
|||
253
tests/test_cli_output.py
Normal file
253
tests/test_cli_output.py
Normal file
|
|
@ -0,0 +1,253 @@
|
|||
"""Tests for CLI output formatting utilities."""
|
||||
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from tinker.cli.output import format_bool, format_size, format_timestamp
|
||||
|
||||
|
||||
class TestFormatTimestamp:
|
||||
"""Tests for the format_timestamp function."""
|
||||
|
||||
def test_none_returns_na(self) -> None:
|
||||
assert format_timestamp(None) == "N/A"
|
||||
|
||||
def test_empty_string_returns_na(self) -> None:
|
||||
assert format_timestamp("") == "N/A"
|
||||
|
||||
def test_just_now_past(self) -> None:
|
||||
"""Times within the last minute should show 'just now'."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(seconds=30)
|
||||
assert format_timestamp(dt) == "just now"
|
||||
|
||||
def test_just_now_future(self) -> None:
|
||||
"""Times within the next minute should show 'in less than a minute'."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(seconds=30)
|
||||
assert format_timestamp(dt) == "in less than a minute"
|
||||
|
||||
def test_minutes_ago(self) -> None:
|
||||
"""Times a few minutes in the past."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(minutes=5, seconds=30)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (4-5 minutes)
|
||||
assert re.match(r"[45] minutes ago", result), (
|
||||
f"Expected '4 minutes ago' or '5 minutes ago', got '{result}'"
|
||||
)
|
||||
|
||||
def test_minutes_future(self) -> None:
|
||||
"""Times a few minutes in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(minutes=5, seconds=30)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (4-5 minutes)
|
||||
assert re.match(r"in [45] minutes", result), (
|
||||
f"Expected 'in 4 minutes' or 'in 5 minutes', got '{result}'"
|
||||
)
|
||||
|
||||
def test_one_minute_ago(self) -> None:
|
||||
"""Singular 'minute' for exactly 1 minute."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(minutes=1, seconds=30)
|
||||
assert format_timestamp(dt) == "1 minute ago"
|
||||
|
||||
def test_one_minute_future(self) -> None:
|
||||
"""Singular 'minute' for exactly 1 minute in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(minutes=1, seconds=30)
|
||||
assert format_timestamp(dt) == "in 1 minute"
|
||||
|
||||
def test_hours_ago(self) -> None:
|
||||
"""Times a few hours in the past."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(hours=3, minutes=30)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (2-3 hours)
|
||||
assert re.match(r"[23] hours ago", result), (
|
||||
f"Expected '2 hours ago' or '3 hours ago', got '{result}'"
|
||||
)
|
||||
|
||||
def test_hours_future(self) -> None:
|
||||
"""Times a few hours in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(hours=3, minutes=30)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (2-3 hours)
|
||||
assert re.match(r"in [23] hours", result), (
|
||||
f"Expected 'in 2 hours' or 'in 3 hours', got '{result}'"
|
||||
)
|
||||
|
||||
def test_one_hour_ago(self) -> None:
|
||||
"""Singular 'hour' for exactly 1 hour."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(hours=1, minutes=30)
|
||||
assert format_timestamp(dt) == "1 hour ago"
|
||||
|
||||
def test_one_hour_future(self) -> None:
|
||||
"""Singular 'hour' for exactly 1 hour in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(hours=1, minutes=30)
|
||||
assert format_timestamp(dt) == "in 1 hour"
|
||||
|
||||
def test_days_ago(self) -> None:
|
||||
"""Times a few days in the past."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(days=3, hours=12)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (2-3 days)
|
||||
assert re.match(r"[23] days ago", result), (
|
||||
f"Expected '2 days ago' or '3 days ago', got '{result}'"
|
||||
)
|
||||
|
||||
def test_days_future(self) -> None:
|
||||
"""Times a few days in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(days=3, hours=12)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (2-3 days)
|
||||
assert re.match(r"in [23] days", result), (
|
||||
f"Expected 'in 2 days' or 'in 3 days', got '{result}'"
|
||||
)
|
||||
|
||||
def test_one_day_ago(self) -> None:
|
||||
"""Singular 'day' for exactly 1 day."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(days=1, hours=12)
|
||||
assert format_timestamp(dt) == "1 day ago"
|
||||
|
||||
def test_one_day_future(self) -> None:
|
||||
"""Singular 'day' for exactly 1 day in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(days=1, hours=12)
|
||||
assert format_timestamp(dt) == "in 1 day"
|
||||
|
||||
def test_weeks_ago(self) -> None:
|
||||
"""Times a few weeks in the past."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(weeks=2, days=3)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (1-2 weeks)
|
||||
assert re.match(r"[12] weeks? ago", result), (
|
||||
f"Expected '1 week ago' or '2 weeks ago', got '{result}'"
|
||||
)
|
||||
|
||||
def test_weeks_future(self) -> None:
|
||||
"""Times a few weeks in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(weeks=2, days=3)
|
||||
result = format_timestamp(dt)
|
||||
# Allow for slight timing variations (1-2 weeks)
|
||||
assert re.match(r"in [12] weeks?", result), (
|
||||
f"Expected 'in 1 week' or 'in 2 weeks', got '{result}'"
|
||||
)
|
||||
|
||||
def test_one_week_ago(self) -> None:
|
||||
"""Singular 'week' for exactly 1 week."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(weeks=1, days=3)
|
||||
assert format_timestamp(dt) == "1 week ago"
|
||||
|
||||
def test_one_week_future(self) -> None:
|
||||
"""Singular 'week' for exactly 1 week in the future."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(weeks=1, days=3)
|
||||
assert format_timestamp(dt) == "in 1 week"
|
||||
|
||||
def test_old_date_shows_absolute(self) -> None:
|
||||
"""Dates more than 30 days ago show absolute date."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(days=45)
|
||||
result = format_timestamp(dt)
|
||||
# Should be in YYYY-MM-DD format
|
||||
assert result == dt.strftime("%Y-%m-%d")
|
||||
|
||||
def test_far_future_date_shows_absolute(self) -> None:
|
||||
"""Dates more than 30 days in future show absolute date."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now + timedelta(days=45)
|
||||
result = format_timestamp(dt)
|
||||
# Should be in YYYY-MM-DD format
|
||||
assert result == dt.strftime("%Y-%m-%d")
|
||||
|
||||
def test_iso_string_input(self) -> None:
|
||||
"""ISO format strings are parsed correctly."""
|
||||
# Create a time 5 minutes ago
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(minutes=5)
|
||||
iso_str = dt.isoformat()
|
||||
result = format_timestamp(iso_str)
|
||||
assert "minute" in result
|
||||
|
||||
def test_iso_string_with_z_suffix(self) -> None:
|
||||
"""ISO strings with Z suffix are parsed correctly."""
|
||||
now = datetime.now(timezone.utc)
|
||||
dt = now - timedelta(hours=2)
|
||||
# Replace +00:00 with Z
|
||||
iso_str = dt.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
|
||||
result = format_timestamp(iso_str)
|
||||
assert "hour" in result
|
||||
|
||||
def test_naive_datetime_treated_as_utc(self) -> None:
|
||||
"""Naive datetimes (no timezone) are treated as UTC."""
|
||||
now = datetime.now(timezone.utc)
|
||||
# Create naive datetime
|
||||
naive_dt = (now - timedelta(minutes=10)).replace(tzinfo=None)
|
||||
result = format_timestamp(naive_dt)
|
||||
assert "minute" in result
|
||||
|
||||
def test_non_utc_timezone(self) -> None:
|
||||
"""Datetimes with non-UTC timezone are converted properly."""
|
||||
# Create a timezone +5 hours from UTC
|
||||
tz_plus5 = timezone(timedelta(hours=5))
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
# Create time 2 hours ago in UTC, but expressed in +5 timezone
|
||||
dt = (now_utc - timedelta(hours=2)).astimezone(tz_plus5)
|
||||
result = format_timestamp(dt)
|
||||
assert "hour" in result
|
||||
|
||||
def test_invalid_string_returns_string(self) -> None:
|
||||
"""Invalid datetime strings are returned as-is."""
|
||||
result = format_timestamp("not a date")
|
||||
assert result == "not a date"
|
||||
|
||||
def test_non_datetime_object_returns_string(self) -> None:
|
||||
"""Non-datetime objects are converted to string."""
|
||||
result = format_timestamp(12345) # type: ignore
|
||||
assert result == "12345"
|
||||
|
||||
|
||||
class TestFormatSize:
|
||||
"""Tests for the format_size function."""
|
||||
|
||||
def test_bytes(self) -> None:
|
||||
assert format_size(500) == "500 B"
|
||||
|
||||
def test_kilobytes(self) -> None:
|
||||
assert format_size(1536) == "1.5 KB"
|
||||
|
||||
def test_megabytes(self) -> None:
|
||||
assert format_size(1572864) == "1.5 MB"
|
||||
|
||||
def test_gigabytes(self) -> None:
|
||||
assert format_size(1610612736) == "1.5 GB"
|
||||
|
||||
def test_terabytes(self) -> None:
|
||||
assert format_size(1649267441664) == "1.5 TB"
|
||||
|
||||
def test_zero_bytes(self) -> None:
|
||||
assert format_size(0) == "0 B"
|
||||
|
||||
def test_negative_returns_na(self) -> None:
|
||||
assert format_size(-1) == "N/A"
|
||||
|
||||
|
||||
class TestFormatBool:
|
||||
"""Tests for the format_bool function."""
|
||||
|
||||
def test_true(self) -> None:
|
||||
assert format_bool(True) == "Yes"
|
||||
|
||||
def test_false(self) -> None:
|
||||
assert format_bool(False) == "No"
|
||||
Loading…
Add table
Add a link
Reference in a new issue