mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
ad03d44978
commit
ca40e08bb4
12 changed files with 358 additions and 27 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1",
|
"last_synced_sha": "9ee57ec63f1159d18309fabc51bead0105d52917",
|
||||||
"last_sync_time": "2026-01-21T23:13:02.930293"
|
"last_sync_time": "2026-01-25T05:52:42.500531"
|
||||||
}
|
}
|
||||||
|
|
@ -110,3 +110,14 @@ async def compute_logprobs_async(
|
||||||
```
|
```
|
||||||
|
|
||||||
Async version of compute_logprobs.
|
Async version of compute_logprobs.
|
||||||
|
|
||||||
|
#### `get_tokenizer`
|
||||||
|
|
||||||
|
```python
|
||||||
|
def get_tokenizer() -> PreTrainedTokenizer
|
||||||
|
```
|
||||||
|
|
||||||
|
Get the tokenizer for the current model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- `PreTrainedTokenizer` compatible with the model
|
||||||
|
|
|
||||||
|
|
@ -176,6 +176,11 @@ def optim_step(
|
||||||
|
|
||||||
Update model parameters using Adam optimizer.
|
Update model parameters using Adam optimizer.
|
||||||
|
|
||||||
|
The Adam optimizer used by tinker is identical
|
||||||
|
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||||
|
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||||
|
|
||||||
|
|
@ -212,13 +217,17 @@ Async version of optim_step.
|
||||||
#### `save_state`
|
#### `save_state`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
|
def save_state(
|
||||||
|
name: str,
|
||||||
|
ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsResponse]
|
||||||
```
|
```
|
||||||
|
|
||||||
Save model weights to persistent storage.
|
Save model weights to persistent storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- `name`: Name for the saved checkpoint
|
- `name`: Name for the saved checkpoint
|
||||||
|
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- `APIFuture` containing the save response with checkpoint path
|
- `APIFuture` containing the save response with checkpoint path
|
||||||
|
|
@ -234,7 +243,10 @@ print(f"Saved to: {result.path}")
|
||||||
#### `save_state_async`
|
#### `save_state_async`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse]
|
async def save_state_async(
|
||||||
|
name: str,
|
||||||
|
ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsResponse]
|
||||||
```
|
```
|
||||||
|
|
||||||
Async version of save_state.
|
Async version of save_state.
|
||||||
|
|
@ -310,13 +322,16 @@ Async version of load_state_with_optimizer.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def save_weights_for_sampler(
|
def save_weights_for_sampler(
|
||||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
name: str,
|
||||||
|
ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||||
```
|
```
|
||||||
|
|
||||||
Save model weights for use with a SamplingClient.
|
Save model weights for use with a SamplingClient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- `name`: Name for the saved sampler weights
|
- `name`: Name for the saved sampler weights
|
||||||
|
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- `APIFuture` containing the save response with sampler path
|
- `APIFuture` containing the save response with sampler path
|
||||||
|
|
@ -338,7 +353,9 @@ sampling_client = service_client.create_sampling_client(
|
||||||
|
|
||||||
```python
|
```python
|
||||||
async def save_weights_for_sampler_async(
|
async def save_weights_for_sampler_async(
|
||||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
name: str,
|
||||||
|
ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||||
```
|
```
|
||||||
|
|
||||||
Async version of save_weights_for_sampler.
|
Async version of save_weights_for_sampler.
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ Weight decay for the optimizer. Uses decoupled weight decay.
|
||||||
|
|
||||||
#### `grad_clip_norm`
|
#### `grad_clip_norm`
|
||||||
|
|
||||||
Gradient clip norm for the optimizer. 0.0 means no clipping.
|
Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping.
|
||||||
|
|
||||||
## `SupportedModel` Objects
|
## `SupportedModel` Objects
|
||||||
|
|
||||||
|
|
@ -159,6 +159,10 @@ The size of the checkpoint in bytes
|
||||||
|
|
||||||
Whether the checkpoint is publicly accessible
|
Whether the checkpoint is publicly accessible
|
||||||
|
|
||||||
|
#### `expires_at`
|
||||||
|
|
||||||
|
When this checkpoint expires (None = never expires)
|
||||||
|
|
||||||
## `ParsedCheckpointTinkerPath` Objects
|
## `ParsedCheckpointTinkerPath` Objects
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
@ -725,6 +729,10 @@ class SaveWeightsRequest(StrictBase)
|
||||||
|
|
||||||
A file/directory name for the weights
|
A file/directory name for the weights
|
||||||
|
|
||||||
|
#### `ttl_seconds`
|
||||||
|
|
||||||
|
TTL in seconds for this checkpoint (None = never expires)
|
||||||
|
|
||||||
## `LoraConfig` Objects
|
## `LoraConfig` Objects
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
@ -834,6 +842,10 @@ class SaveWeightsForSamplerRequest(StrictBase)
|
||||||
|
|
||||||
A file/directory name for the weights
|
A file/directory name for the weights
|
||||||
|
|
||||||
|
#### `ttl_seconds`
|
||||||
|
|
||||||
|
TTL in seconds for this checkpoint (None = never expires)
|
||||||
|
|
||||||
## `SamplingParams` Objects
|
## `SamplingParams` Objects
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "tinker"
|
name = "tinker"
|
||||||
version = "0.8.1"
|
version = "0.9.0"
|
||||||
description = "The official Python SDK for the tinker API"
|
description = "The official Python SDK for the tinker API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ class CheckpointListOutput(OutputBase):
|
||||||
|
|
||||||
def get_table_columns(self) -> List[str]:
|
def get_table_columns(self) -> List[str]:
|
||||||
"""Return column headers for table output."""
|
"""Return column headers for table output."""
|
||||||
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"]
|
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Expires", "Path"]
|
||||||
|
|
||||||
def get_table_rows(self) -> List[List[str]]:
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
"""Return rows for table output."""
|
"""Return rows for table output."""
|
||||||
|
|
@ -103,6 +103,7 @@ class CheckpointListOutput(OutputBase):
|
||||||
else "N/A",
|
else "N/A",
|
||||||
format_bool(ckpt.public),
|
format_bool(ckpt.public),
|
||||||
format_timestamp(ckpt.time),
|
format_timestamp(ckpt.time),
|
||||||
|
format_timestamp(ckpt.expires_at) if ckpt.expires_at else "Never",
|
||||||
ckpt.tinker_path,
|
ckpt.tinker_path,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -161,6 +162,12 @@ class CheckpointInfoOutput(OutputBase):
|
||||||
# Creation time
|
# Creation time
|
||||||
rows.append(["Created", format_timestamp(self.checkpoint.time)])
|
rows.append(["Created", format_timestamp(self.checkpoint.time)])
|
||||||
|
|
||||||
|
# Expiration time
|
||||||
|
if self.checkpoint.expires_at:
|
||||||
|
rows.append(["Expires", format_timestamp(self.checkpoint.expires_at)])
|
||||||
|
else:
|
||||||
|
rows.append(["Expires", "Never"])
|
||||||
|
|
||||||
# Parse training run ID from path
|
# Parse training run ID from path
|
||||||
if self.checkpoint.tinker_path.startswith("tinker://"):
|
if self.checkpoint.tinker_path.startswith("tinker://"):
|
||||||
parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/")
|
parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/")
|
||||||
|
|
|
||||||
|
|
@ -177,22 +177,33 @@ def format_timestamp(dt: Union[datetime, str, None]) -> str:
|
||||||
# Calculate time difference
|
# Calculate time difference
|
||||||
delta = now - dt
|
delta = now - dt
|
||||||
|
|
||||||
# Format based on age
|
# Handle future times (negative delta means dt is in the future)
|
||||||
|
is_future = delta.total_seconds() < 0
|
||||||
|
if is_future:
|
||||||
|
delta = dt - now
|
||||||
|
suffix = ""
|
||||||
|
prefix = "in "
|
||||||
|
else:
|
||||||
|
suffix = " ago"
|
||||||
|
prefix = ""
|
||||||
|
|
||||||
|
# Format based on magnitude
|
||||||
|
total_seconds = int(delta.total_seconds())
|
||||||
if delta.days > 30:
|
if delta.days > 30:
|
||||||
return dt.strftime("%Y-%m-%d")
|
return dt.strftime("%Y-%m-%d")
|
||||||
elif delta.days > 7:
|
elif delta.days > 7:
|
||||||
weeks = delta.days // 7
|
weeks = delta.days // 7
|
||||||
return f"{weeks} week{'s' if weeks > 1 else ''} ago"
|
return f"{prefix}{weeks} week{'s' if weeks > 1 else ''}{suffix}"
|
||||||
elif delta.days > 0:
|
elif delta.days > 0:
|
||||||
return f"{delta.days} day{'s' if delta.days > 1 else ''} ago"
|
return f"{prefix}{delta.days} day{'s' if delta.days > 1 else ''}{suffix}"
|
||||||
elif delta.seconds > 3600:
|
elif total_seconds > 3600:
|
||||||
hours = delta.seconds // 3600
|
hours = total_seconds // 3600
|
||||||
return f"{hours} hour{'s' if hours > 1 else ''} ago"
|
return f"{prefix}{hours} hour{'s' if hours > 1 else ''}{suffix}"
|
||||||
elif delta.seconds > 60:
|
elif total_seconds > 60:
|
||||||
minutes = delta.seconds // 60
|
minutes = total_seconds // 60
|
||||||
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
|
return f"{prefix}{minutes} minute{'s' if minutes > 1 else ''}{suffix}"
|
||||||
else:
|
else:
|
||||||
return "just now"
|
return "just now" if not is_future else "in less than a minute"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# If any error occurs, just return string representation
|
# If any error occurs, just return string representation
|
||||||
|
|
|
||||||
|
|
@ -486,11 +486,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
return self.optim_step(adam_params)
|
return self.optim_step(adam_params)
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
|
def save_state(
|
||||||
|
self, name: str, ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsResponse]:
|
||||||
"""Save model weights to persistent storage.
|
"""Save model weights to persistent storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- `name`: Name for the saved checkpoint
|
- `name`: Name for the saved checkpoint
|
||||||
|
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- `APIFuture` containing the save response with checkpoint path
|
- `APIFuture` containing the save response with checkpoint path
|
||||||
|
|
@ -514,6 +517,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
model_id=self._guaranteed_model_id(),
|
model_id=self._guaranteed_model_id(),
|
||||||
path=name,
|
path=name,
|
||||||
seq_id=request_id + 1,
|
seq_id=request_id + 1,
|
||||||
|
ttl_seconds=ttl_seconds,
|
||||||
)
|
)
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
return await client.weights.save(
|
return await client.weights.save(
|
||||||
|
|
@ -533,9 +537,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
|
|
||||||
return self.holder.run_coroutine_threadsafe(_save_state_async())
|
return self.holder.run_coroutine_threadsafe(_save_state_async())
|
||||||
|
|
||||||
async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]:
|
async def save_state_async(
|
||||||
|
self, name: str, ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsResponse]:
|
||||||
"""Async version of save_state."""
|
"""Async version of save_state."""
|
||||||
return self.save_state(name)
|
return self.save_state(name, ttl_seconds=ttl_seconds)
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def _load_state_impl(
|
async def _load_state_impl(
|
||||||
|
|
@ -625,7 +631,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
async def _save_weights_for_sampler_impl(
|
async def _save_weights_for_sampler_impl(
|
||||||
self, request_id: int, name: str | None
|
self, request_id: int, name: str | None, ttl_seconds: int | None = None
|
||||||
) -> types.SaveWeightsForSamplerResponseInternal:
|
) -> types.SaveWeightsForSamplerResponseInternal:
|
||||||
assert asyncio.get_event_loop() == self.holder.get_loop()
|
assert asyncio.get_event_loop() == self.holder.get_loop()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
@ -636,6 +642,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
model_id=self._guaranteed_model_id(),
|
model_id=self._guaranteed_model_id(),
|
||||||
path=name,
|
path=name,
|
||||||
seq_id=request_id + 1,
|
seq_id=request_id + 1,
|
||||||
|
ttl_seconds=ttl_seconds,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampling_session_seq_id = self.holder._sampling_client_counter
|
sampling_session_seq_id = self.holder._sampling_client_counter
|
||||||
|
|
@ -644,6 +651,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
model_id=self._guaranteed_model_id(),
|
model_id=self._guaranteed_model_id(),
|
||||||
seq_id=request_id + 1,
|
seq_id=request_id + 1,
|
||||||
sampling_session_seq_id=sampling_session_seq_id,
|
sampling_session_seq_id=sampling_session_seq_id,
|
||||||
|
ttl_seconds=ttl_seconds,
|
||||||
)
|
)
|
||||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
return await client.weights.save_for_sampler(
|
return await client.weights.save_for_sampler(
|
||||||
|
|
@ -662,11 +670,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
)
|
)
|
||||||
|
|
||||||
@capture_exceptions(fatal=True)
|
@capture_exceptions(fatal=True)
|
||||||
def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
def save_weights_for_sampler(
|
||||||
|
self, name: str, ttl_seconds: int | None = None
|
||||||
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||||
"""Save model weights for use with a SamplingClient.
|
"""Save model weights for use with a SamplingClient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- `name`: Name for the saved sampler weights
|
- `name`: Name for the saved sampler weights
|
||||||
|
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- `APIFuture` containing the save response with sampler path
|
- `APIFuture` containing the save response with sampler path
|
||||||
|
|
@ -687,17 +698,17 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
|
||||||
request_id = self._get_request_id()
|
request_id = self._get_request_id()
|
||||||
|
|
||||||
async def _save_weights_for_sampler_async():
|
async def _save_weights_for_sampler_async():
|
||||||
result = await self._save_weights_for_sampler_impl(request_id, name)
|
result = await self._save_weights_for_sampler_impl(request_id, name, ttl_seconds)
|
||||||
assert result.path is not None
|
assert result.path is not None
|
||||||
return types.SaveWeightsForSamplerResponse(path=result.path)
|
return types.SaveWeightsForSamplerResponse(path=result.path)
|
||||||
|
|
||||||
return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async())
|
return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async())
|
||||||
|
|
||||||
async def save_weights_for_sampler_async(
|
async def save_weights_for_sampler_async(
|
||||||
self, name: str
|
self, name: str, ttl_seconds: int | None = None
|
||||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
|
||||||
"""Async version of save_weights_for_sampler."""
|
"""Async version of save_weights_for_sampler."""
|
||||||
return self.save_weights_for_sampler(name)
|
return self.save_weights_for_sampler(name, ttl_seconds=ttl_seconds)
|
||||||
|
|
||||||
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
|
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
|
||||||
async def _get_info_async():
|
async def _get_info_async():
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,9 @@ class Checkpoint(BaseModel):
|
||||||
public: bool = False
|
public: bool = False
|
||||||
"""Whether the checkpoint is publicly accessible"""
|
"""Whether the checkpoint is publicly accessible"""
|
||||||
|
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
"""When this checkpoint expires (None = never expires)"""
|
||||||
|
|
||||||
|
|
||||||
class ParsedCheckpointTinkerPath(BaseModel):
|
class ParsedCheckpointTinkerPath(BaseModel):
|
||||||
tinker_path: str
|
tinker_path: str
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,9 @@ class SaveWeightsForSamplerRequest(StrictBase):
|
||||||
|
|
||||||
seq_id: Optional[int] = None
|
seq_id: Optional[int] = None
|
||||||
|
|
||||||
|
ttl_seconds: Optional[int] = None
|
||||||
|
"""TTL in seconds for this checkpoint (None = never expires)"""
|
||||||
|
|
||||||
type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler"
|
type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler"
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,9 @@ class SaveWeightsRequest(StrictBase):
|
||||||
|
|
||||||
seq_id: Optional[int] = None
|
seq_id: Optional[int] = None
|
||||||
|
|
||||||
|
ttl_seconds: Optional[int] = None
|
||||||
|
"""TTL in seconds for this checkpoint (None = never expires)"""
|
||||||
|
|
||||||
type: Literal["save_weights"] = "save_weights"
|
type: Literal["save_weights"] = "save_weights"
|
||||||
|
|
||||||
if PYDANTIC_V2:
|
if PYDANTIC_V2:
|
||||||
|
|
|
||||||
253
tests/test_cli_output.py
Normal file
253
tests/test_cli_output.py
Normal file
|
|
@ -0,0 +1,253 @@
|
||||||
|
"""Tests for CLI output formatting utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from tinker.cli.output import format_bool, format_size, format_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatTimestamp:
|
||||||
|
"""Tests for the format_timestamp function."""
|
||||||
|
|
||||||
|
def test_none_returns_na(self) -> None:
|
||||||
|
assert format_timestamp(None) == "N/A"
|
||||||
|
|
||||||
|
def test_empty_string_returns_na(self) -> None:
|
||||||
|
assert format_timestamp("") == "N/A"
|
||||||
|
|
||||||
|
def test_just_now_past(self) -> None:
|
||||||
|
"""Times within the last minute should show 'just now'."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(seconds=30)
|
||||||
|
assert format_timestamp(dt) == "just now"
|
||||||
|
|
||||||
|
def test_just_now_future(self) -> None:
|
||||||
|
"""Times within the next minute should show 'in less than a minute'."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(seconds=30)
|
||||||
|
assert format_timestamp(dt) == "in less than a minute"
|
||||||
|
|
||||||
|
def test_minutes_ago(self) -> None:
|
||||||
|
"""Times a few minutes in the past."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(minutes=5, seconds=30)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (4-5 minutes)
|
||||||
|
assert re.match(r"[45] minutes ago", result), (
|
||||||
|
f"Expected '4 minutes ago' or '5 minutes ago', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_minutes_future(self) -> None:
|
||||||
|
"""Times a few minutes in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(minutes=5, seconds=30)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (4-5 minutes)
|
||||||
|
assert re.match(r"in [45] minutes", result), (
|
||||||
|
f"Expected 'in 4 minutes' or 'in 5 minutes', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_one_minute_ago(self) -> None:
|
||||||
|
"""Singular 'minute' for exactly 1 minute."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(minutes=1, seconds=30)
|
||||||
|
assert format_timestamp(dt) == "1 minute ago"
|
||||||
|
|
||||||
|
def test_one_minute_future(self) -> None:
|
||||||
|
"""Singular 'minute' for exactly 1 minute in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(minutes=1, seconds=30)
|
||||||
|
assert format_timestamp(dt) == "in 1 minute"
|
||||||
|
|
||||||
|
def test_hours_ago(self) -> None:
|
||||||
|
"""Times a few hours in the past."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(hours=3, minutes=30)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (2-3 hours)
|
||||||
|
assert re.match(r"[23] hours ago", result), (
|
||||||
|
f"Expected '2 hours ago' or '3 hours ago', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_hours_future(self) -> None:
|
||||||
|
"""Times a few hours in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(hours=3, minutes=30)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (2-3 hours)
|
||||||
|
assert re.match(r"in [23] hours", result), (
|
||||||
|
f"Expected 'in 2 hours' or 'in 3 hours', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_one_hour_ago(self) -> None:
|
||||||
|
"""Singular 'hour' for exactly 1 hour."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(hours=1, minutes=30)
|
||||||
|
assert format_timestamp(dt) == "1 hour ago"
|
||||||
|
|
||||||
|
def test_one_hour_future(self) -> None:
|
||||||
|
"""Singular 'hour' for exactly 1 hour in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(hours=1, minutes=30)
|
||||||
|
assert format_timestamp(dt) == "in 1 hour"
|
||||||
|
|
||||||
|
def test_days_ago(self) -> None:
|
||||||
|
"""Times a few days in the past."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(days=3, hours=12)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (2-3 days)
|
||||||
|
assert re.match(r"[23] days ago", result), (
|
||||||
|
f"Expected '2 days ago' or '3 days ago', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_days_future(self) -> None:
|
||||||
|
"""Times a few days in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(days=3, hours=12)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (2-3 days)
|
||||||
|
assert re.match(r"in [23] days", result), (
|
||||||
|
f"Expected 'in 2 days' or 'in 3 days', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_one_day_ago(self) -> None:
|
||||||
|
"""Singular 'day' for exactly 1 day."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(days=1, hours=12)
|
||||||
|
assert format_timestamp(dt) == "1 day ago"
|
||||||
|
|
||||||
|
def test_one_day_future(self) -> None:
|
||||||
|
"""Singular 'day' for exactly 1 day in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(days=1, hours=12)
|
||||||
|
assert format_timestamp(dt) == "in 1 day"
|
||||||
|
|
||||||
|
def test_weeks_ago(self) -> None:
|
||||||
|
"""Times a few weeks in the past."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(weeks=2, days=3)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (1-2 weeks)
|
||||||
|
assert re.match(r"[12] weeks? ago", result), (
|
||||||
|
f"Expected '1 week ago' or '2 weeks ago', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_weeks_future(self) -> None:
|
||||||
|
"""Times a few weeks in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(weeks=2, days=3)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Allow for slight timing variations (1-2 weeks)
|
||||||
|
assert re.match(r"in [12] weeks?", result), (
|
||||||
|
f"Expected 'in 1 week' or 'in 2 weeks', got '{result}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_one_week_ago(self) -> None:
|
||||||
|
"""Singular 'week' for exactly 1 week."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(weeks=1, days=3)
|
||||||
|
assert format_timestamp(dt) == "1 week ago"
|
||||||
|
|
||||||
|
def test_one_week_future(self) -> None:
|
||||||
|
"""Singular 'week' for exactly 1 week in the future."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(weeks=1, days=3)
|
||||||
|
assert format_timestamp(dt) == "in 1 week"
|
||||||
|
|
||||||
|
def test_old_date_shows_absolute(self) -> None:
|
||||||
|
"""Dates more than 30 days ago show absolute date."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(days=45)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Should be in YYYY-MM-DD format
|
||||||
|
assert result == dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
def test_far_future_date_shows_absolute(self) -> None:
|
||||||
|
"""Dates more than 30 days in future show absolute date."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now + timedelta(days=45)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
# Should be in YYYY-MM-DD format
|
||||||
|
assert result == dt.strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
def test_iso_string_input(self) -> None:
|
||||||
|
"""ISO format strings are parsed correctly."""
|
||||||
|
# Create a time 5 minutes ago
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(minutes=5)
|
||||||
|
iso_str = dt.isoformat()
|
||||||
|
result = format_timestamp(iso_str)
|
||||||
|
assert "minute" in result
|
||||||
|
|
||||||
|
def test_iso_string_with_z_suffix(self) -> None:
|
||||||
|
"""ISO strings with Z suffix are parsed correctly."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
dt = now - timedelta(hours=2)
|
||||||
|
# Replace +00:00 with Z
|
||||||
|
iso_str = dt.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
|
||||||
|
result = format_timestamp(iso_str)
|
||||||
|
assert "hour" in result
|
||||||
|
|
||||||
|
def test_naive_datetime_treated_as_utc(self) -> None:
|
||||||
|
"""Naive datetimes (no timezone) are treated as UTC."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
# Create naive datetime
|
||||||
|
naive_dt = (now - timedelta(minutes=10)).replace(tzinfo=None)
|
||||||
|
result = format_timestamp(naive_dt)
|
||||||
|
assert "minute" in result
|
||||||
|
|
||||||
|
def test_non_utc_timezone(self) -> None:
|
||||||
|
"""Datetimes with non-UTC timezone are converted properly."""
|
||||||
|
# Create a timezone +5 hours from UTC
|
||||||
|
tz_plus5 = timezone(timedelta(hours=5))
|
||||||
|
now_utc = datetime.now(timezone.utc)
|
||||||
|
# Create time 2 hours ago in UTC, but expressed in +5 timezone
|
||||||
|
dt = (now_utc - timedelta(hours=2)).astimezone(tz_plus5)
|
||||||
|
result = format_timestamp(dt)
|
||||||
|
assert "hour" in result
|
||||||
|
|
||||||
|
def test_invalid_string_returns_string(self) -> None:
|
||||||
|
"""Invalid datetime strings are returned as-is."""
|
||||||
|
result = format_timestamp("not a date")
|
||||||
|
assert result == "not a date"
|
||||||
|
|
||||||
|
def test_non_datetime_object_returns_string(self) -> None:
|
||||||
|
"""Non-datetime objects are converted to string."""
|
||||||
|
result = format_timestamp(12345) # type: ignore
|
||||||
|
assert result == "12345"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatSize:
|
||||||
|
"""Tests for the format_size function."""
|
||||||
|
|
||||||
|
def test_bytes(self) -> None:
|
||||||
|
assert format_size(500) == "500 B"
|
||||||
|
|
||||||
|
def test_kilobytes(self) -> None:
|
||||||
|
assert format_size(1536) == "1.5 KB"
|
||||||
|
|
||||||
|
def test_megabytes(self) -> None:
|
||||||
|
assert format_size(1572864) == "1.5 MB"
|
||||||
|
|
||||||
|
def test_gigabytes(self) -> None:
|
||||||
|
assert format_size(1610612736) == "1.5 GB"
|
||||||
|
|
||||||
|
def test_terabytes(self) -> None:
|
||||||
|
assert format_size(1649267441664) == "1.5 TB"
|
||||||
|
|
||||||
|
def test_zero_bytes(self) -> None:
|
||||||
|
assert format_size(0) == "0 B"
|
||||||
|
|
||||||
|
def test_negative_returns_na(self) -> None:
|
||||||
|
assert format_size(-1) == "N/A"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatBool:
|
||||||
|
"""Tests for the format_bool function."""
|
||||||
|
|
||||||
|
def test_true(self) -> None:
|
||||||
|
assert format_bool(True) == "Yes"
|
||||||
|
|
||||||
|
def test_false(self) -> None:
|
||||||
|
assert format_bool(False) == "No"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue