Sync contents

This commit is contained in:
Daniel Xu 2025-11-18 22:06:38 +00:00
parent 604e00c700
commit 2a37c3afb4
18 changed files with 1811 additions and 4 deletions

View file

@ -1,4 +1,4 @@
{ {
"last_synced_sha": "d973e1d6eede81b853a167e8f999f43402e07c3a", "last_synced_sha": "31ae0341ea2c6122c46c0597c5a862440a9cf31e",
"last_sync_time": "2025-11-11T05:56:15.874542" "last_sync_time": "2025-11-18T22:06:38.289689"
} }

View file

@ -17,6 +17,8 @@ dependencies = [
"sniffio", "sniffio",
"numpy", "numpy",
"torch", "torch",
"rich>=13.0.0",
"click>=8.0.0",
] ]
requires-python = ">= 3.11" requires-python = ">= 3.11"
classifiers = [ classifiers = [
@ -36,6 +38,9 @@ classifiers = [
"License :: OSI Approved :: Apache Software License" "License :: OSI Approved :: Apache Software License"
] ]
[project.scripts]
tinker = "tinker.cli.__main__:cli"
[project.urls] [project.urls]
Homepage = "https://thinkingmachines.ai/tinker" Homepage = "https://thinkingmachines.ai/tinker"
Repository = "https://github.com/thinking-machines-lab/tinker" Repository = "https://github.com/thinking-machines-lab/tinker"

281
src/tinker/cli/CLAUDE.md Normal file
View file

@ -0,0 +1,281 @@
# Tinker CLI Design Documentation
## Overview
The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance.
## Key Design Decisions
### 1. Lazy Import Strategy with Click
**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level.
**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands.
**Implementation**:
- Main `__init__.py` only imports `click` and `lazy_group`
- Command modules are loaded only when invoked via `LazyGroup`
- Output formatting imports `rich` only when table output is needed
- JSON module imported only when JSON output is requested
- Version information loaded from `_version.py` only when `tinker version` is used
### 2. Click Framework with LazyGroup
**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading.
**Rationale**:
- Click provides cleaner command structure with decorators
- Better subcommand isolation - each command file is self-contained
- Automatic help generation with better formatting
- Built-in type conversion and validation
- LazyGroup enables fast startup by deferring imports
**LazyGroup Implementation**:
```python
class LazyGroup(click.Group):
def __init__(self, *args, lazy_subcommands=None, **kwargs):
# Map of command name to "module.path:command_name"
self.lazy_subcommands = lazy_subcommands or {}
def get_command(self, ctx, cmd_name):
if cmd_name in self.lazy_subcommands:
# Import only when command is actually invoked
import_path = self.lazy_subcommands[cmd_name]
module_name, attr_name = import_path.rsplit(":", 1)
mod = importlib.import_module(module_name)
return getattr(mod, attr_name)
```
### 3. Hierarchical Command Structure
**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`.
**Rationale**:
- Provides a consistent, predictable interface
- Groups related functionality together
- Makes the CLI extensible for future commands
- Follows common CLI patterns (like `git`, `docker`, etc.)
**Examples**:
- `tinker version` - Show CLI and SDK version
- `tinker run list` - List all training runs
- `tinker run info <run-id>` - Show details of a specific run
- `tinker checkpoint list` - List all checkpoints
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
### 4. Output System with Inheritance
**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class.
**Rationale**:
- Enforces consistent interface across all commands
- Encapsulates output logic with the command that generates it
- Makes it easy to support multiple output formats (table, JSON)
- Keeps related code together in the same module
**Implementation**:
- `OutputBase` in `output.py` defines the contract
- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`)
- Base class handles format selection and rendering
### 5. Self-Contained Command Modules
**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point.
**Rationale**:
- Modular architecture - commands can be developed independently
- Clear separation of concerns
- Easy to add new commands without modifying core files
- Consistent pattern across all commands
**Command Structure**:
```python
# Each command file follows this pattern:
@click.group() # or @click.command() for simple commands
def cli():
"""Command description."""
pass
@cli.command() # For subcommands
def list():
"""Subcommand implementation."""
pass
```
### 6. Centralized Client Management
**Decision**: All SDK client creation and error handling is centralized in `client.py`.
**Rationale**:
- Single place to handle authentication and connection errors
- Consistent error messages across all commands
- Reusable error handling decorator
- Clean separation of concerns
### 7. Rich Tables for Human-Readable Output
**Decision**: Use the `rich` library for table formatting, kept as an optional dependency.
**Rationale**:
- Provides beautiful, formatted tables with colors and borders
- Handles column width adjustment automatically
- Supports both dark and light terminal themes
- Optional dependency keeps the core package lightweight
### 8. Unix-Style Default Output
**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output.
**Rationale**:
- Follows Unix philosophy
- Tables are better for human consumption
- JSON is better for scripting and automation
- Single flag switches between formats consistently
## Performance Optimizations
1. **LazyGroup for deferred imports** - Commands only loaded when invoked
2. **No heavy imports at module level** - Only Click imported initially
3. **Lazy loading** of all SDK dependencies
4. **Progress indicators** that clear themselves
5. **Efficient data fetching** - fetch all data by default instead of pagination
## Error Handling Strategy
1. **User-friendly messages** - Technical errors are translated to helpful messages
2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes
3. **Graceful degradation** - Continue operation when possible
4. **Detailed error info** - Show details when available, traceback only in TTY
### TinkerCliError Exception Pattern
All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`:
```python
from ..exceptions import TinkerCliError
# Instead of:
print(f"Error: Something went wrong", file=sys.stderr)
sys.exit(1)
# Use:
raise TinkerCliError(
"Something went wrong",
"Optional details or help text",
exit_code=1 # Optional, defaults to 1
)
```
**Benefits:**
- Better testability (can catch exceptions in tests)
- Centralized error formatting in `__main__.py`
- Consistent exit codes across the CLI
- Stack traces preserved for debugging
**Important Notes:**
- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification
- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages
- The main error handler in `__main__.py` handles printing to stderr and exiting
## Future Extensibility
The architecture supports easy addition of:
### New Commands
- Create new module in `commands/` directory
- Define output classes in the same module if needed
- Add command to lazy_subcommands in `__init__.py`
### New Subcommands
- Add new Click command decorator to existing command module
- Define corresponding output class if needed
- Subcommands automatically discovered by Click
### New Output Formats
- Override `print()` method in `OutputBase`
- Or add new format handling to base class
## Testing Guidelines
1. **Startup time**: `time tinker --help` should be <50ms
2. **Import verification**: Check that modules aren't imported unnecessarily
3. **Output formats**: Test both table and JSON output
4. **Error cases**: Test with missing auth, invalid IDs, network errors
5. **Empty results**: Ensure graceful handling of no data
## Module Structure
```
cli/
├── __init__.py # Main entry with LazyGroup configuration
├── __main__.py # Module execution support
├── lazy_group.py # LazyGroup implementation for lazy loading
├── output.py # OutputBase class and formatting utilities
├── client.py # SDK client creation and error handling
├── commands/
│ ├── __init__.py # Command module marker
│ ├── version.py # Version command
│ ├── run.py # Run commands and output classes
│ └── checkpoint.py # Checkpoint commands and output classes
└── CLAUDE.md # This documentation
```
## Command Examples
```bash
# Show version
tinker version
# List all training runs
tinker run list
# Show run details
tinker run info run-abc123
# List all checkpoints
tinker checkpoint list
# List checkpoints for specific run
tinker checkpoint list run-abc123
# Show checkpoint details
tinker checkpoint info ckpt-xyz789
# JSON output
tinker --format json run list
tinker --format json checkpoint list
```
## Dependencies
### Required
- Python 3.11+
- tinker SDK (main package)
- click>=8.0.0 (CLI framework)
### Optional
- `rich` - For table formatting (installed with `pip install tinker[cli]`)
## Migration from Argparse to Click
### Key Changes:
1. **Command Definition**: Decorators instead of `parser.add_argument()`
2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch
3. **Context Passing**: Click's context system for sharing format option
4. **Error Handling**: Click handles exits and error formatting
5. **Help Generation**: Automatic from docstrings and decorators
### Benefits:
- Cleaner, more Pythonic code
- Better command organization
- Built-in testing utilities
- Easier to extend with plugins
- More consistent behavior across commands
## Maintenance Notes
1. **Keep imports lazy** - Use LazyGroup for all commands
2. **Test startup time** - Regularly verify fast startup is maintained
3. **Follow Click patterns** - Use decorators and context properly
4. **Document changes** - Update this file when making architectural changes
5. **Maintain consistency** - All commands should follow the same structure

View file

@ -0,0 +1,60 @@
"""Tinker CLI - Command-line interface for the Tinker SDK.
This module implements lazy loading to ensure fast startup times.
Only Click is imported at the module level. All other dependencies
including command modules are imported on-demand.
Enable execution of the CLI via python -m tinker.cli
"""
import sys
import click
from .lazy_group import LazyGroup
from .context import CLIContext
from .exceptions import TinkerCliError
@click.group(
cls=LazyGroup,
lazy_subcommands={
"checkpoint": "tinker.cli.commands.checkpoint:cli",
"run": "tinker.cli.commands.run:cli",
"version": "tinker.cli.commands.version:cli",
},
context_settings=dict(help_option_names=["-h", "--help"]),
)
@click.option(
"--format",
"-f",
type=click.Choice(["table", "json"]),
default="table",
help="Output format (default: table)",
)
@click.pass_context
def main_cli(ctx: click.Context, format: str) -> None:
"""Tinker management CLI."""
# Store format in context for subcommands to access
ctx.obj = CLIContext(format=format) # type: ignore[assignment]
def main():
try:
main_cli()
except TinkerCliError as e:
# Print error message to stderr
if e.message:
print(f"Error: {e.message}", file=sys.stderr)
if e.details:
print(e.details, file=sys.stderr)
sys.exit(e.exit_code)
except KeyboardInterrupt:
# Standard Unix exit code for Ctrl+C
sys.exit(130)
# Make main available for entry point
cli = main
if __name__ == "__main__":
main()

157
src/tinker/cli/client.py Normal file
View file

@ -0,0 +1,157 @@
"""Client utilities for tinker CLI - handles SDK client creation and configuration.
This module provides functions for creating and configuring the Tinker SDK
client, with proper error handling for common issues like authentication
and network errors.
"""
import sys
from functools import wraps
from typing import TypeVar, Callable, Any, cast, TYPE_CHECKING
from .exceptions import TinkerCliError
if TYPE_CHECKING:
from tinker.lib.public_interfaces.rest_client import RestClient
def create_rest_client() -> "RestClient":
"""Create and configure a RestClient instance with proper error handling.
This function handles the creation of the ServiceClient and RestClient,
with appropriate error messages for common failure cases.
Returns:
A configured RestClient instance
Raises:
TinkerCliError: If client creation fails
"""
# Lazy import to avoid slow startup
from tinker import ServiceClient
try:
service_client = ServiceClient()
return service_client.create_rest_client()
except ImportError as e:
raise TinkerCliError(
f"Failed to import Tinker SDK: {e}",
"Please ensure the tinker package is properly installed.",
)
except ValueError as e:
# Often indicates missing or invalid API key
raise TinkerCliError(
f"Configuration error: {e}", "Please check your Tinker API credentials."
)
except Exception as e:
# Catch-all for other errors
raise TinkerCliError(
f"Failed to connect to Tinker API: {e}",
"Please check your network connection and API configuration.",
)
# Type variable for decorator
F = TypeVar("F", bound=Callable[..., Any])
def handle_api_errors(func: F) -> F:
"""Decorator for handling common API errors.
This decorator catches common exceptions from the Tinker API
and provides user-friendly error messages.
Args:
func: Function to wrap with error handling
Returns:
Wrapped function with error handling
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Lazy import to avoid slow startup
from tinker._exceptions import (
APIError,
BadRequestError,
AuthenticationError,
PermissionDeniedError,
NotFoundError,
UnprocessableEntityError,
RateLimitError,
InternalServerError,
APIStatusError,
APIConnectionError,
APITimeoutError,
)
try:
return func(*args, **kwargs)
except NotFoundError as e:
details = f"Details: {e.message}" if hasattr(e, "message") else None
raise TinkerCliError("Resource not found", details)
except AuthenticationError as e:
details = "Please check your API key or authentication credentials."
if hasattr(e, "message"):
details += f"\nDetails: {e.message}"
raise TinkerCliError("Authentication failed", details)
except PermissionDeniedError as e:
details = "You don't have permission to access this resource."
if hasattr(e, "message"):
details += f"\nDetails: {e.message}"
raise TinkerCliError("Permission denied", details)
except BadRequestError as e:
details = f"Details: {e.message}" if hasattr(e, "message") else None
raise TinkerCliError("Invalid request", details)
except UnprocessableEntityError as e:
details = f"Details: {e.message}" if hasattr(e, "message") else None
raise TinkerCliError("Invalid data provided", details)
except RateLimitError as e:
details = "Please wait a moment before trying again."
if hasattr(e, "message"):
details += f"\nDetails: {e.message}"
raise TinkerCliError("Rate limit exceeded", details)
except InternalServerError as e:
details = "The Tinker API encountered an internal error. Please try again later."
if hasattr(e, "message"):
details += f"\nDetails: {e.message}"
raise TinkerCliError("Internal server error", details)
except APITimeoutError as e:
details = "The request to Tinker API timed out. Please try again."
if hasattr(e, "message"):
details += f"\nDetails: {e.message}"
raise TinkerCliError("Request timeout", details)
except APIConnectionError as e:
details = "Could not connect to the Tinker API. Please check your network connection."
if hasattr(e, "message"):
details += f"\nDetails: {e.message}"
raise TinkerCliError("Connection failed", details)
except APIStatusError as e:
status = e.status_code if hasattr(e, "status_code") else "unknown"
details = f"Details: {e.message}" if hasattr(e, "message") else None
raise TinkerCliError(f"API error (status {status})", details)
except APIError as e:
# Generic API error
details = f"Details: {e.message}" if hasattr(e, "message") else None
raise TinkerCliError("API error occurred", details)
except TinkerCliError:
# Re-raise our own errors without modification
raise
except KeyboardInterrupt:
# Re-raise keyboard interrupt to be handled by main
raise
except Exception as e:
# Catch-all for unexpected errors
import traceback
details = None
if sys.stderr.isatty():
# Only include traceback if stderr is a terminal (for debugging)
import io
tb_str = io.StringIO()
traceback.print_exc(file=tb_str)
details = tb_str.getvalue()
raise TinkerCliError(f"Unexpected error occurred: {e}", details)
return cast(F, wrapper)

View file

@ -0,0 +1 @@
"""Command modules for the Tinker CLI."""

View file

@ -0,0 +1,414 @@
"""Commands for managing checkpoints.
This module implements the 'tinker checkpoint' commands, including:
- list: List all checkpoints or checkpoints for a specific run
- info: Show details of a specific checkpoint
"""
from typing import TYPE_CHECKING, Any, Dict, List
import click
if TYPE_CHECKING:
from tinker.types import Checkpoint
from ..client import create_rest_client, handle_api_errors
from ..context import CLIContext
from ..exceptions import TinkerCliError
from ..output import OutputBase, format_bool, format_size, format_timestamp
class CheckpointListOutput(OutputBase):
"""Output for 'tinker checkpoint list' command."""
def __init__(
self,
checkpoints: List["Checkpoint"],
run_id: str | None = None,
total_count: int | None = None,
shown_count: int | None = None,
):
"""Initialize with list of checkpoints.
Args:
checkpoints: List of Checkpoint objects
run_id: Optional training run ID if filtering by run
total_count: Total number of checkpoints available
shown_count: Number of checkpoints shown in this response
"""
self.checkpoints = checkpoints
self.run_id = run_id
self.total_count = total_count
self.shown_count = shown_count if shown_count is not None else len(checkpoints)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON output."""
result = {}
if self.run_id:
result["run_id"] = self.run_id
# Check if these are Pydantic models
if self.checkpoints and hasattr(self.checkpoints[0], "model_dump"):
result["checkpoints"] = [c.model_dump() for c in self.checkpoints]
else:
result["checkpoints"] = [dict(c) for c in self.checkpoints]
return result
def get_title(self) -> str | None:
"""Return title for table output."""
count = len(self.checkpoints)
if self.run_id:
if count == 0:
return f"No checkpoints found for run {self.run_id}"
elif count == 1:
title = f"1 checkpoint for run {self.run_id}"
else:
title = f"{count} checkpoints for run {self.run_id}"
else:
if count == 0:
return "No checkpoints found"
elif count == 1:
title = "1 checkpoint"
else:
title = f"{count} checkpoints"
# Add information about remaining checkpoints if available
if self.total_count is not None and self.total_count > self.shown_count:
remaining = self.total_count - self.shown_count
if remaining == 1:
title += " (1 more not shown, use --limit to see more)"
else:
title += f" ({remaining} more not shown, use --limit to see more)"
return title
def get_table_columns(self) -> List[str]:
"""Return column headers for table output."""
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"]
def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output."""
rows = []
for ckpt in self.checkpoints:
rows.append(
[
ckpt.checkpoint_id,
ckpt.checkpoint_type,
format_size(ckpt.size_bytes) if hasattr(ckpt, "size_bytes") else "N/A",
format_bool(ckpt.public),
format_timestamp(ckpt.time),
ckpt.tinker_path,
]
)
return rows
class CheckpointInfoOutput(OutputBase):
"""Output for 'tinker checkpoint info' command."""
def __init__(self, checkpoint: "Checkpoint"):
"""Initialize with a single checkpoint.
Args:
checkpoint: Checkpoint object
"""
self.checkpoint = checkpoint
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON output."""
if hasattr(self.checkpoint, "model_dump"):
return self.checkpoint.model_dump()
return dict(self.checkpoint)
def get_title(self) -> str | None:
"""Return title for table output."""
return f"Checkpoint: {self.checkpoint.checkpoint_id}"
def get_table_columns(self) -> List[str]:
"""Return column headers for table output."""
return ["Property", "Value"]
def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output."""
rows = [
["Checkpoint ID", self.checkpoint.checkpoint_id],
["Type", self.checkpoint.checkpoint_type],
["Tinker Path", self.checkpoint.tinker_path],
]
# Size if available
if hasattr(self.checkpoint, "size_bytes"):
rows.append(["Size", format_size(self.checkpoint.size_bytes)])
# Public status
rows.append(["Public", format_bool(self.checkpoint.public)])
# Creation time
rows.append(["Created", format_timestamp(self.checkpoint.time)])
# Parse training run ID from path
if self.checkpoint.tinker_path.startswith("tinker://"):
parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/")
if parts:
rows.append(["Training Run ID", parts[0]])
return rows
def get_checkpoint_from_path(checkpoint_path: str) -> "Checkpoint":
"""Get checkpoint details from a tinker path.
Args:
checkpoint_path: A tinker path like "tinker://run-id/weights/0001"
Returns:
Checkpoint object
Raises:
TinkerCliError: If the checkpoint cannot be retrieved
"""
# Lazy import
from tinker import ParsedCheckpointTinkerPath
try:
parsed = ParsedCheckpointTinkerPath.from_tinker_path(checkpoint_path)
client = create_rest_client()
# Get the checkpoint info
checkpoints_response = client.list_checkpoints(parsed.training_run_id).result()
# Find the matching checkpoint
for ckpt in checkpoints_response.checkpoints:
if ckpt.tinker_path == checkpoint_path:
return ckpt
raise TinkerCliError(f"Checkpoint not found: {checkpoint_path}")
except ValueError as e:
raise TinkerCliError(
f"Invalid checkpoint path: {e}",
"Checkpoint paths should be in the format: tinker://run-id/weights/0001",
)
except TinkerCliError:
# Re-raise our own errors
raise
except Exception as e:
raise TinkerCliError(f"Failed to retrieve checkpoint: {e}")
# Click command group for checkpoint commands
@click.group()
def cli():
"""Manage checkpoints."""
pass
@cli.command()
@click.option("--run-id", help="Training run ID")
@click.option(
"--limit",
type=int,
default=20,
help="Maximum number of checkpoints to display when listing from all runs (default: 20, use --limit=0 to show all)",
)
@click.pass_obj
@handle_api_errors
def list(cli_context: CLIContext, run_id: str | None, limit: int) -> None:
"""List checkpoints.
If --run-id is provided, list checkpoints for that specific training run.
Otherwise, list checkpoints from all recent runs.
"""
# Get format from context object
format = cli_context.format
# Create client
client = create_rest_client()
if run_id:
# List checkpoints for specific run.
# Note that there's no pagination for listing checkpoints on a single training run.
response = client.list_checkpoints(run_id).result()
# Create output object
output = CheckpointListOutput(checkpoints=response.checkpoints, run_id=run_id)
else:
# List checkpoints from all user's training runs using list_user_checkpoints()
all_checkpoints = []
offset = 0
# Fetch in batches of 1000 since the queries are so slow
BATCH_SIZE = 1000
# First fetch to get initial data and total count
first_response = client.list_user_checkpoints(
limit=min(BATCH_SIZE, limit) if limit > 0 else BATCH_SIZE, offset=0
).result()
all_checkpoints.extend(first_response.checkpoints)
total_count = (
first_response.cursor.total_count
if first_response.cursor
else len(first_response.checkpoints)
)
# Determine target count: either user-specified limit or total available
target_count = limit if limit > 0 else total_count
target_count = min(target_count, total_count) # Can't fetch more than exists
# If we need to fetch more checkpoints, paginate with a progress bar
if len(all_checkpoints) < target_count:
with click.progressbar(
length=target_count,
label=f"Fetching {'all' if limit == 0 else str(target_count)} checkpoints",
show_percent=True,
show_pos=True,
show_eta=True,
) as bar:
bar.update(len(all_checkpoints))
# Fetch remaining checkpoints in batches
while len(all_checkpoints) < target_count:
offset = len(all_checkpoints)
remaining = target_count - len(all_checkpoints)
next_batch_size = min(BATCH_SIZE, remaining)
response = client.list_user_checkpoints(
limit=next_batch_size, offset=offset
).result()
all_checkpoints.extend(response.checkpoints)
bar.update(len(response.checkpoints))
# Break if we got fewer than requested (reached the end)
if len(response.checkpoints) < next_batch_size:
break
# Create output object with pagination information
output = CheckpointListOutput(
checkpoints=all_checkpoints, total_count=total_count, shown_count=len(all_checkpoints)
)
# Print in requested format
output.print(format=format)
@cli.command()
@click.argument("checkpoint_path")
@click.pass_obj
@handle_api_errors
def info(cli_context: CLIContext, checkpoint_path: str) -> None:
"""Show details of a specific checkpoint.
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
"""
# Get format from context object
format = cli_context.format
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {checkpoint_path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
checkpoint = get_checkpoint_from_path(checkpoint_path)
# Create output object
output = CheckpointInfoOutput(checkpoint=checkpoint)
# Print in requested format
output.print(format=format)
@cli.command()
@click.argument("checkpoint_path")
@click.pass_obj
@handle_api_errors
def publish(cli_context: CLIContext, checkpoint_path: str) -> None:
"""Publish a checkpoint to make it publicly accessible.
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
Only the owner of the training run can publish checkpoints.
"""
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {checkpoint_path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
# Create client and publish
client = create_rest_client()
client.publish_checkpoint_from_tinker_path(checkpoint_path).result()
@cli.command()
@click.argument("checkpoint_path")
@click.pass_obj
@handle_api_errors
def unpublish(cli_context: CLIContext, checkpoint_path: str) -> None:
"""Unpublish a checkpoint to make it private again.
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
Only the owner of the training run can unpublish checkpoints.
"""
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {checkpoint_path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
# Create client and unpublish
client = create_rest_client()
client.unpublish_checkpoint_from_tinker_path(checkpoint_path).result()
@cli.command()
@click.argument("checkpoint_path")
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
@click.pass_obj
@handle_api_errors
def delete(cli_context: CLIContext, checkpoint_path: str, yes: bool) -> None:
"""Delete a checkpoint permanently.
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
Only the owner of the training run can delete checkpoints.
WARNING: This action is permanent and cannot be undone.
"""
# Validate it's a tinker path
if not checkpoint_path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {checkpoint_path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
# Get format from context object
format = cli_context.format
# If not using --yes, show checkpoint info and prompt for confirmation
if not yes:
try:
checkpoint = get_checkpoint_from_path(checkpoint_path)
# Display checkpoint info using the same format as 'info' command
output = CheckpointInfoOutput(checkpoint)
output.print(format=format)
click.echo()
except TinkerCliError:
# If we can't get checkpoint info, still allow deletion attempt
# The API will return appropriate error if checkpoint doesn't exist
click.echo(f"Checkpoint path: {checkpoint_path}")
click.echo()
# Confirmation prompt
click.echo("WARNING: This action is permanent and cannot be undone.")
if not click.confirm("Are you sure you want to delete this checkpoint?"):
click.echo("Deletion cancelled.")
return
# Create client and delete
client = create_rest_client()
client.delete_checkpoint_from_tinker_path(checkpoint_path).result()

View file

@ -0,0 +1,257 @@
"""Commands for managing training runs.
This module implements the 'tinker run' commands, including:
- list: List all training runs
- info: Show details of a specific run
"""
from typing import Any, Dict, List, TYPE_CHECKING
import click
if TYPE_CHECKING:
from tinker.types import TrainingRun
from ..client import create_rest_client, handle_api_errors
from ..context import CLIContext
from ..output import OutputBase, format_timestamp
class RunListOutput(OutputBase):
"""Output for 'tinker run list' command."""
def __init__(
self,
runs: List["TrainingRun"],
total_count: int | None = None,
shown_count: int | None = None,
):
"""Initialize with list of training runs.
Args:
runs: List of TrainingRun objects
total_count: Total number of runs available (from cursor)
shown_count: Number of runs shown in this response
"""
self.runs = runs
self.total_count = total_count
self.shown_count = shown_count if shown_count is not None else len(runs)
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON output."""
# Check if these are Pydantic models
if self.runs and hasattr(self.runs[0], "model_dump"):
return {"runs": [run.model_dump() for run in self.runs]}
return {"runs": [dict(run) for run in self.runs]}
def get_title(self) -> str | None:
"""Return title for table output."""
count = len(self.runs)
if count == 0:
return "No training runs found"
# Build the base title
if count == 1:
title = "1 training run"
else:
title = f"{count} training runs"
# Add information about remaining runs if available
if self.total_count is not None and self.total_count > self.shown_count:
remaining = self.total_count - self.shown_count
if remaining == 1:
title += f" (1 more not shown, use --limit to see more)"
else:
title += f" ({remaining} more not shown, use --limit to see more)"
return title
def get_table_columns(self) -> List[str]:
"""Return column headers for table output."""
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"]
def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output."""
rows = []
for run in self.runs:
# Format LoRA information
if run.is_lora and run.lora_rank:
lora_info = f"Rank {run.lora_rank}"
elif run.is_lora:
lora_info = "Yes"
else:
lora_info = "No"
rows.append(
[
run.training_run_id,
run.base_model,
run.model_owner,
lora_info,
format_timestamp(run.last_request_time),
str(run.corrupted),
]
)
return rows
class RunInfoOutput(OutputBase):
"""Output for 'tinker run info' command."""
def __init__(self, run: "TrainingRun"):
"""Initialize with a single training run.
Args:
run: TrainingRun object
"""
self.run = run
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON output."""
if hasattr(self.run, "model_dump"):
return self.run.model_dump()
return dict(self.run)
def get_title(self) -> str | None:
"""Return title for table output."""
return f"Training Run: {self.run.training_run_id}"
def get_table_columns(self) -> List[str]:
"""Return column headers for table output."""
return ["Property", "Value"]
def get_table_rows(self) -> List[List[str]]:
"""Return rows for table output."""
rows = [
["Run ID", self.run.training_run_id],
["Base Model", self.run.base_model],
["Owner", self.run.model_owner],
]
# LoRA information
if self.run.is_lora:
if self.run.lora_rank:
rows.append(["LoRA", f"Yes (Rank {self.run.lora_rank})"])
else:
rows.append(["LoRA", "Yes"])
else:
rows.append(["LoRA", "No"])
# Last update time
rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
# Corruption status
rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"])
# Last checkpoints
if self.run.last_checkpoint:
rows.append(["Last Training Checkpoint", self.run.last_checkpoint.checkpoint_id])
rows.append([" - Time", format_timestamp(self.run.last_checkpoint.time)])
rows.append([" - Path", self.run.last_checkpoint.tinker_path])
if self.run.last_sampler_checkpoint:
rows.append(["Last Sampler Checkpoint", self.run.last_sampler_checkpoint.checkpoint_id])
rows.append([" - Time", format_timestamp(self.run.last_sampler_checkpoint.time)])
rows.append([" - Path", self.run.last_sampler_checkpoint.tinker_path])
# User metadata if present
if self.run.user_metadata:
rows.append(["Metadata", ""])
for key, value in self.run.user_metadata.items():
rows.append([f" - {key}", value])
return rows
# Click command group for run commands
@click.group()
def cli():
"""Manage training runs."""
pass
@cli.command()
@click.option(
"--limit",
type=int,
default=20,
help="Maximum number of runs to fetch (default: 20, use --limit=0 to fetch all)",
)
@click.pass_obj
@handle_api_errors
def list(cli_context: CLIContext, limit: int) -> None:
"""List all training runs."""
# Get format from context object
format = cli_context.format
# Create client
client = create_rest_client()
all_runs = []
offset = 0
batch_size = 100 # Fetch in batches of 100 for efficiency
# First fetch to get initial data and total count
first_response = client.list_training_runs(
limit=min(batch_size, limit) if limit > 0 else batch_size, offset=0
).result()
all_runs.extend(first_response.training_runs)
total_count = first_response.cursor.total_count
# Determine target count: either user-specified limit or total available
target_count = limit if limit > 0 else total_count
target_count = min(target_count, total_count) # Can't fetch more than exists
# If we need to fetch more runs, paginate with a progress bar
if len(all_runs) < target_count:
with click.progressbar(
length=target_count,
label=f"Fetching {'all' if limit == 0 else str(target_count)} training runs",
show_percent=True,
show_pos=True,
show_eta=True,
) as bar:
bar.update(len(all_runs))
# Fetch remaining runs in batches
while len(all_runs) < target_count:
offset = len(all_runs)
remaining = target_count - len(all_runs)
next_batch_size = min(batch_size, remaining)
response = client.list_training_runs(limit=next_batch_size, offset=offset).result()
all_runs.extend(response.training_runs)
bar.update(len(response.training_runs))
# Break if we got fewer than requested (reached the end)
if len(response.training_runs) < next_batch_size:
break
# Create output object with pagination information
output = RunListOutput(runs=all_runs, total_count=total_count, shown_count=len(all_runs))
# Print in requested format
output.print(format=format)
@cli.command()
@click.argument("run_id")
@click.pass_obj
@handle_api_errors
def info(cli_context: CLIContext, run_id: str) -> None:
"""Show details of a specific run."""
# Get format from context object
format = cli_context.format
# Create client
client = create_rest_client()
# Fetch training run details
response = client.get_training_run(run_id).result()
# Create output object
output = RunInfoOutput(run=response)
# Print in requested format
output.print(format=format)

View file

@ -0,0 +1,18 @@
"""Command for showing version information.
This module implements the 'tinker version' command.
"""
import click
@click.command()
def cli():
"""Show version information."""
try:
# Lazy import version only when needed
from tinker._version import __version__
click.echo(f"tinker {__version__}")
except ImportError:
click.echo("tinker (version unavailable)")

23
src/tinker/cli/context.py Normal file
View file

@ -0,0 +1,23 @@
"""Context object for the Tinker CLI.
This module provides a dataclass for sharing configuration and state
between CLI commands.
"""
from dataclasses import dataclass
from typing import Literal
@dataclass
class CLIContext:
"""Context object for sharing state between CLI commands.
This dataclass is passed through the Click command hierarchy
using the @click.pass_obj decorator, allowing commands to access
shared configuration without needing to traverse the context tree.
Attributes:
format: Output format for command results ('table' or 'json')
"""
format: Literal["table", "json"] = "table"

View file

@ -0,0 +1,31 @@
"""Custom exceptions for the Tinker CLI.
This module defines exceptions used throughout the CLI for consistent
error handling and graceful exits.
"""
class TinkerCliError(Exception):
"""Custom exception for CLI errors that should exit gracefully.
This exception is caught at the top level of the CLI and converted
to appropriate error messages and exit codes.
Attributes:
message: The main error message to display
details: Optional additional details or suggestions
exit_code: The exit code to use (default: 1)
"""
def __init__(self, message: str, details: str | None = None, exit_code: int = 1):
"""Initialize a TinkerCliError.
Args:
message: The main error message (will be prefixed with "Error: ")
details: Optional additional details or help text
exit_code: The exit code to use when exiting (default: 1)
"""
self.message = message
self.details = details
self.exit_code = exit_code
super().__init__(message)

View file

@ -0,0 +1,89 @@
"""Lazy loading support for Click commands.
This module provides a LazyGroup class that extends Click's Group to support
lazy loading of subcommands, ensuring fast CLI startup times.
"""
import importlib
from typing import Any, Dict, List
import click
class LazyGroup(click.Group):
"""A Click Group that supports lazy loading of subcommands.
This allows the CLI to have fast startup times by only importing
command modules when they are actually invoked, not when the CLI
is first loaded or when help is displayed.
"""
def __init__(
self, *args: Any, lazy_subcommands: Dict[str, str] | None = None, **kwargs: Any
) -> None:
"""Initialize the LazyGroup.
Args:
lazy_subcommands: A dictionary mapping command names to import paths.
Format: {"command": "module.path:attribute_name"}
"""
super().__init__(*args, **kwargs)
self.lazy_subcommands = lazy_subcommands or {}
def list_commands(self, ctx: click.Context) -> List[str]:
"""Return a list of all command names.
This includes both eagerly loaded commands and lazy commands.
"""
# Get any eagerly loaded commands
base = super().list_commands(ctx)
# Add lazy command names
lazy = sorted(self.lazy_subcommands.keys())
return base + lazy
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
"""Get a command by name, loading it lazily if necessary.
Args:
ctx: The Click context
cmd_name: The name of the command to retrieve
Returns:
The Click command object, or None if not found
"""
# Check if it's a lazy command
if cmd_name in self.lazy_subcommands:
return self._lazy_load(cmd_name)
# Fall back to normal command loading
return super().get_command(ctx, cmd_name)
def _lazy_load(self, cmd_name: str) -> click.Command:
"""Lazily load a command by importing its module.
Args:
cmd_name: The name of the command to load
Returns:
The loaded Click command object
Raises:
ValueError: If the imported object is not a Click Command
"""
# Get the import path for this command
import_path = self.lazy_subcommands[cmd_name]
# Split into module path and attribute name
module_name, attr_name = import_path.rsplit(":", 1)
# Import the module
mod = importlib.import_module(module_name)
# Get the command object
cmd_object = getattr(mod, attr_name)
# Verify it's a Click command
if not isinstance(cmd_object, click.Command):
raise ValueError(
f"Lazy loading of {import_path} failed: '{attr_name}' is not a Click Command"
)
return cmd_object

227
src/tinker/cli/output.py Normal file
View file

@ -0,0 +1,227 @@
"""Output formatting utilities for the Tinker CLI.
This module provides a base class for structured output that can be
rendered as either a table (using rich) or as JSON. Each command
defines its own output class that inherits from OutputBase.
"""
import sys
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Callable, Union
class OutputBase(ABC):
"""Virtual base class for all command outputs.
Subclasses must implement methods to convert data to various formats.
The base class provides the common print() method that handles
format selection and rendering.
"""
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""Convert output to dictionary for JSON serialization.
Returns:
Dictionary representation of the output data
"""
pass
@abstractmethod
def get_table_columns(self) -> List[str]:
"""Return list of column names for table output.
Returns:
List of column header strings
"""
pass
@abstractmethod
def get_table_rows(self) -> List[List[str]]:
"""Return list of rows for table output.
Each row should be a list of string values corresponding
to the columns returned by get_table_columns().
Returns:
List of rows, where each row is a list of string values
"""
pass
def get_title(self) -> str | None:
"""Optional title for the output display.
Override this method to provide a title for the table.
Returns:
Title string or None
"""
return None
def print(self, format: str = "table") -> None:
"""Print the output in the specified format.
Args:
format: Output format - either "table" or "json"
"""
if format == "json":
self._print_json()
else:
self._print_table()
def _print_json(self) -> None:
"""Print output as JSON."""
import json
data = self.to_dict()
json.dump(data, sys.stdout, indent=2, default=str)
print() # Add newline after JSON output
def _print_table(self) -> None:
"""Print output as a rich table."""
# Lazy import rich to avoid slow startup
from rich.console import Console
from rich.table import Table
console = Console()
# Create table with optional title
title = self.get_title()
table = Table(title=title) if title else Table()
# Add columns
columns = self.get_table_columns()
for col in columns:
# First column (usually ID) gets special styling
if col == columns[0]:
table.add_column(col, style="bright_cyan", no_wrap=True)
else:
table.add_column(col)
# Add rows
rows = self.get_table_rows()
for row in rows:
table.add_row(*row)
# Print the table
console.print(table)
# Utility formatting functions
def format_size(bytes: int) -> str:
"""Format bytes as human-readable size.
Args:
bytes: Size in bytes
Returns:
Human-readable size string (e.g., "1.2 GB")
"""
if bytes < 0:
return "N/A"
size = float(bytes)
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
if size < 1024.0:
if unit == "B":
return f"{int(size)} {unit}"
return f"{size:.1f} {unit}"
size /= 1024.0
return f"{size:.1f} EB"
def format_timestamp(dt: Union[datetime, str, None]) -> str:
"""Format datetime as relative time or absolute date.
Args:
dt: datetime object, ISO string, or None
Returns:
Formatted time string (e.g., "2 hours ago", "2024-01-15")
"""
if not dt:
return "N/A"
# Lazy import datetime
from datetime import datetime, timezone
# Handle different input types
if isinstance(dt, str):
# Try to parse ISO format string
try:
from datetime import datetime
dt = datetime.fromisoformat(dt.replace("Z", "+00:00"))
except (ValueError, AttributeError):
return str(dt)
if not hasattr(dt, "replace"):
# Not a datetime object
return str(dt)
try:
# Get current time
now = datetime.now(timezone.utc)
# Ensure dt has timezone info
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(timezone.utc)
# Calculate time difference
delta = now - dt
# Format based on age
if delta.days > 30:
return dt.strftime("%Y-%m-%d")
elif delta.days > 7:
weeks = delta.days // 7
return f"{weeks} week{'s' if weeks > 1 else ''} ago"
elif delta.days > 0:
return f"{delta.days} day{'s' if delta.days > 1 else ''} ago"
elif delta.seconds > 3600:
hours = delta.seconds // 3600
return f"{hours} hour{'s' if hours > 1 else ''} ago"
elif delta.seconds > 60:
minutes = delta.seconds // 60
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
else:
return "just now"
except Exception:
# If any error occurs, just return string representation
return str(dt)
def format_bool(value: bool) -> str:
"""Format boolean for display.
Args:
value: Boolean value
Returns:
"Yes" or "No"
"""
return "Yes" if value else "No"
def format_optional(value: Any, formatter: Callable[[Any], str] | None = None) -> str:
"""Format an optional value.
Args:
value: Value to format (may be None)
formatter: Optional formatting function to apply if value is not None
Returns:
Formatted string or "N/A" if value is None
"""
if value is None:
return "N/A"
if formatter:
return formatter(value)
return str(value)

View file

@ -31,9 +31,12 @@ class RestClient(TelemetryProvider):
Key methods: Key methods:
- list_checkpoints() - list available model checkpoints (both training and sampler) - list_checkpoints() - list available model checkpoints (both training and sampler)
- list_user_checkpoints() - list all checkpoints across all user's training runs
- get_training_run() - get model information and metadata as ModelEntry - get_training_run() - get model information and metadata as ModelEntry
- delete_checkpoint() - delete an existing checkpoint for a training run - delete_checkpoint() - delete an existing checkpoint for a training run
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive - get_checkpoint_archive_url() - get signed URL to download checkpoint archive
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
Args: Args:
holder: Internal client managing HTTP connections and async operations holder: Internal client managing HTTP connections and async operations
@ -420,3 +423,227 @@ class RestClient(TelemetryProvider):
return await self._get_checkpoint_archive_url_submit( return await self._get_checkpoint_archive_url_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
) )
def _publish_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]:
"""Internal method to submit publish checkpoint request."""
async def _publish_checkpoint_async() -> None:
async def _send_request() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
await client.post(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
cast_to=NoneType,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_publish_checkpoint_async())
@sync_only
@capture_exceptions(fatal=True)
def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
"""Publish a checkpoint referenced by a tinker path to make it publicly accessible.
Only the exact owner of the training run can publish checkpoints.
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method.
Args:
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
Returns:
A Future that completes when the checkpoint is published
Raises:
HTTPException: 400 if checkpoint identifier is invalid
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 409 if checkpoint is already public
HTTPException: 500 if there's an error publishing the checkpoint
Example:
>>> future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
>>> future.result() # Wait for completion
>>> print("Checkpoint published successfully")
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of publish_checkpoint_from_tinker_path.
Only the exact owner of the training run can publish checkpoints.
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path_async method.
Args:
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
Raises:
HTTPException: 400 if checkpoint identifier is invalid
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 409 if checkpoint is already public
HTTPException: 500 if there's an error publishing the checkpoint
Example:
>>> await rest_client.publish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001")
>>> print("Checkpoint published successfully")
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._publish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
def _unpublish_checkpoint_submit(
self, training_run_id: types.ModelID, checkpoint_id: str
) -> AwaitableConcurrentFuture[None]:
"""Internal method to submit unpublish checkpoint request."""
async def _unpublish_checkpoint_async() -> None:
async def _send_request() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
await client.delete(
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
cast_to=NoneType,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_unpublish_checkpoint_async())
@sync_only
@capture_exceptions(fatal=True)
def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
"""Unpublish a checkpoint referenced by a tinker path to make it private again.
Only the exact owner of the training run can unpublish checkpoints.
This reverses the effect of publishing a checkpoint.
Args:
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
Returns:
A Future that completes when the checkpoint is unpublished
Raises:
HTTPException: 400 if checkpoint identifier is invalid
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 409 if checkpoint is already private
HTTPException: 500 if there's an error unpublishing the checkpoint
Example:
>>> future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
>>> future.result() # Wait for completion
>>> print("Checkpoint unpublished successfully")
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
return self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
).future()
@capture_exceptions(fatal=True)
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
"""Async version of unpublish_checkpoint_from_tinker_path.
Only the exact owner of the training run can unpublish checkpoints.
This reverses the effect of publishing a checkpoint.
Args:
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
Raises:
HTTPException: 400 if checkpoint identifier is invalid
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 409 if checkpoint is already private
HTTPException: 500 if there's an error unpublishing the checkpoint
Example:
>>> await rest_client.unpublish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001")
>>> print("Checkpoint unpublished successfully")
"""
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
await self._unpublish_checkpoint_submit(
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
)
def _list_user_checkpoints_submit(
self, limit: int = 100, offset: int = 0
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
"""Internal method to submit list user checkpoints request."""
async def _list_user_checkpoints_async() -> types.CheckpointsListResponse:
async def _send_request() -> types.CheckpointsListResponse:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
params: dict[str, object] = {"limit": limit, "offset": offset}
return await client.get(
"/api/v1/checkpoints",
options={"params": params},
cast_to=types.CheckpointsListResponse,
)
return await self.holder.execute_with_retries(_send_request)
return self.holder.run_coroutine_threadsafe(_list_user_checkpoints_async())
@sync_only
@capture_exceptions(fatal=True)
def list_user_checkpoints(
self, limit: int = 100, offset: int = 0
) -> ConcurrentFuture[types.CheckpointsListResponse]:
"""List all checkpoints for the current user across all their training runs.
This method retrieves checkpoints from all training runs owned by the authenticated user,
sorted by time (newest first). It supports pagination for efficiently handling large
numbers of checkpoints.
Args:
limit: Maximum number of checkpoints to return (default 100)
offset: Offset for pagination (default 0)
Returns:
A Future containing the CheckpointsListResponse with checkpoints and cursor info
Example:
>>> future = rest_client.list_user_checkpoints(limit=50)
>>> response = future.result()
>>> print(f"Found {len(response.checkpoints)} checkpoints")
>>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
>>> for checkpoint in response.checkpoints:
... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
>>> # Get next page if there are more checkpoints
>>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
... next_page = rest_client.list_user_checkpoints(limit=50, offset=50)
"""
return self._list_user_checkpoints_submit(limit, offset).future()
@capture_exceptions(fatal=True)
async def list_user_checkpoints_async(
self, limit: int = 100, offset: int = 0
) -> types.CheckpointsListResponse:
"""Async version of list_user_checkpoints.
This method retrieves checkpoints from all training runs owned by the authenticated user,
sorted by time (newest first). It supports pagination for efficiently handling large
numbers of checkpoints.
Args:
limit: Maximum number of checkpoints to return (default 100)
offset: Offset for pagination (default 0)
Returns:
CheckpointsListResponse with checkpoints and cursor info
Example:
>>> response = await rest_client.list_user_checkpoints_async(limit=50)
>>> print(f"Found {len(response.checkpoints)} checkpoints")
>>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
>>> for checkpoint in response.checkpoints:
... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
>>> # Get next page if there are more checkpoints
>>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
... next_page = await rest_client.list_user_checkpoints_async(limit=50, offset=50)
"""
return await self._list_user_checkpoints_submit(limit, offset)

View file

@ -593,6 +593,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
assert model_name is not None, "This shouldn't happen: model_name is None" assert model_name is not None, "This shouldn't happen: model_name is None"
# Use tokenizer_id from get_info if available, otherwise fall back to heuristic logic # Use tokenizer_id from get_info if available, otherwise fall back to heuristic logic
kwargs = {}
tokenizer_id = info.model_data.tokenizer_id tokenizer_id = info.model_data.tokenizer_id
if tokenizer_id is None: if tokenizer_id is None:
# We generally adhere to the huggingface convention of "<org>/<model>" but # We generally adhere to the huggingface convention of "<org>/<model>" but
@ -608,4 +609,10 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
else: else:
tokenizer_id = model_name tokenizer_id = model_name
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True) if tokenizer_id == "moonshotai/Kimi-K2-Thinking":
kwargs = {
"trust_remote_code": True,
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
}
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)

View file

@ -185,7 +185,7 @@ class AsyncWeightsResource(AsyncAPIResource):
if not model_id: if not model_id:
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}") raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
return await self._get( return await self._get(
f"/api/v1/models/{model_id}/checkpoints", f"/api/v1/training_runs/{model_id}/checkpoints",
options=make_request_options( options=make_request_options(
extra_headers=extra_headers, extra_headers=extra_headers,
extra_query=extra_query, extra_query=extra_query,

View file

@ -21,6 +21,12 @@ class Checkpoint(BaseModel):
tinker_path: str tinker_path: str
"""The tinker path to the checkpoint""" """The tinker path to the checkpoint"""
size_bytes: int | None = None
"""The size of the checkpoint in bytes"""
public: bool = False
"""Whether the checkpoint is publicly accessible"""
class ParsedCheckpointTinkerPath(BaseModel): class ParsedCheckpointTinkerPath(BaseModel):
tinker_path: str tinker_path: str

View file

@ -1,5 +1,6 @@
from .._models import BaseModel from .._models import BaseModel
from .checkpoint import Checkpoint from .checkpoint import Checkpoint
from .cursor import Cursor
__all__ = ["CheckpointsListResponse"] __all__ = ["CheckpointsListResponse"]
@ -7,3 +8,6 @@ __all__ = ["CheckpointsListResponse"]
class CheckpointsListResponse(BaseModel): class CheckpointsListResponse(BaseModel):
checkpoints: list[Checkpoint] checkpoints: list[Checkpoint]
"""List of available model checkpoints for the model""" """List of available model checkpoints for the model"""
cursor: Cursor | None = None
"""Pagination cursor information (None for unpaginated responses)"""