Sync contents

This commit is contained in:
Daniel Xu 2026-01-25 05:52:42 +00:00
parent ad03d44978
commit ca40e08bb4
12 changed files with 358 additions and 27 deletions

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "e99a92629bc1b7df95b2ffc0b60a60c8eecddee1", "last_synced_sha": "9ee57ec63f1159d18309fabc51bead0105d52917",
"last_sync_time": "2026-01-21T23:13:02.930293" "last_sync_time": "2026-01-25T05:52:42.500531"
} }

View file

@ -110,3 +110,14 @@ async def compute_logprobs_async(
``` ```
Async version of compute_logprobs. Async version of compute_logprobs.
#### `get_tokenizer`
```python
def get_tokenizer() -> PreTrainedTokenizer
```
Get the tokenizer for the current model.
Returns:
- `PreTrainedTokenizer` compatible with the model

View file

@ -176,6 +176,11 @@ def optim_step(
Update model parameters using Adam optimizer. Update model parameters using Adam optimizer.
The Adam optimizer used by tinker is identical
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
Args: Args:
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay) - `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
@ -212,13 +217,17 @@ Async version of optim_step.
#### `save_state` #### `save_state`
```python ```python
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse] def save_state(
name: str,
ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsResponse]
``` ```
Save model weights to persistent storage. Save model weights to persistent storage.
Args: Args:
- `name`: Name for the saved checkpoint - `name`: Name for the saved checkpoint
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
Returns: Returns:
- `APIFuture` containing the save response with checkpoint path - `APIFuture` containing the save response with checkpoint path
@ -234,7 +243,10 @@ print(f"Saved to: {result.path}")
#### `save_state_async` #### `save_state_async`
```python ```python
async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse] async def save_state_async(
name: str,
ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsResponse]
``` ```
Async version of save_state. Async version of save_state.
@ -310,13 +322,16 @@ Async version of load_state_with_optimizer.
```python ```python
def save_weights_for_sampler( def save_weights_for_sampler(
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse] name: str,
ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsForSamplerResponse]
``` ```
Save model weights for use with a SamplingClient. Save model weights for use with a SamplingClient.
Args: Args:
- `name`: Name for the saved sampler weights - `name`: Name for the saved sampler weights
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
Returns: Returns:
- `APIFuture` containing the save response with sampler path - `APIFuture` containing the save response with sampler path
@ -338,7 +353,9 @@ sampling_client = service_client.create_sampling_client(
```python ```python
async def save_weights_for_sampler_async( async def save_weights_for_sampler_async(
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse] name: str,
ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsForSamplerResponse]
``` ```
Async version of save_weights_for_sampler. Async version of save_weights_for_sampler.

View file

@ -26,7 +26,7 @@ Weight decay for the optimizer. Uses decoupled weight decay.
#### `grad_clip_norm` #### `grad_clip_norm`
Gradient clip norm for the optimizer. 0.0 means no clipping. Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping.
## `SupportedModel` Objects ## `SupportedModel` Objects
@ -159,6 +159,10 @@ The size of the checkpoint in bytes
Whether the checkpoint is publicly accessible Whether the checkpoint is publicly accessible
#### `expires_at`
When this checkpoint expires (None = never expires)
## `ParsedCheckpointTinkerPath` Objects ## `ParsedCheckpointTinkerPath` Objects
```python ```python
@ -725,6 +729,10 @@ class SaveWeightsRequest(StrictBase)
A file/directory name for the weights A file/directory name for the weights
#### `ttl_seconds`
TTL in seconds for this checkpoint (None = never expires)
## `LoraConfig` Objects ## `LoraConfig` Objects
```python ```python
@ -834,6 +842,10 @@ class SaveWeightsForSamplerRequest(StrictBase)
A file/directory name for the weights A file/directory name for the weights
#### `ttl_seconds`
TTL in seconds for this checkpoint (None = never expires)
## `SamplingParams` Objects ## `SamplingParams` Objects
```python ```python

View file

@ -1,6 +1,6 @@
[project] [project]
name = "tinker" name = "tinker"
version = "0.8.1" version = "0.9.0"
description = "The official Python SDK for the tinker API" description = "The official Python SDK for the tinker API"
readme = "README.md" readme = "README.md"
license = "Apache-2.0" license = "Apache-2.0"

View file

@ -88,7 +88,7 @@ class CheckpointListOutput(OutputBase):
def get_table_columns(self) -> List[str]: def get_table_columns(self) -> List[str]:
"""Return column headers for table output.""" """Return column headers for table output."""
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"] return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Expires", "Path"]
def get_table_rows(self) -> List[List[str]]: def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output.""" """Return rows for table output."""
@ -103,6 +103,7 @@ class CheckpointListOutput(OutputBase):
else "N/A", else "N/A",
format_bool(ckpt.public), format_bool(ckpt.public),
format_timestamp(ckpt.time), format_timestamp(ckpt.time),
format_timestamp(ckpt.expires_at) if ckpt.expires_at else "Never",
ckpt.tinker_path, ckpt.tinker_path,
] ]
) )
@ -161,6 +162,12 @@ class CheckpointInfoOutput(OutputBase):
# Creation time # Creation time
rows.append(["Created", format_timestamp(self.checkpoint.time)]) rows.append(["Created", format_timestamp(self.checkpoint.time)])
# Expiration time
if self.checkpoint.expires_at:
rows.append(["Expires", format_timestamp(self.checkpoint.expires_at)])
else:
rows.append(["Expires", "Never"])
# Parse training run ID from path # Parse training run ID from path
if self.checkpoint.tinker_path.startswith("tinker://"): if self.checkpoint.tinker_path.startswith("tinker://"):
parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/") parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/")

View file

@ -177,22 +177,33 @@ def format_timestamp(dt: Union[datetime, str, None]) -> str:
# Calculate time difference # Calculate time difference
delta = now - dt delta = now - dt
# Format based on age # Handle future times (negative delta means dt is in the future)
is_future = delta.total_seconds() < 0
if is_future:
delta = dt - now
suffix = ""
prefix = "in "
else:
suffix = " ago"
prefix = ""
# Format based on magnitude
total_seconds = int(delta.total_seconds())
if delta.days > 30: if delta.days > 30:
return dt.strftime("%Y-%m-%d") return dt.strftime("%Y-%m-%d")
elif delta.days > 7: elif delta.days > 7:
weeks = delta.days // 7 weeks = delta.days // 7
return f"{weeks} week{'s' if weeks > 1 else ''} ago" return f"{prefix}{weeks} week{'s' if weeks > 1 else ''}{suffix}"
elif delta.days > 0: elif delta.days > 0:
return f"{delta.days} day{'s' if delta.days > 1 else ''} ago" return f"{prefix}{delta.days} day{'s' if delta.days > 1 else ''}{suffix}"
elif delta.seconds > 3600: elif total_seconds > 3600:
hours = delta.seconds // 3600 hours = total_seconds // 3600
return f"{hours} hour{'s' if hours > 1 else ''} ago" return f"{prefix}{hours} hour{'s' if hours > 1 else ''}{suffix}"
elif delta.seconds > 60: elif total_seconds > 60:
minutes = delta.seconds // 60 minutes = total_seconds // 60
return f"{minutes} minute{'s' if minutes > 1 else ''} ago" return f"{prefix}{minutes} minute{'s' if minutes > 1 else ''}{suffix}"
else: else:
return "just now" return "just now" if not is_future else "in less than a minute"
except Exception: except Exception:
# If any error occurs, just return string representation # If any error occurs, just return string representation

View file

@ -486,11 +486,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
return self.optim_step(adam_params) return self.optim_step(adam_params)
@capture_exceptions(fatal=True) @capture_exceptions(fatal=True)
def save_state(self, name: str) -> APIFuture[types.SaveWeightsResponse]: def save_state(
self, name: str, ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsResponse]:
"""Save model weights to persistent storage. """Save model weights to persistent storage.
Args: Args:
- `name`: Name for the saved checkpoint - `name`: Name for the saved checkpoint
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
Returns: Returns:
- `APIFuture` containing the save response with checkpoint path - `APIFuture` containing the save response with checkpoint path
@ -514,6 +517,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
model_id=self._guaranteed_model_id(), model_id=self._guaranteed_model_id(),
path=name, path=name,
seq_id=request_id + 1, seq_id=request_id + 1,
ttl_seconds=ttl_seconds,
) )
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save( return await client.weights.save(
@ -533,9 +537,11 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
return self.holder.run_coroutine_threadsafe(_save_state_async()) return self.holder.run_coroutine_threadsafe(_save_state_async())
async def save_state_async(self, name: str) -> APIFuture[types.SaveWeightsResponse]: async def save_state_async(
self, name: str, ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsResponse]:
"""Async version of save_state.""" """Async version of save_state."""
return self.save_state(name) return self.save_state(name, ttl_seconds=ttl_seconds)
@capture_exceptions(fatal=True) @capture_exceptions(fatal=True)
async def _load_state_impl( async def _load_state_impl(
@ -625,7 +631,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
@capture_exceptions(fatal=True) @capture_exceptions(fatal=True)
async def _save_weights_for_sampler_impl( async def _save_weights_for_sampler_impl(
self, request_id: int, name: str | None self, request_id: int, name: str | None, ttl_seconds: int | None = None
) -> types.SaveWeightsForSamplerResponseInternal: ) -> types.SaveWeightsForSamplerResponseInternal:
assert asyncio.get_event_loop() == self.holder.get_loop() assert asyncio.get_event_loop() == self.holder.get_loop()
start_time = time.time() start_time = time.time()
@ -636,6 +642,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
model_id=self._guaranteed_model_id(), model_id=self._guaranteed_model_id(),
path=name, path=name,
seq_id=request_id + 1, seq_id=request_id + 1,
ttl_seconds=ttl_seconds,
) )
else: else:
sampling_session_seq_id = self.holder._sampling_client_counter sampling_session_seq_id = self.holder._sampling_client_counter
@ -644,6 +651,7 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
model_id=self._guaranteed_model_id(), model_id=self._guaranteed_model_id(),
seq_id=request_id + 1, seq_id=request_id + 1,
sampling_session_seq_id=sampling_session_seq_id, sampling_session_seq_id=sampling_session_seq_id,
ttl_seconds=ttl_seconds,
) )
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save_for_sampler( return await client.weights.save_for_sampler(
@ -662,11 +670,14 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
) )
@capture_exceptions(fatal=True) @capture_exceptions(fatal=True)
def save_weights_for_sampler(self, name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]: def save_weights_for_sampler(
self, name: str, ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsForSamplerResponse]:
"""Save model weights for use with a SamplingClient. """Save model weights for use with a SamplingClient.
Args: Args:
- `name`: Name for the saved sampler weights - `name`: Name for the saved sampler weights
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
Returns: Returns:
- `APIFuture` containing the save response with sampler path - `APIFuture` containing the save response with sampler path
@ -687,17 +698,17 @@ class TrainingClient(TelemetryProvider, QueueStateObserver):
request_id = self._get_request_id() request_id = self._get_request_id()
async def _save_weights_for_sampler_async(): async def _save_weights_for_sampler_async():
result = await self._save_weights_for_sampler_impl(request_id, name) result = await self._save_weights_for_sampler_impl(request_id, name, ttl_seconds)
assert result.path is not None assert result.path is not None
return types.SaveWeightsForSamplerResponse(path=result.path) return types.SaveWeightsForSamplerResponse(path=result.path)
return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async()) return self.holder.run_coroutine_threadsafe(_save_weights_for_sampler_async())
async def save_weights_for_sampler_async( async def save_weights_for_sampler_async(
self, name: str self, name: str, ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsForSamplerResponse]: ) -> APIFuture[types.SaveWeightsForSamplerResponse]:
"""Async version of save_weights_for_sampler.""" """Async version of save_weights_for_sampler."""
return self.save_weights_for_sampler(name) return self.save_weights_for_sampler(name, ttl_seconds=ttl_seconds)
def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]: def _get_info_submit(self) -> AwaitableConcurrentFuture[types.GetInfoResponse]:
async def _get_info_async(): async def _get_info_async():

View file

@ -27,6 +27,9 @@ class Checkpoint(BaseModel):
public: bool = False public: bool = False
"""Whether the checkpoint is publicly accessible""" """Whether the checkpoint is publicly accessible"""
expires_at: datetime | None = None
"""When this checkpoint expires (None = never expires)"""
class ParsedCheckpointTinkerPath(BaseModel): class ParsedCheckpointTinkerPath(BaseModel):
tinker_path: str tinker_path: str

View file

@ -19,6 +19,9 @@ class SaveWeightsForSamplerRequest(StrictBase):
seq_id: Optional[int] = None seq_id: Optional[int] = None
ttl_seconds: Optional[int] = None
"""TTL in seconds for this checkpoint (None = never expires)"""
type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler" type: Literal["save_weights_for_sampler"] = "save_weights_for_sampler"
if PYDANTIC_V2: if PYDANTIC_V2:

View file

@ -17,6 +17,9 @@ class SaveWeightsRequest(StrictBase):
seq_id: Optional[int] = None seq_id: Optional[int] = None
ttl_seconds: Optional[int] = None
"""TTL in seconds for this checkpoint (None = never expires)"""
type: Literal["save_weights"] = "save_weights" type: Literal["save_weights"] = "save_weights"
if PYDANTIC_V2: if PYDANTIC_V2:

253
tests/test_cli_output.py Normal file
View file

@ -0,0 +1,253 @@
"""Tests for CLI output formatting utilities."""
import re
from datetime import datetime, timedelta, timezone
from tinker.cli.output import format_bool, format_size, format_timestamp
class TestFormatTimestamp:
"""Tests for the format_timestamp function."""
def test_none_returns_na(self) -> None:
assert format_timestamp(None) == "N/A"
def test_empty_string_returns_na(self) -> None:
assert format_timestamp("") == "N/A"
def test_just_now_past(self) -> None:
"""Times within the last minute should show 'just now'."""
now = datetime.now(timezone.utc)
dt = now - timedelta(seconds=30)
assert format_timestamp(dt) == "just now"
def test_just_now_future(self) -> None:
"""Times within the next minute should show 'in less than a minute'."""
now = datetime.now(timezone.utc)
dt = now + timedelta(seconds=30)
assert format_timestamp(dt) == "in less than a minute"
def test_minutes_ago(self) -> None:
"""Times a few minutes in the past."""
now = datetime.now(timezone.utc)
dt = now - timedelta(minutes=5, seconds=30)
result = format_timestamp(dt)
# Allow for slight timing variations (4-5 minutes)
assert re.match(r"[45] minutes ago", result), (
f"Expected '4 minutes ago' or '5 minutes ago', got '{result}'"
)
def test_minutes_future(self) -> None:
"""Times a few minutes in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(minutes=5, seconds=30)
result = format_timestamp(dt)
# Allow for slight timing variations (4-5 minutes)
assert re.match(r"in [45] minutes", result), (
f"Expected 'in 4 minutes' or 'in 5 minutes', got '{result}'"
)
def test_one_minute_ago(self) -> None:
"""Singular 'minute' for exactly 1 minute."""
now = datetime.now(timezone.utc)
dt = now - timedelta(minutes=1, seconds=30)
assert format_timestamp(dt) == "1 minute ago"
def test_one_minute_future(self) -> None:
"""Singular 'minute' for exactly 1 minute in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(minutes=1, seconds=30)
assert format_timestamp(dt) == "in 1 minute"
def test_hours_ago(self) -> None:
"""Times a few hours in the past."""
now = datetime.now(timezone.utc)
dt = now - timedelta(hours=3, minutes=30)
result = format_timestamp(dt)
# Allow for slight timing variations (2-3 hours)
assert re.match(r"[23] hours ago", result), (
f"Expected '2 hours ago' or '3 hours ago', got '{result}'"
)
def test_hours_future(self) -> None:
"""Times a few hours in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(hours=3, minutes=30)
result = format_timestamp(dt)
# Allow for slight timing variations (2-3 hours)
assert re.match(r"in [23] hours", result), (
f"Expected 'in 2 hours' or 'in 3 hours', got '{result}'"
)
def test_one_hour_ago(self) -> None:
"""Singular 'hour' for exactly 1 hour."""
now = datetime.now(timezone.utc)
dt = now - timedelta(hours=1, minutes=30)
assert format_timestamp(dt) == "1 hour ago"
def test_one_hour_future(self) -> None:
"""Singular 'hour' for exactly 1 hour in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(hours=1, minutes=30)
assert format_timestamp(dt) == "in 1 hour"
def test_days_ago(self) -> None:
"""Times a few days in the past."""
now = datetime.now(timezone.utc)
dt = now - timedelta(days=3, hours=12)
result = format_timestamp(dt)
# Allow for slight timing variations (2-3 days)
assert re.match(r"[23] days ago", result), (
f"Expected '2 days ago' or '3 days ago', got '{result}'"
)
def test_days_future(self) -> None:
"""Times a few days in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(days=3, hours=12)
result = format_timestamp(dt)
# Allow for slight timing variations (2-3 days)
assert re.match(r"in [23] days", result), (
f"Expected 'in 2 days' or 'in 3 days', got '{result}'"
)
def test_one_day_ago(self) -> None:
"""Singular 'day' for exactly 1 day."""
now = datetime.now(timezone.utc)
dt = now - timedelta(days=1, hours=12)
assert format_timestamp(dt) == "1 day ago"
def test_one_day_future(self) -> None:
"""Singular 'day' for exactly 1 day in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(days=1, hours=12)
assert format_timestamp(dt) == "in 1 day"
def test_weeks_ago(self) -> None:
"""Times a few weeks in the past."""
now = datetime.now(timezone.utc)
dt = now - timedelta(weeks=2, days=3)
result = format_timestamp(dt)
# Allow for slight timing variations (1-2 weeks)
assert re.match(r"[12] weeks? ago", result), (
f"Expected '1 week ago' or '2 weeks ago', got '{result}'"
)
def test_weeks_future(self) -> None:
"""Times a few weeks in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(weeks=2, days=3)
result = format_timestamp(dt)
# Allow for slight timing variations (1-2 weeks)
assert re.match(r"in [12] weeks?", result), (
f"Expected 'in 1 week' or 'in 2 weeks', got '{result}'"
)
def test_one_week_ago(self) -> None:
"""Singular 'week' for exactly 1 week."""
now = datetime.now(timezone.utc)
dt = now - timedelta(weeks=1, days=3)
assert format_timestamp(dt) == "1 week ago"
def test_one_week_future(self) -> None:
"""Singular 'week' for exactly 1 week in the future."""
now = datetime.now(timezone.utc)
dt = now + timedelta(weeks=1, days=3)
assert format_timestamp(dt) == "in 1 week"
def test_old_date_shows_absolute(self) -> None:
"""Dates more than 30 days ago show absolute date."""
now = datetime.now(timezone.utc)
dt = now - timedelta(days=45)
result = format_timestamp(dt)
# Should be in YYYY-MM-DD format
assert result == dt.strftime("%Y-%m-%d")
def test_far_future_date_shows_absolute(self) -> None:
"""Dates more than 30 days in future show absolute date."""
now = datetime.now(timezone.utc)
dt = now + timedelta(days=45)
result = format_timestamp(dt)
# Should be in YYYY-MM-DD format
assert result == dt.strftime("%Y-%m-%d")
def test_iso_string_input(self) -> None:
"""ISO format strings are parsed correctly."""
# Create a time 5 minutes ago
now = datetime.now(timezone.utc)
dt = now - timedelta(minutes=5)
iso_str = dt.isoformat()
result = format_timestamp(iso_str)
assert "minute" in result
def test_iso_string_with_z_suffix(self) -> None:
"""ISO strings with Z suffix are parsed correctly."""
now = datetime.now(timezone.utc)
dt = now - timedelta(hours=2)
# Replace +00:00 with Z
iso_str = dt.strftime("%Y-%m-%dT%H:%M:%S.%f") + "Z"
result = format_timestamp(iso_str)
assert "hour" in result
def test_naive_datetime_treated_as_utc(self) -> None:
"""Naive datetimes (no timezone) are treated as UTC."""
now = datetime.now(timezone.utc)
# Create naive datetime
naive_dt = (now - timedelta(minutes=10)).replace(tzinfo=None)
result = format_timestamp(naive_dt)
assert "minute" in result
def test_non_utc_timezone(self) -> None:
"""Datetimes with non-UTC timezone are converted properly."""
# Create a timezone +5 hours from UTC
tz_plus5 = timezone(timedelta(hours=5))
now_utc = datetime.now(timezone.utc)
# Create time 2 hours ago in UTC, but expressed in +5 timezone
dt = (now_utc - timedelta(hours=2)).astimezone(tz_plus5)
result = format_timestamp(dt)
assert "hour" in result
def test_invalid_string_returns_string(self) -> None:
"""Invalid datetime strings are returned as-is."""
result = format_timestamp("not a date")
assert result == "not a date"
def test_non_datetime_object_returns_string(self) -> None:
"""Non-datetime objects are converted to string."""
result = format_timestamp(12345) # type: ignore
assert result == "12345"
class TestFormatSize:
"""Tests for the format_size function."""
def test_bytes(self) -> None:
assert format_size(500) == "500 B"
def test_kilobytes(self) -> None:
assert format_size(1536) == "1.5 KB"
def test_megabytes(self) -> None:
assert format_size(1572864) == "1.5 MB"
def test_gigabytes(self) -> None:
assert format_size(1610612736) == "1.5 GB"
def test_terabytes(self) -> None:
assert format_size(1649267441664) == "1.5 TB"
def test_zero_bytes(self) -> None:
assert format_size(0) == "0 B"
def test_negative_returns_na(self) -> None:
assert format_size(-1) == "N/A"
class TestFormatBool:
"""Tests for the format_bool function."""
def test_true(self) -> None:
assert format_bool(True) == "Yes"
def test_false(self) -> None:
assert format_bool(False) == "No"