mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
604e00c700
commit
2a37c3afb4
18 changed files with 1811 additions and 4 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
{
|
{
|
||||||
"last_synced_sha": "d973e1d6eede81b853a167e8f999f43402e07c3a",
|
"last_synced_sha": "31ae0341ea2c6122c46c0597c5a862440a9cf31e",
|
||||||
"last_sync_time": "2025-11-11T05:56:15.874542"
|
"last_sync_time": "2025-11-18T22:06:38.289689"
|
||||||
}
|
}
|
||||||
|
|
@ -17,6 +17,8 @@ dependencies = [
|
||||||
"sniffio",
|
"sniffio",
|
||||||
"numpy",
|
"numpy",
|
||||||
"torch",
|
"torch",
|
||||||
|
"rich>=13.0.0",
|
||||||
|
"click>=8.0.0",
|
||||||
]
|
]
|
||||||
requires-python = ">= 3.11"
|
requires-python = ">= 3.11"
|
||||||
classifiers = [
|
classifiers = [
|
||||||
|
|
@ -36,6 +38,9 @@ classifiers = [
|
||||||
"License :: OSI Approved :: Apache Software License"
|
"License :: OSI Approved :: Apache Software License"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
tinker = "tinker.cli.__main__:cli"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
Homepage = "https://thinkingmachines.ai/tinker"
|
Homepage = "https://thinkingmachines.ai/tinker"
|
||||||
Repository = "https://github.com/thinking-machines-lab/tinker"
|
Repository = "https://github.com/thinking-machines-lab/tinker"
|
||||||
|
|
|
||||||
281
src/tinker/cli/CLAUDE.md
Normal file
281
src/tinker/cli/CLAUDE.md
Normal file
|
|
@ -0,0 +1,281 @@
|
||||||
|
# Tinker CLI Design Documentation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance.
|
||||||
|
|
||||||
|
## Key Design Decisions
|
||||||
|
|
||||||
|
### 1. Lazy Import Strategy with Click
|
||||||
|
|
||||||
|
**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level.
|
||||||
|
|
||||||
|
**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands.
|
||||||
|
|
||||||
|
**Implementation**:
|
||||||
|
- Main `__init__.py` only imports `click` and `lazy_group`
|
||||||
|
- Command modules are loaded only when invoked via `LazyGroup`
|
||||||
|
- Output formatting imports `rich` only when table output is needed
|
||||||
|
- JSON module imported only when JSON output is requested
|
||||||
|
- Version information loaded from `_version.py` only when `tinker version` is used
|
||||||
|
|
||||||
|
### 2. Click Framework with LazyGroup
|
||||||
|
|
||||||
|
**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Click provides cleaner command structure with decorators
|
||||||
|
- Better subcommand isolation - each command file is self-contained
|
||||||
|
- Automatic help generation with better formatting
|
||||||
|
- Built-in type conversion and validation
|
||||||
|
- LazyGroup enables fast startup by deferring imports
|
||||||
|
|
||||||
|
**LazyGroup Implementation**:
|
||||||
|
```python
|
||||||
|
class LazyGroup(click.Group):
|
||||||
|
def __init__(self, *args, lazy_subcommands=None, **kwargs):
|
||||||
|
# Map of command name to "module.path:command_name"
|
||||||
|
self.lazy_subcommands = lazy_subcommands or {}
|
||||||
|
|
||||||
|
def get_command(self, ctx, cmd_name):
|
||||||
|
if cmd_name in self.lazy_subcommands:
|
||||||
|
# Import only when command is actually invoked
|
||||||
|
import_path = self.lazy_subcommands[cmd_name]
|
||||||
|
module_name, attr_name = import_path.rsplit(":", 1)
|
||||||
|
mod = importlib.import_module(module_name)
|
||||||
|
return getattr(mod, attr_name)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Hierarchical Command Structure
|
||||||
|
|
||||||
|
**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Provides a consistent, predictable interface
|
||||||
|
- Groups related functionality together
|
||||||
|
- Makes the CLI extensible for future commands
|
||||||
|
- Follows common CLI patterns (like `git`, `docker`, etc.)
|
||||||
|
|
||||||
|
**Examples**:
|
||||||
|
- `tinker version` - Show CLI and SDK version
|
||||||
|
- `tinker run list` - List all training runs
|
||||||
|
- `tinker run info <run-id>` - Show details of a specific run
|
||||||
|
- `tinker checkpoint list` - List all checkpoints
|
||||||
|
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
||||||
|
|
||||||
|
### 4. Output System with Inheritance
|
||||||
|
|
||||||
|
**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Enforces consistent interface across all commands
|
||||||
|
- Encapsulates output logic with the command that generates it
|
||||||
|
- Makes it easy to support multiple output formats (table, JSON)
|
||||||
|
- Keeps related code together in the same module
|
||||||
|
|
||||||
|
**Implementation**:
|
||||||
|
- `OutputBase` in `output.py` defines the contract
|
||||||
|
- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`)
|
||||||
|
- Base class handles format selection and rendering
|
||||||
|
|
||||||
|
### 5. Self-Contained Command Modules
|
||||||
|
|
||||||
|
**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Modular architecture - commands can be developed independently
|
||||||
|
- Clear separation of concerns
|
||||||
|
- Easy to add new commands without modifying core files
|
||||||
|
- Consistent pattern across all commands
|
||||||
|
|
||||||
|
**Command Structure**:
|
||||||
|
```python
|
||||||
|
# Each command file follows this pattern:
|
||||||
|
@click.group() # or @click.command() for simple commands
|
||||||
|
def cli():
|
||||||
|
"""Command description."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@cli.command() # For subcommands
|
||||||
|
def list():
|
||||||
|
"""Subcommand implementation."""
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6. Centralized Client Management
|
||||||
|
|
||||||
|
**Decision**: All SDK client creation and error handling is centralized in `client.py`.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Single place to handle authentication and connection errors
|
||||||
|
- Consistent error messages across all commands
|
||||||
|
- Reusable error handling decorator
|
||||||
|
- Clean separation of concerns
|
||||||
|
|
||||||
|
### 7. Rich Tables for Human-Readable Output
|
||||||
|
|
||||||
|
**Decision**: Use the `rich` library for table formatting, kept as an optional dependency.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Provides beautiful, formatted tables with colors and borders
|
||||||
|
- Handles column width adjustment automatically
|
||||||
|
- Supports both dark and light terminal themes
|
||||||
|
- Optional dependency keeps the core package lightweight
|
||||||
|
|
||||||
|
### 8. Unix-Style Default Output
|
||||||
|
|
||||||
|
**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Follows Unix philosophy
|
||||||
|
- Tables are better for human consumption
|
||||||
|
- JSON is better for scripting and automation
|
||||||
|
- Single flag switches between formats consistently
|
||||||
|
|
||||||
|
## Performance Optimizations
|
||||||
|
|
||||||
|
1. **LazyGroup for deferred imports** - Commands only loaded when invoked
|
||||||
|
2. **No heavy imports at module level** - Only Click imported initially
|
||||||
|
3. **Lazy loading** of all SDK dependencies
|
||||||
|
4. **Progress indicators** that clear themselves
|
||||||
|
5. **Efficient data fetching** - fetch all data by default instead of pagination
|
||||||
|
|
||||||
|
## Error Handling Strategy
|
||||||
|
|
||||||
|
1. **User-friendly messages** - Technical errors are translated to helpful messages
|
||||||
|
2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes
|
||||||
|
3. **Graceful degradation** - Continue operation when possible
|
||||||
|
4. **Detailed error info** - Show details when available, traceback only in TTY
|
||||||
|
|
||||||
|
### TinkerCliError Exception Pattern
|
||||||
|
|
||||||
|
All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ..exceptions import TinkerCliError
|
||||||
|
|
||||||
|
# Instead of:
|
||||||
|
print(f"Error: Something went wrong", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Use:
|
||||||
|
raise TinkerCliError(
|
||||||
|
"Something went wrong",
|
||||||
|
"Optional details or help text",
|
||||||
|
exit_code=1 # Optional, defaults to 1
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits:**
|
||||||
|
- Better testability (can catch exceptions in tests)
|
||||||
|
- Centralized error formatting in `__main__.py`
|
||||||
|
- Consistent exit codes across the CLI
|
||||||
|
- Stack traces preserved for debugging
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification
|
||||||
|
- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages
|
||||||
|
- The main error handler in `__main__.py` handles printing to stderr and exiting
|
||||||
|
|
||||||
|
## Future Extensibility
|
||||||
|
|
||||||
|
The architecture supports easy addition of:
|
||||||
|
|
||||||
|
### New Commands
|
||||||
|
- Create new module in `commands/` directory
|
||||||
|
- Define output classes in the same module if needed
|
||||||
|
- Add command to lazy_subcommands in `__init__.py`
|
||||||
|
|
||||||
|
### New Subcommands
|
||||||
|
- Add new Click command decorator to existing command module
|
||||||
|
- Define corresponding output class if needed
|
||||||
|
- Subcommands automatically discovered by Click
|
||||||
|
|
||||||
|
### New Output Formats
|
||||||
|
- Override `print()` method in `OutputBase`
|
||||||
|
- Or add new format handling to base class
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
|
||||||
|
1. **Startup time**: `time tinker --help` should be <50ms
|
||||||
|
2. **Import verification**: Check that modules aren't imported unnecessarily
|
||||||
|
3. **Output formats**: Test both table and JSON output
|
||||||
|
4. **Error cases**: Test with missing auth, invalid IDs, network errors
|
||||||
|
5. **Empty results**: Ensure graceful handling of no data
|
||||||
|
|
||||||
|
## Module Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
cli/
|
||||||
|
├── __init__.py # Main entry with LazyGroup configuration
|
||||||
|
├── __main__.py # Module execution support
|
||||||
|
├── lazy_group.py # LazyGroup implementation for lazy loading
|
||||||
|
├── output.py # OutputBase class and formatting utilities
|
||||||
|
├── client.py # SDK client creation and error handling
|
||||||
|
├── commands/
|
||||||
|
│ ├── __init__.py # Command module marker
|
||||||
|
│ ├── version.py # Version command
|
||||||
|
│ ├── run.py # Run commands and output classes
|
||||||
|
│ └── checkpoint.py # Checkpoint commands and output classes
|
||||||
|
└── CLAUDE.md # This documentation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Command Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Show version
|
||||||
|
tinker version
|
||||||
|
|
||||||
|
# List all training runs
|
||||||
|
tinker run list
|
||||||
|
|
||||||
|
# Show run details
|
||||||
|
tinker run info run-abc123
|
||||||
|
|
||||||
|
# List all checkpoints
|
||||||
|
tinker checkpoint list
|
||||||
|
|
||||||
|
# List checkpoints for specific run
|
||||||
|
tinker checkpoint list run-abc123
|
||||||
|
|
||||||
|
# Show checkpoint details
|
||||||
|
tinker checkpoint info ckpt-xyz789
|
||||||
|
|
||||||
|
# JSON output
|
||||||
|
tinker --format json run list
|
||||||
|
tinker --format json checkpoint list
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
### Required
|
||||||
|
- Python 3.11+
|
||||||
|
- tinker SDK (main package)
|
||||||
|
- click>=8.0.0 (CLI framework)
|
||||||
|
|
||||||
|
### Optional
|
||||||
|
- `rich` - For table formatting (installed with `pip install tinker[cli]`)
|
||||||
|
|
||||||
|
## Migration from Argparse to Click
|
||||||
|
|
||||||
|
### Key Changes:
|
||||||
|
1. **Command Definition**: Decorators instead of `parser.add_argument()`
|
||||||
|
2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch
|
||||||
|
3. **Context Passing**: Click's context system for sharing format option
|
||||||
|
4. **Error Handling**: Click handles exits and error formatting
|
||||||
|
5. **Help Generation**: Automatic from docstrings and decorators
|
||||||
|
|
||||||
|
### Benefits:
|
||||||
|
- Cleaner, more Pythonic code
|
||||||
|
- Better command organization
|
||||||
|
- Built-in testing utilities
|
||||||
|
- Easier to extend with plugins
|
||||||
|
- More consistent behavior across commands
|
||||||
|
|
||||||
|
## Maintenance Notes
|
||||||
|
|
||||||
|
1. **Keep imports lazy** - Use LazyGroup for all commands
|
||||||
|
2. **Test startup time** - Regularly verify fast startup is maintained
|
||||||
|
3. **Follow Click patterns** - Use decorators and context properly
|
||||||
|
4. **Document changes** - Update this file when making architectural changes
|
||||||
|
5. **Maintain consistency** - All commands should follow the same structure
|
||||||
60
src/tinker/cli/__main__.py
Normal file
60
src/tinker/cli/__main__.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""Tinker CLI - Command-line interface for the Tinker SDK.
|
||||||
|
|
||||||
|
This module implements lazy loading to ensure fast startup times.
|
||||||
|
Only Click is imported at the module level. All other dependencies
|
||||||
|
including command modules are imported on-demand.
|
||||||
|
|
||||||
|
Enable execution of the CLI via python -m tinker.cli
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import click
|
||||||
|
from .lazy_group import LazyGroup
|
||||||
|
from .context import CLIContext
|
||||||
|
from .exceptions import TinkerCliError
|
||||||
|
|
||||||
|
|
||||||
|
@click.group(
|
||||||
|
cls=LazyGroup,
|
||||||
|
lazy_subcommands={
|
||||||
|
"checkpoint": "tinker.cli.commands.checkpoint:cli",
|
||||||
|
"run": "tinker.cli.commands.run:cli",
|
||||||
|
"version": "tinker.cli.commands.version:cli",
|
||||||
|
},
|
||||||
|
context_settings=dict(help_option_names=["-h", "--help"]),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--format",
|
||||||
|
"-f",
|
||||||
|
type=click.Choice(["table", "json"]),
|
||||||
|
default="table",
|
||||||
|
help="Output format (default: table)",
|
||||||
|
)
|
||||||
|
@click.pass_context
|
||||||
|
def main_cli(ctx: click.Context, format: str) -> None:
|
||||||
|
"""Tinker management CLI."""
|
||||||
|
# Store format in context for subcommands to access
|
||||||
|
ctx.obj = CLIContext(format=format) # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
try:
|
||||||
|
main_cli()
|
||||||
|
except TinkerCliError as e:
|
||||||
|
# Print error message to stderr
|
||||||
|
if e.message:
|
||||||
|
print(f"Error: {e.message}", file=sys.stderr)
|
||||||
|
if e.details:
|
||||||
|
print(e.details, file=sys.stderr)
|
||||||
|
sys.exit(e.exit_code)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# Standard Unix exit code for Ctrl+C
|
||||||
|
sys.exit(130)
|
||||||
|
|
||||||
|
|
||||||
|
# Make main available for entry point
|
||||||
|
cli = main
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
157
src/tinker/cli/client.py
Normal file
157
src/tinker/cli/client.py
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
"""Client utilities for tinker CLI - handles SDK client creation and configuration.
|
||||||
|
|
||||||
|
This module provides functions for creating and configuring the Tinker SDK
|
||||||
|
client, with proper error handling for common issues like authentication
|
||||||
|
and network errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from functools import wraps
|
||||||
|
from typing import TypeVar, Callable, Any, cast, TYPE_CHECKING
|
||||||
|
|
||||||
|
from .exceptions import TinkerCliError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tinker.lib.public_interfaces.rest_client import RestClient
|
||||||
|
|
||||||
|
|
||||||
|
def create_rest_client() -> "RestClient":
|
||||||
|
"""Create and configure a RestClient instance with proper error handling.
|
||||||
|
|
||||||
|
This function handles the creation of the ServiceClient and RestClient,
|
||||||
|
with appropriate error messages for common failure cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured RestClient instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TinkerCliError: If client creation fails
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid slow startup
|
||||||
|
from tinker import ServiceClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
service_client = ServiceClient()
|
||||||
|
return service_client.create_rest_client()
|
||||||
|
except ImportError as e:
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Failed to import Tinker SDK: {e}",
|
||||||
|
"Please ensure the tinker package is properly installed.",
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# Often indicates missing or invalid API key
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Configuration error: {e}", "Please check your Tinker API credentials."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Catch-all for other errors
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Failed to connect to Tinker API: {e}",
|
||||||
|
"Please check your network connection and API configuration.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Type variable for decorator
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def handle_api_errors(func: F) -> F:
|
||||||
|
"""Decorator for handling common API errors.
|
||||||
|
|
||||||
|
This decorator catches common exceptions from the Tinker API
|
||||||
|
and provides user-friendly error messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to wrap with error handling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function with error handling
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
# Lazy import to avoid slow startup
|
||||||
|
from tinker._exceptions import (
|
||||||
|
APIError,
|
||||||
|
BadRequestError,
|
||||||
|
AuthenticationError,
|
||||||
|
PermissionDeniedError,
|
||||||
|
NotFoundError,
|
||||||
|
UnprocessableEntityError,
|
||||||
|
RateLimitError,
|
||||||
|
InternalServerError,
|
||||||
|
APIStatusError,
|
||||||
|
APIConnectionError,
|
||||||
|
APITimeoutError,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except NotFoundError as e:
|
||||||
|
details = f"Details: {e.message}" if hasattr(e, "message") else None
|
||||||
|
raise TinkerCliError("Resource not found", details)
|
||||||
|
except AuthenticationError as e:
|
||||||
|
details = "Please check your API key or authentication credentials."
|
||||||
|
if hasattr(e, "message"):
|
||||||
|
details += f"\nDetails: {e.message}"
|
||||||
|
raise TinkerCliError("Authentication failed", details)
|
||||||
|
except PermissionDeniedError as e:
|
||||||
|
details = "You don't have permission to access this resource."
|
||||||
|
if hasattr(e, "message"):
|
||||||
|
details += f"\nDetails: {e.message}"
|
||||||
|
raise TinkerCliError("Permission denied", details)
|
||||||
|
except BadRequestError as e:
|
||||||
|
details = f"Details: {e.message}" if hasattr(e, "message") else None
|
||||||
|
raise TinkerCliError("Invalid request", details)
|
||||||
|
except UnprocessableEntityError as e:
|
||||||
|
details = f"Details: {e.message}" if hasattr(e, "message") else None
|
||||||
|
raise TinkerCliError("Invalid data provided", details)
|
||||||
|
except RateLimitError as e:
|
||||||
|
details = "Please wait a moment before trying again."
|
||||||
|
if hasattr(e, "message"):
|
||||||
|
details += f"\nDetails: {e.message}"
|
||||||
|
raise TinkerCliError("Rate limit exceeded", details)
|
||||||
|
except InternalServerError as e:
|
||||||
|
details = "The Tinker API encountered an internal error. Please try again later."
|
||||||
|
if hasattr(e, "message"):
|
||||||
|
details += f"\nDetails: {e.message}"
|
||||||
|
raise TinkerCliError("Internal server error", details)
|
||||||
|
except APITimeoutError as e:
|
||||||
|
details = "The request to Tinker API timed out. Please try again."
|
||||||
|
if hasattr(e, "message"):
|
||||||
|
details += f"\nDetails: {e.message}"
|
||||||
|
raise TinkerCliError("Request timeout", details)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
details = "Could not connect to the Tinker API. Please check your network connection."
|
||||||
|
if hasattr(e, "message"):
|
||||||
|
details += f"\nDetails: {e.message}"
|
||||||
|
raise TinkerCliError("Connection failed", details)
|
||||||
|
except APIStatusError as e:
|
||||||
|
status = e.status_code if hasattr(e, "status_code") else "unknown"
|
||||||
|
details = f"Details: {e.message}" if hasattr(e, "message") else None
|
||||||
|
raise TinkerCliError(f"API error (status {status})", details)
|
||||||
|
except APIError as e:
|
||||||
|
# Generic API error
|
||||||
|
details = f"Details: {e.message}" if hasattr(e, "message") else None
|
||||||
|
raise TinkerCliError("API error occurred", details)
|
||||||
|
except TinkerCliError:
|
||||||
|
# Re-raise our own errors without modification
|
||||||
|
raise
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# Re-raise keyboard interrupt to be handled by main
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Catch-all for unexpected errors
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
details = None
|
||||||
|
if sys.stderr.isatty():
|
||||||
|
# Only include traceback if stderr is a terminal (for debugging)
|
||||||
|
import io
|
||||||
|
|
||||||
|
tb_str = io.StringIO()
|
||||||
|
traceback.print_exc(file=tb_str)
|
||||||
|
details = tb_str.getvalue()
|
||||||
|
raise TinkerCliError(f"Unexpected error occurred: {e}", details)
|
||||||
|
|
||||||
|
return cast(F, wrapper)
|
||||||
1
src/tinker/cli/commands/__init__.py
Normal file
1
src/tinker/cli/commands/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
"""Command modules for the Tinker CLI."""
|
||||||
414
src/tinker/cli/commands/checkpoint.py
Normal file
414
src/tinker/cli/commands/checkpoint.py
Normal file
|
|
@ -0,0 +1,414 @@
|
||||||
|
"""Commands for managing checkpoints.
|
||||||
|
|
||||||
|
This module implements the 'tinker checkpoint' commands, including:
|
||||||
|
- list: List all checkpoints or checkpoints for a specific run
|
||||||
|
- info: Show details of a specific checkpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tinker.types import Checkpoint
|
||||||
|
|
||||||
|
from ..client import create_rest_client, handle_api_errors
|
||||||
|
from ..context import CLIContext
|
||||||
|
from ..exceptions import TinkerCliError
|
||||||
|
from ..output import OutputBase, format_bool, format_size, format_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointListOutput(OutputBase):
|
||||||
|
"""Output for 'tinker checkpoint list' command."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
checkpoints: List["Checkpoint"],
|
||||||
|
run_id: str | None = None,
|
||||||
|
total_count: int | None = None,
|
||||||
|
shown_count: int | None = None,
|
||||||
|
):
|
||||||
|
"""Initialize with list of checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoints: List of Checkpoint objects
|
||||||
|
run_id: Optional training run ID if filtering by run
|
||||||
|
total_count: Total number of checkpoints available
|
||||||
|
shown_count: Number of checkpoints shown in this response
|
||||||
|
"""
|
||||||
|
self.checkpoints = checkpoints
|
||||||
|
self.run_id = run_id
|
||||||
|
self.total_count = total_count
|
||||||
|
self.shown_count = shown_count if shown_count is not None else len(checkpoints)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON output."""
|
||||||
|
result = {}
|
||||||
|
if self.run_id:
|
||||||
|
result["run_id"] = self.run_id
|
||||||
|
|
||||||
|
# Check if these are Pydantic models
|
||||||
|
if self.checkpoints and hasattr(self.checkpoints[0], "model_dump"):
|
||||||
|
result["checkpoints"] = [c.model_dump() for c in self.checkpoints]
|
||||||
|
else:
|
||||||
|
result["checkpoints"] = [dict(c) for c in self.checkpoints]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_title(self) -> str | None:
|
||||||
|
"""Return title for table output."""
|
||||||
|
count = len(self.checkpoints)
|
||||||
|
|
||||||
|
if self.run_id:
|
||||||
|
if count == 0:
|
||||||
|
return f"No checkpoints found for run {self.run_id}"
|
||||||
|
elif count == 1:
|
||||||
|
title = f"1 checkpoint for run {self.run_id}"
|
||||||
|
else:
|
||||||
|
title = f"{count} checkpoints for run {self.run_id}"
|
||||||
|
else:
|
||||||
|
if count == 0:
|
||||||
|
return "No checkpoints found"
|
||||||
|
elif count == 1:
|
||||||
|
title = "1 checkpoint"
|
||||||
|
else:
|
||||||
|
title = f"{count} checkpoints"
|
||||||
|
|
||||||
|
# Add information about remaining checkpoints if available
|
||||||
|
if self.total_count is not None and self.total_count > self.shown_count:
|
||||||
|
remaining = self.total_count - self.shown_count
|
||||||
|
if remaining == 1:
|
||||||
|
title += " (1 more not shown, use --limit to see more)"
|
||||||
|
else:
|
||||||
|
title += f" ({remaining} more not shown, use --limit to see more)"
|
||||||
|
|
||||||
|
return title
|
||||||
|
|
||||||
|
def get_table_columns(self) -> List[str]:
|
||||||
|
"""Return column headers for table output."""
|
||||||
|
return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"]
|
||||||
|
|
||||||
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
|
"""Return rows for table output."""
|
||||||
|
rows = []
|
||||||
|
for ckpt in self.checkpoints:
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
ckpt.checkpoint_id,
|
||||||
|
ckpt.checkpoint_type,
|
||||||
|
format_size(ckpt.size_bytes) if hasattr(ckpt, "size_bytes") else "N/A",
|
||||||
|
format_bool(ckpt.public),
|
||||||
|
format_timestamp(ckpt.time),
|
||||||
|
ckpt.tinker_path,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointInfoOutput(OutputBase):
|
||||||
|
"""Output for 'tinker checkpoint info' command."""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint: "Checkpoint"):
|
||||||
|
"""Initialize with a single checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint: Checkpoint object
|
||||||
|
"""
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON output."""
|
||||||
|
if hasattr(self.checkpoint, "model_dump"):
|
||||||
|
return self.checkpoint.model_dump()
|
||||||
|
return dict(self.checkpoint)
|
||||||
|
|
||||||
|
def get_title(self) -> str | None:
|
||||||
|
"""Return title for table output."""
|
||||||
|
return f"Checkpoint: {self.checkpoint.checkpoint_id}"
|
||||||
|
|
||||||
|
def get_table_columns(self) -> List[str]:
|
||||||
|
"""Return column headers for table output."""
|
||||||
|
return ["Property", "Value"]
|
||||||
|
|
||||||
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
|
"""Return rows for table output."""
|
||||||
|
rows = [
|
||||||
|
["Checkpoint ID", self.checkpoint.checkpoint_id],
|
||||||
|
["Type", self.checkpoint.checkpoint_type],
|
||||||
|
["Tinker Path", self.checkpoint.tinker_path],
|
||||||
|
]
|
||||||
|
|
||||||
|
# Size if available
|
||||||
|
if hasattr(self.checkpoint, "size_bytes"):
|
||||||
|
rows.append(["Size", format_size(self.checkpoint.size_bytes)])
|
||||||
|
|
||||||
|
# Public status
|
||||||
|
rows.append(["Public", format_bool(self.checkpoint.public)])
|
||||||
|
|
||||||
|
# Creation time
|
||||||
|
rows.append(["Created", format_timestamp(self.checkpoint.time)])
|
||||||
|
|
||||||
|
# Parse training run ID from path
|
||||||
|
if self.checkpoint.tinker_path.startswith("tinker://"):
|
||||||
|
parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/")
|
||||||
|
if parts:
|
||||||
|
rows.append(["Training Run ID", parts[0]])
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_from_path(checkpoint_path: str) -> "Checkpoint":
|
||||||
|
"""Get checkpoint details from a tinker path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: A tinker path like "tinker://run-id/weights/0001"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Checkpoint object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TinkerCliError: If the checkpoint cannot be retrieved
|
||||||
|
"""
|
||||||
|
# Lazy import
|
||||||
|
from tinker import ParsedCheckpointTinkerPath
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = ParsedCheckpointTinkerPath.from_tinker_path(checkpoint_path)
|
||||||
|
client = create_rest_client()
|
||||||
|
|
||||||
|
# Get the checkpoint info
|
||||||
|
checkpoints_response = client.list_checkpoints(parsed.training_run_id).result()
|
||||||
|
|
||||||
|
# Find the matching checkpoint
|
||||||
|
for ckpt in checkpoints_response.checkpoints:
|
||||||
|
if ckpt.tinker_path == checkpoint_path:
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
raise TinkerCliError(f"Checkpoint not found: {checkpoint_path}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Invalid checkpoint path: {e}",
|
||||||
|
"Checkpoint paths should be in the format: tinker://run-id/weights/0001",
|
||||||
|
)
|
||||||
|
except TinkerCliError:
|
||||||
|
# Re-raise our own errors
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise TinkerCliError(f"Failed to retrieve checkpoint: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# Click command group for checkpoint commands
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
"""Manage checkpoints."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.option("--run-id", help="Training run ID")
|
||||||
|
@click.option(
|
||||||
|
"--limit",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Maximum number of checkpoints to display when listing from all runs (default: 20, use --limit=0 to show all)",
|
||||||
|
)
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def list(cli_context: CLIContext, run_id: str | None, limit: int) -> None:
|
||||||
|
"""List checkpoints.
|
||||||
|
|
||||||
|
If --run-id is provided, list checkpoints for that specific training run.
|
||||||
|
Otherwise, list checkpoints from all recent runs.
|
||||||
|
"""
|
||||||
|
# Get format from context object
|
||||||
|
format = cli_context.format
|
||||||
|
|
||||||
|
# Create client
|
||||||
|
client = create_rest_client()
|
||||||
|
|
||||||
|
if run_id:
|
||||||
|
# List checkpoints for specific run.
|
||||||
|
# Note that there's no pagination for listing checkpoints on a single training run.
|
||||||
|
response = client.list_checkpoints(run_id).result()
|
||||||
|
|
||||||
|
# Create output object
|
||||||
|
output = CheckpointListOutput(checkpoints=response.checkpoints, run_id=run_id)
|
||||||
|
else:
|
||||||
|
# List checkpoints from all user's training runs using list_user_checkpoints()
|
||||||
|
all_checkpoints = []
|
||||||
|
offset = 0
|
||||||
|
# Fetch in batches of 1000 since the queries are so slow
|
||||||
|
BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
# First fetch to get initial data and total count
|
||||||
|
first_response = client.list_user_checkpoints(
|
||||||
|
limit=min(BATCH_SIZE, limit) if limit > 0 else BATCH_SIZE, offset=0
|
||||||
|
).result()
|
||||||
|
all_checkpoints.extend(first_response.checkpoints)
|
||||||
|
total_count = (
|
||||||
|
first_response.cursor.total_count
|
||||||
|
if first_response.cursor
|
||||||
|
else len(first_response.checkpoints)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine target count: either user-specified limit or total available
|
||||||
|
target_count = limit if limit > 0 else total_count
|
||||||
|
target_count = min(target_count, total_count) # Can't fetch more than exists
|
||||||
|
|
||||||
|
# If we need to fetch more checkpoints, paginate with a progress bar
|
||||||
|
if len(all_checkpoints) < target_count:
|
||||||
|
with click.progressbar(
|
||||||
|
length=target_count,
|
||||||
|
label=f"Fetching {'all' if limit == 0 else str(target_count)} checkpoints",
|
||||||
|
show_percent=True,
|
||||||
|
show_pos=True,
|
||||||
|
show_eta=True,
|
||||||
|
) as bar:
|
||||||
|
bar.update(len(all_checkpoints))
|
||||||
|
|
||||||
|
# Fetch remaining checkpoints in batches
|
||||||
|
while len(all_checkpoints) < target_count:
|
||||||
|
offset = len(all_checkpoints)
|
||||||
|
remaining = target_count - len(all_checkpoints)
|
||||||
|
next_batch_size = min(BATCH_SIZE, remaining)
|
||||||
|
|
||||||
|
response = client.list_user_checkpoints(
|
||||||
|
limit=next_batch_size, offset=offset
|
||||||
|
).result()
|
||||||
|
all_checkpoints.extend(response.checkpoints)
|
||||||
|
bar.update(len(response.checkpoints))
|
||||||
|
|
||||||
|
# Break if we got fewer than requested (reached the end)
|
||||||
|
if len(response.checkpoints) < next_batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create output object with pagination information
|
||||||
|
output = CheckpointListOutput(
|
||||||
|
checkpoints=all_checkpoints, total_count=total_count, shown_count=len(all_checkpoints)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print in requested format
|
||||||
|
output.print(format=format)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("checkpoint_path")
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def info(cli_context: CLIContext, checkpoint_path: str) -> None:
|
||||||
|
"""Show details of a specific checkpoint.
|
||||||
|
|
||||||
|
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
|
||||||
|
"""
|
||||||
|
# Get format from context object
|
||||||
|
format = cli_context.format
|
||||||
|
|
||||||
|
# Validate it's a tinker path
|
||||||
|
if not checkpoint_path.startswith("tinker://"):
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Invalid checkpoint path: {checkpoint_path}",
|
||||||
|
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = get_checkpoint_from_path(checkpoint_path)
|
||||||
|
|
||||||
|
# Create output object
|
||||||
|
output = CheckpointInfoOutput(checkpoint=checkpoint)
|
||||||
|
|
||||||
|
# Print in requested format
|
||||||
|
output.print(format=format)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("checkpoint_path")
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def publish(cli_context: CLIContext, checkpoint_path: str) -> None:
|
||||||
|
"""Publish a checkpoint to make it publicly accessible.
|
||||||
|
|
||||||
|
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
|
||||||
|
Only the owner of the training run can publish checkpoints.
|
||||||
|
"""
|
||||||
|
# Validate it's a tinker path
|
||||||
|
if not checkpoint_path.startswith("tinker://"):
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Invalid checkpoint path: {checkpoint_path}",
|
||||||
|
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create client and publish
|
||||||
|
client = create_rest_client()
|
||||||
|
client.publish_checkpoint_from_tinker_path(checkpoint_path).result()
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("checkpoint_path")
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def unpublish(cli_context: CLIContext, checkpoint_path: str) -> None:
|
||||||
|
"""Unpublish a checkpoint to make it private again.
|
||||||
|
|
||||||
|
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
|
||||||
|
Only the owner of the training run can unpublish checkpoints.
|
||||||
|
"""
|
||||||
|
# Validate it's a tinker path
|
||||||
|
if not checkpoint_path.startswith("tinker://"):
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Invalid checkpoint path: {checkpoint_path}",
|
||||||
|
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create client and unpublish
|
||||||
|
client = create_rest_client()
|
||||||
|
client.unpublish_checkpoint_from_tinker_path(checkpoint_path).result()
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("checkpoint_path")
|
||||||
|
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def delete(cli_context: CLIContext, checkpoint_path: str, yes: bool) -> None:
|
||||||
|
"""Delete a checkpoint permanently.
|
||||||
|
|
||||||
|
CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001).
|
||||||
|
Only the owner of the training run can delete checkpoints.
|
||||||
|
|
||||||
|
WARNING: This action is permanent and cannot be undone.
|
||||||
|
"""
|
||||||
|
# Validate it's a tinker path
|
||||||
|
if not checkpoint_path.startswith("tinker://"):
|
||||||
|
raise TinkerCliError(
|
||||||
|
f"Invalid checkpoint path: {checkpoint_path}",
|
||||||
|
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get format from context object
|
||||||
|
format = cli_context.format
|
||||||
|
|
||||||
|
# If not using --yes, show checkpoint info and prompt for confirmation
|
||||||
|
if not yes:
|
||||||
|
try:
|
||||||
|
checkpoint = get_checkpoint_from_path(checkpoint_path)
|
||||||
|
|
||||||
|
# Display checkpoint info using the same format as 'info' command
|
||||||
|
output = CheckpointInfoOutput(checkpoint)
|
||||||
|
output.print(format=format)
|
||||||
|
click.echo()
|
||||||
|
|
||||||
|
except TinkerCliError:
|
||||||
|
# If we can't get checkpoint info, still allow deletion attempt
|
||||||
|
# The API will return appropriate error if checkpoint doesn't exist
|
||||||
|
click.echo(f"Checkpoint path: {checkpoint_path}")
|
||||||
|
click.echo()
|
||||||
|
|
||||||
|
# Confirmation prompt
|
||||||
|
click.echo("WARNING: This action is permanent and cannot be undone.")
|
||||||
|
if not click.confirm("Are you sure you want to delete this checkpoint?"):
|
||||||
|
click.echo("Deletion cancelled.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create client and delete
|
||||||
|
client = create_rest_client()
|
||||||
|
client.delete_checkpoint_from_tinker_path(checkpoint_path).result()
|
||||||
257
src/tinker/cli/commands/run.py
Normal file
257
src/tinker/cli/commands/run.py
Normal file
|
|
@ -0,0 +1,257 @@
|
||||||
|
"""Commands for managing training runs.
|
||||||
|
|
||||||
|
This module implements the 'tinker run' commands, including:
|
||||||
|
- list: List all training runs
|
||||||
|
- info: Show details of a specific run
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, TYPE_CHECKING
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tinker.types import TrainingRun
|
||||||
|
|
||||||
|
from ..client import create_rest_client, handle_api_errors
|
||||||
|
from ..context import CLIContext
|
||||||
|
from ..output import OutputBase, format_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class RunListOutput(OutputBase):
|
||||||
|
"""Output for 'tinker run list' command."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
runs: List["TrainingRun"],
|
||||||
|
total_count: int | None = None,
|
||||||
|
shown_count: int | None = None,
|
||||||
|
):
|
||||||
|
"""Initialize with list of training runs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runs: List of TrainingRun objects
|
||||||
|
total_count: Total number of runs available (from cursor)
|
||||||
|
shown_count: Number of runs shown in this response
|
||||||
|
"""
|
||||||
|
self.runs = runs
|
||||||
|
self.total_count = total_count
|
||||||
|
self.shown_count = shown_count if shown_count is not None else len(runs)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON output."""
|
||||||
|
# Check if these are Pydantic models
|
||||||
|
if self.runs and hasattr(self.runs[0], "model_dump"):
|
||||||
|
return {"runs": [run.model_dump() for run in self.runs]}
|
||||||
|
return {"runs": [dict(run) for run in self.runs]}
|
||||||
|
|
||||||
|
def get_title(self) -> str | None:
|
||||||
|
"""Return title for table output."""
|
||||||
|
count = len(self.runs)
|
||||||
|
if count == 0:
|
||||||
|
return "No training runs found"
|
||||||
|
|
||||||
|
# Build the base title
|
||||||
|
if count == 1:
|
||||||
|
title = "1 training run"
|
||||||
|
else:
|
||||||
|
title = f"{count} training runs"
|
||||||
|
|
||||||
|
# Add information about remaining runs if available
|
||||||
|
if self.total_count is not None and self.total_count > self.shown_count:
|
||||||
|
remaining = self.total_count - self.shown_count
|
||||||
|
if remaining == 1:
|
||||||
|
title += f" (1 more not shown, use --limit to see more)"
|
||||||
|
else:
|
||||||
|
title += f" ({remaining} more not shown, use --limit to see more)"
|
||||||
|
|
||||||
|
return title
|
||||||
|
|
||||||
|
def get_table_columns(self) -> List[str]:
|
||||||
|
"""Return column headers for table output."""
|
||||||
|
return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"]
|
||||||
|
|
||||||
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
|
"""Return rows for table output."""
|
||||||
|
rows = []
|
||||||
|
for run in self.runs:
|
||||||
|
# Format LoRA information
|
||||||
|
if run.is_lora and run.lora_rank:
|
||||||
|
lora_info = f"Rank {run.lora_rank}"
|
||||||
|
elif run.is_lora:
|
||||||
|
lora_info = "Yes"
|
||||||
|
else:
|
||||||
|
lora_info = "No"
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
[
|
||||||
|
run.training_run_id,
|
||||||
|
run.base_model,
|
||||||
|
run.model_owner,
|
||||||
|
lora_info,
|
||||||
|
format_timestamp(run.last_request_time),
|
||||||
|
str(run.corrupted),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
class RunInfoOutput(OutputBase):
|
||||||
|
"""Output for 'tinker run info' command."""
|
||||||
|
|
||||||
|
def __init__(self, run: "TrainingRun"):
|
||||||
|
"""Initialize with a single training run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run: TrainingRun object
|
||||||
|
"""
|
||||||
|
self.run = run
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for JSON output."""
|
||||||
|
if hasattr(self.run, "model_dump"):
|
||||||
|
return self.run.model_dump()
|
||||||
|
return dict(self.run)
|
||||||
|
|
||||||
|
def get_title(self) -> str | None:
|
||||||
|
"""Return title for table output."""
|
||||||
|
return f"Training Run: {self.run.training_run_id}"
|
||||||
|
|
||||||
|
def get_table_columns(self) -> List[str]:
|
||||||
|
"""Return column headers for table output."""
|
||||||
|
return ["Property", "Value"]
|
||||||
|
|
||||||
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
|
"""Return rows for table output."""
|
||||||
|
rows = [
|
||||||
|
["Run ID", self.run.training_run_id],
|
||||||
|
["Base Model", self.run.base_model],
|
||||||
|
["Owner", self.run.model_owner],
|
||||||
|
]
|
||||||
|
|
||||||
|
# LoRA information
|
||||||
|
if self.run.is_lora:
|
||||||
|
if self.run.lora_rank:
|
||||||
|
rows.append(["LoRA", f"Yes (Rank {self.run.lora_rank})"])
|
||||||
|
else:
|
||||||
|
rows.append(["LoRA", "Yes"])
|
||||||
|
else:
|
||||||
|
rows.append(["LoRA", "No"])
|
||||||
|
|
||||||
|
# Last update time
|
||||||
|
rows.append(["Last Update", format_timestamp(self.run.last_request_time)])
|
||||||
|
|
||||||
|
# Corruption status
|
||||||
|
rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"])
|
||||||
|
|
||||||
|
# Last checkpoints
|
||||||
|
if self.run.last_checkpoint:
|
||||||
|
rows.append(["Last Training Checkpoint", self.run.last_checkpoint.checkpoint_id])
|
||||||
|
rows.append([" - Time", format_timestamp(self.run.last_checkpoint.time)])
|
||||||
|
rows.append([" - Path", self.run.last_checkpoint.tinker_path])
|
||||||
|
|
||||||
|
if self.run.last_sampler_checkpoint:
|
||||||
|
rows.append(["Last Sampler Checkpoint", self.run.last_sampler_checkpoint.checkpoint_id])
|
||||||
|
rows.append([" - Time", format_timestamp(self.run.last_sampler_checkpoint.time)])
|
||||||
|
rows.append([" - Path", self.run.last_sampler_checkpoint.tinker_path])
|
||||||
|
|
||||||
|
# User metadata if present
|
||||||
|
if self.run.user_metadata:
|
||||||
|
rows.append(["Metadata", ""])
|
||||||
|
for key, value in self.run.user_metadata.items():
|
||||||
|
rows.append([f" - {key}", value])
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
# Click command group for run commands
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
"""Manage training runs."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.option(
|
||||||
|
"--limit",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Maximum number of runs to fetch (default: 20, use --limit=0 to fetch all)",
|
||||||
|
)
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def list(cli_context: CLIContext, limit: int) -> None:
|
||||||
|
"""List all training runs."""
|
||||||
|
# Get format from context object
|
||||||
|
format = cli_context.format
|
||||||
|
|
||||||
|
# Create client
|
||||||
|
client = create_rest_client()
|
||||||
|
|
||||||
|
all_runs = []
|
||||||
|
offset = 0
|
||||||
|
batch_size = 100 # Fetch in batches of 100 for efficiency
|
||||||
|
|
||||||
|
# First fetch to get initial data and total count
|
||||||
|
first_response = client.list_training_runs(
|
||||||
|
limit=min(batch_size, limit) if limit > 0 else batch_size, offset=0
|
||||||
|
).result()
|
||||||
|
all_runs.extend(first_response.training_runs)
|
||||||
|
total_count = first_response.cursor.total_count
|
||||||
|
|
||||||
|
# Determine target count: either user-specified limit or total available
|
||||||
|
target_count = limit if limit > 0 else total_count
|
||||||
|
target_count = min(target_count, total_count) # Can't fetch more than exists
|
||||||
|
|
||||||
|
# If we need to fetch more runs, paginate with a progress bar
|
||||||
|
if len(all_runs) < target_count:
|
||||||
|
with click.progressbar(
|
||||||
|
length=target_count,
|
||||||
|
label=f"Fetching {'all' if limit == 0 else str(target_count)} training runs",
|
||||||
|
show_percent=True,
|
||||||
|
show_pos=True,
|
||||||
|
show_eta=True,
|
||||||
|
) as bar:
|
||||||
|
bar.update(len(all_runs))
|
||||||
|
|
||||||
|
# Fetch remaining runs in batches
|
||||||
|
while len(all_runs) < target_count:
|
||||||
|
offset = len(all_runs)
|
||||||
|
remaining = target_count - len(all_runs)
|
||||||
|
next_batch_size = min(batch_size, remaining)
|
||||||
|
|
||||||
|
response = client.list_training_runs(limit=next_batch_size, offset=offset).result()
|
||||||
|
all_runs.extend(response.training_runs)
|
||||||
|
bar.update(len(response.training_runs))
|
||||||
|
|
||||||
|
# Break if we got fewer than requested (reached the end)
|
||||||
|
if len(response.training_runs) < next_batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Create output object with pagination information
|
||||||
|
output = RunListOutput(runs=all_runs, total_count=total_count, shown_count=len(all_runs))
|
||||||
|
|
||||||
|
# Print in requested format
|
||||||
|
output.print(format=format)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("run_id")
|
||||||
|
@click.pass_obj
|
||||||
|
@handle_api_errors
|
||||||
|
def info(cli_context: CLIContext, run_id: str) -> None:
|
||||||
|
"""Show details of a specific run."""
|
||||||
|
# Get format from context object
|
||||||
|
format = cli_context.format
|
||||||
|
|
||||||
|
# Create client
|
||||||
|
client = create_rest_client()
|
||||||
|
|
||||||
|
# Fetch training run details
|
||||||
|
response = client.get_training_run(run_id).result()
|
||||||
|
|
||||||
|
# Create output object
|
||||||
|
output = RunInfoOutput(run=response)
|
||||||
|
|
||||||
|
# Print in requested format
|
||||||
|
output.print(format=format)
|
||||||
18
src/tinker/cli/commands/version.py
Normal file
18
src/tinker/cli/commands/version.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""Command for showing version information.
|
||||||
|
|
||||||
|
This module implements the 'tinker version' command.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
def cli():
|
||||||
|
"""Show version information."""
|
||||||
|
try:
|
||||||
|
# Lazy import version only when needed
|
||||||
|
from tinker._version import __version__
|
||||||
|
|
||||||
|
click.echo(f"tinker {__version__}")
|
||||||
|
except ImportError:
|
||||||
|
click.echo("tinker (version unavailable)")
|
||||||
23
src/tinker/cli/context.py
Normal file
23
src/tinker/cli/context.py
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
"""Context object for the Tinker CLI.
|
||||||
|
|
||||||
|
This module provides a dataclass for sharing configuration and state
|
||||||
|
between CLI commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIContext:
|
||||||
|
"""Context object for sharing state between CLI commands.
|
||||||
|
|
||||||
|
This dataclass is passed through the Click command hierarchy
|
||||||
|
using the @click.pass_obj decorator, allowing commands to access
|
||||||
|
shared configuration without needing to traverse the context tree.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
format: Output format for command results ('table' or 'json')
|
||||||
|
"""
|
||||||
|
|
||||||
|
format: Literal["table", "json"] = "table"
|
||||||
31
src/tinker/cli/exceptions.py
Normal file
31
src/tinker/cli/exceptions.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
"""Custom exceptions for the Tinker CLI.
|
||||||
|
|
||||||
|
This module defines exceptions used throughout the CLI for consistent
|
||||||
|
error handling and graceful exits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TinkerCliError(Exception):
|
||||||
|
"""Custom exception for CLI errors that should exit gracefully.
|
||||||
|
|
||||||
|
This exception is caught at the top level of the CLI and converted
|
||||||
|
to appropriate error messages and exit codes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
message: The main error message to display
|
||||||
|
details: Optional additional details or suggestions
|
||||||
|
exit_code: The exit code to use (default: 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: str | None = None, exit_code: int = 1):
|
||||||
|
"""Initialize a TinkerCliError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The main error message (will be prefixed with "Error: ")
|
||||||
|
details: Optional additional details or help text
|
||||||
|
exit_code: The exit code to use when exiting (default: 1)
|
||||||
|
"""
|
||||||
|
self.message = message
|
||||||
|
self.details = details
|
||||||
|
self.exit_code = exit_code
|
||||||
|
super().__init__(message)
|
||||||
89
src/tinker/cli/lazy_group.py
Normal file
89
src/tinker/cli/lazy_group.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""Lazy loading support for Click commands.
|
||||||
|
|
||||||
|
This module provides a LazyGroup class that extends Click's Group to support
|
||||||
|
lazy loading of subcommands, ensuring fast CLI startup times.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
class LazyGroup(click.Group):
|
||||||
|
"""A Click Group that supports lazy loading of subcommands.
|
||||||
|
|
||||||
|
This allows the CLI to have fast startup times by only importing
|
||||||
|
command modules when they are actually invoked, not when the CLI
|
||||||
|
is first loaded or when help is displayed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *args: Any, lazy_subcommands: Dict[str, str] | None = None, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the LazyGroup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lazy_subcommands: A dictionary mapping command names to import paths.
|
||||||
|
Format: {"command": "module.path:attribute_name"}
|
||||||
|
"""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lazy_subcommands = lazy_subcommands or {}
|
||||||
|
|
||||||
|
def list_commands(self, ctx: click.Context) -> List[str]:
|
||||||
|
"""Return a list of all command names.
|
||||||
|
|
||||||
|
This includes both eagerly loaded commands and lazy commands.
|
||||||
|
"""
|
||||||
|
# Get any eagerly loaded commands
|
||||||
|
base = super().list_commands(ctx)
|
||||||
|
# Add lazy command names
|
||||||
|
lazy = sorted(self.lazy_subcommands.keys())
|
||||||
|
return base + lazy
|
||||||
|
|
||||||
|
def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
|
||||||
|
"""Get a command by name, loading it lazily if necessary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: The Click context
|
||||||
|
cmd_name: The name of the command to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The Click command object, or None if not found
|
||||||
|
"""
|
||||||
|
# Check if it's a lazy command
|
||||||
|
if cmd_name in self.lazy_subcommands:
|
||||||
|
return self._lazy_load(cmd_name)
|
||||||
|
# Fall back to normal command loading
|
||||||
|
return super().get_command(ctx, cmd_name)
|
||||||
|
|
||||||
|
def _lazy_load(self, cmd_name: str) -> click.Command:
|
||||||
|
"""Lazily load a command by importing its module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cmd_name: The name of the command to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded Click command object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the imported object is not a Click Command
|
||||||
|
"""
|
||||||
|
# Get the import path for this command
|
||||||
|
import_path = self.lazy_subcommands[cmd_name]
|
||||||
|
|
||||||
|
# Split into module path and attribute name
|
||||||
|
module_name, attr_name = import_path.rsplit(":", 1)
|
||||||
|
|
||||||
|
# Import the module
|
||||||
|
mod = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Get the command object
|
||||||
|
cmd_object = getattr(mod, attr_name)
|
||||||
|
|
||||||
|
# Verify it's a Click command
|
||||||
|
if not isinstance(cmd_object, click.Command):
|
||||||
|
raise ValueError(
|
||||||
|
f"Lazy loading of {import_path} failed: '{attr_name}' is not a Click Command"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cmd_object
|
||||||
227
src/tinker/cli/output.py
Normal file
227
src/tinker/cli/output.py
Normal file
|
|
@ -0,0 +1,227 @@
|
||||||
|
"""Output formatting utilities for the Tinker CLI.
|
||||||
|
|
||||||
|
This module provides a base class for structured output that can be
|
||||||
|
rendered as either a table (using rich) or as JSON. Each command
|
||||||
|
defines its own output class that inherits from OutputBase.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Callable, Union
|
||||||
|
|
||||||
|
|
||||||
|
class OutputBase(ABC):
|
||||||
|
"""Virtual base class for all command outputs.
|
||||||
|
|
||||||
|
Subclasses must implement methods to convert data to various formats.
|
||||||
|
The base class provides the common print() method that handles
|
||||||
|
format selection and rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert output to dictionary for JSON serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation of the output data
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_table_columns(self) -> List[str]:
|
||||||
|
"""Return list of column names for table output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of column header strings
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_table_rows(self) -> List[List[str]]:
|
||||||
|
"""Return list of rows for table output.
|
||||||
|
|
||||||
|
Each row should be a list of string values corresponding
|
||||||
|
to the columns returned by get_table_columns().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of rows, where each row is a list of string values
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_title(self) -> str | None:
|
||||||
|
"""Optional title for the output display.
|
||||||
|
|
||||||
|
Override this method to provide a title for the table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Title string or None
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def print(self, format: str = "table") -> None:
|
||||||
|
"""Print the output in the specified format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
format: Output format - either "table" or "json"
|
||||||
|
"""
|
||||||
|
if format == "json":
|
||||||
|
self._print_json()
|
||||||
|
else:
|
||||||
|
self._print_table()
|
||||||
|
|
||||||
|
def _print_json(self) -> None:
|
||||||
|
"""Print output as JSON."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = self.to_dict()
|
||||||
|
json.dump(data, sys.stdout, indent=2, default=str)
|
||||||
|
print() # Add newline after JSON output
|
||||||
|
|
||||||
|
def _print_table(self) -> None:
|
||||||
|
"""Print output as a rich table."""
|
||||||
|
# Lazy import rich to avoid slow startup
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# Create table with optional title
|
||||||
|
title = self.get_title()
|
||||||
|
table = Table(title=title) if title else Table()
|
||||||
|
|
||||||
|
# Add columns
|
||||||
|
columns = self.get_table_columns()
|
||||||
|
for col in columns:
|
||||||
|
# First column (usually ID) gets special styling
|
||||||
|
if col == columns[0]:
|
||||||
|
table.add_column(col, style="bright_cyan", no_wrap=True)
|
||||||
|
else:
|
||||||
|
table.add_column(col)
|
||||||
|
|
||||||
|
# Add rows
|
||||||
|
rows = self.get_table_rows()
|
||||||
|
for row in rows:
|
||||||
|
table.add_row(*row)
|
||||||
|
|
||||||
|
# Print the table
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
# Utility formatting functions
|
||||||
|
|
||||||
|
|
||||||
|
def format_size(bytes: int) -> str:
|
||||||
|
"""Format bytes as human-readable size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bytes: Size in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable size string (e.g., "1.2 GB")
|
||||||
|
"""
|
||||||
|
if bytes < 0:
|
||||||
|
return "N/A"
|
||||||
|
|
||||||
|
size = float(bytes)
|
||||||
|
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
||||||
|
if size < 1024.0:
|
||||||
|
if unit == "B":
|
||||||
|
return f"{int(size)} {unit}"
|
||||||
|
return f"{size:.1f} {unit}"
|
||||||
|
size /= 1024.0
|
||||||
|
|
||||||
|
return f"{size:.1f} EB"
|
||||||
|
|
||||||
|
|
||||||
|
def format_timestamp(dt: Union[datetime, str, None]) -> str:
|
||||||
|
"""Format datetime as relative time or absolute date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dt: datetime object, ISO string, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted time string (e.g., "2 hours ago", "2024-01-15")
|
||||||
|
"""
|
||||||
|
if not dt:
|
||||||
|
return "N/A"
|
||||||
|
|
||||||
|
# Lazy import datetime
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
# Handle different input types
|
||||||
|
if isinstance(dt, str):
|
||||||
|
# Try to parse ISO format string
|
||||||
|
try:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
dt = datetime.fromisoformat(dt.replace("Z", "+00:00"))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return str(dt)
|
||||||
|
|
||||||
|
if not hasattr(dt, "replace"):
|
||||||
|
# Not a datetime object
|
||||||
|
return str(dt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get current time
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Ensure dt has timezone info
|
||||||
|
if dt.tzinfo is None:
|
||||||
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
|
else:
|
||||||
|
dt = dt.astimezone(timezone.utc)
|
||||||
|
|
||||||
|
# Calculate time difference
|
||||||
|
delta = now - dt
|
||||||
|
|
||||||
|
# Format based on age
|
||||||
|
if delta.days > 30:
|
||||||
|
return dt.strftime("%Y-%m-%d")
|
||||||
|
elif delta.days > 7:
|
||||||
|
weeks = delta.days // 7
|
||||||
|
return f"{weeks} week{'s' if weeks > 1 else ''} ago"
|
||||||
|
elif delta.days > 0:
|
||||||
|
return f"{delta.days} day{'s' if delta.days > 1 else ''} ago"
|
||||||
|
elif delta.seconds > 3600:
|
||||||
|
hours = delta.seconds // 3600
|
||||||
|
return f"{hours} hour{'s' if hours > 1 else ''} ago"
|
||||||
|
elif delta.seconds > 60:
|
||||||
|
minutes = delta.seconds // 60
|
||||||
|
return f"{minutes} minute{'s' if minutes > 1 else ''} ago"
|
||||||
|
else:
|
||||||
|
return "just now"
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If any error occurs, just return string representation
|
||||||
|
return str(dt)
|
||||||
|
|
||||||
|
|
||||||
|
def format_bool(value: bool) -> str:
|
||||||
|
"""Format boolean for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Boolean value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"Yes" or "No"
|
||||||
|
"""
|
||||||
|
return "Yes" if value else "No"
|
||||||
|
|
||||||
|
|
||||||
|
def format_optional(value: Any, formatter: Callable[[Any], str] | None = None) -> str:
|
||||||
|
"""Format an optional value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to format (may be None)
|
||||||
|
formatter: Optional formatting function to apply if value is not None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string or "N/A" if value is None
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return "N/A"
|
||||||
|
if formatter:
|
||||||
|
return formatter(value)
|
||||||
|
return str(value)
|
||||||
|
|
@ -31,9 +31,12 @@ class RestClient(TelemetryProvider):
|
||||||
|
|
||||||
Key methods:
|
Key methods:
|
||||||
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
- list_checkpoints() - list available model checkpoints (both training and sampler)
|
||||||
|
- list_user_checkpoints() - list all checkpoints across all user's training runs
|
||||||
- get_training_run() - get model information and metadata as ModelEntry
|
- get_training_run() - get model information and metadata as ModelEntry
|
||||||
- delete_checkpoint() - delete an existing checkpoint for a training run
|
- delete_checkpoint() - delete an existing checkpoint for a training run
|
||||||
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
||||||
|
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
|
||||||
|
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
holder: Internal client managing HTTP connections and async operations
|
holder: Internal client managing HTTP connections and async operations
|
||||||
|
|
@ -420,3 +423,227 @@ class RestClient(TelemetryProvider):
|
||||||
return await self._get_checkpoint_archive_url_submit(
|
return await self._get_checkpoint_archive_url_submit(
|
||||||
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _publish_checkpoint_submit(
|
||||||
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
|
) -> AwaitableConcurrentFuture[None]:
|
||||||
|
"""Internal method to submit publish checkpoint request."""
|
||||||
|
|
||||||
|
async def _publish_checkpoint_async() -> None:
|
||||||
|
async def _send_request() -> None:
|
||||||
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
await client.post(
|
||||||
|
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
|
||||||
|
cast_to=NoneType,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.holder.execute_with_retries(_send_request)
|
||||||
|
|
||||||
|
return self.holder.run_coroutine_threadsafe(_publish_checkpoint_async())
|
||||||
|
|
||||||
|
@sync_only
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
||||||
|
"""Publish a checkpoint referenced by a tinker path to make it publicly accessible.
|
||||||
|
|
||||||
|
Only the exact owner of the training run can publish checkpoints.
|
||||||
|
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Future that completes when the checkpoint is published
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if checkpoint identifier is invalid
|
||||||
|
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||||
|
HTTPException: 409 if checkpoint is already public
|
||||||
|
HTTPException: 500 if there's an error publishing the checkpoint
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
|
||||||
|
>>> future.result() # Wait for completion
|
||||||
|
>>> print("Checkpoint published successfully")
|
||||||
|
"""
|
||||||
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
return self._publish_checkpoint_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
).future()
|
||||||
|
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
||||||
|
"""Async version of publish_checkpoint_from_tinker_path.
|
||||||
|
|
||||||
|
Only the exact owner of the training run can publish checkpoints.
|
||||||
|
Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path_async method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if checkpoint identifier is invalid
|
||||||
|
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||||
|
HTTPException: 409 if checkpoint is already public
|
||||||
|
HTTPException: 500 if there's an error publishing the checkpoint
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> await rest_client.publish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001")
|
||||||
|
>>> print("Checkpoint published successfully")
|
||||||
|
"""
|
||||||
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
await self._publish_checkpoint_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _unpublish_checkpoint_submit(
|
||||||
|
self, training_run_id: types.ModelID, checkpoint_id: str
|
||||||
|
) -> AwaitableConcurrentFuture[None]:
|
||||||
|
"""Internal method to submit unpublish checkpoint request."""
|
||||||
|
|
||||||
|
async def _unpublish_checkpoint_async() -> None:
|
||||||
|
async def _send_request() -> None:
|
||||||
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
await client.delete(
|
||||||
|
f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish",
|
||||||
|
cast_to=NoneType,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.holder.execute_with_retries(_send_request)
|
||||||
|
|
||||||
|
return self.holder.run_coroutine_threadsafe(_unpublish_checkpoint_async())
|
||||||
|
|
||||||
|
@sync_only
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
|
||||||
|
"""Unpublish a checkpoint referenced by a tinker path to make it private again.
|
||||||
|
|
||||||
|
Only the exact owner of the training run can unpublish checkpoints.
|
||||||
|
This reverses the effect of publishing a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Future that completes when the checkpoint is unpublished
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if checkpoint identifier is invalid
|
||||||
|
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||||
|
HTTPException: 409 if checkpoint is already private
|
||||||
|
HTTPException: 500 if there's an error unpublishing the checkpoint
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001")
|
||||||
|
>>> future.result() # Wait for completion
|
||||||
|
>>> print("Checkpoint unpublished successfully")
|
||||||
|
"""
|
||||||
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
return self._unpublish_checkpoint_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
).future()
|
||||||
|
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
|
||||||
|
"""Async version of unpublish_checkpoint_from_tinker_path.
|
||||||
|
|
||||||
|
Only the exact owner of the training run can unpublish checkpoints.
|
||||||
|
This reverses the effect of publishing a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 400 if checkpoint identifier is invalid
|
||||||
|
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||||
|
HTTPException: 409 if checkpoint is already private
|
||||||
|
HTTPException: 500 if there's an error unpublishing the checkpoint
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> await rest_client.unpublish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001")
|
||||||
|
>>> print("Checkpoint unpublished successfully")
|
||||||
|
"""
|
||||||
|
parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path)
|
||||||
|
await self._unpublish_checkpoint_submit(
|
||||||
|
parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _list_user_checkpoints_submit(
|
||||||
|
self, limit: int = 100, offset: int = 0
|
||||||
|
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
|
||||||
|
"""Internal method to submit list user checkpoints request."""
|
||||||
|
|
||||||
|
async def _list_user_checkpoints_async() -> types.CheckpointsListResponse:
|
||||||
|
async def _send_request() -> types.CheckpointsListResponse:
|
||||||
|
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||||
|
params: dict[str, object] = {"limit": limit, "offset": offset}
|
||||||
|
|
||||||
|
return await client.get(
|
||||||
|
"/api/v1/checkpoints",
|
||||||
|
options={"params": params},
|
||||||
|
cast_to=types.CheckpointsListResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.holder.execute_with_retries(_send_request)
|
||||||
|
|
||||||
|
return self.holder.run_coroutine_threadsafe(_list_user_checkpoints_async())
|
||||||
|
|
||||||
|
@sync_only
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
def list_user_checkpoints(
|
||||||
|
self, limit: int = 100, offset: int = 0
|
||||||
|
) -> ConcurrentFuture[types.CheckpointsListResponse]:
|
||||||
|
"""List all checkpoints for the current user across all their training runs.
|
||||||
|
|
||||||
|
This method retrieves checkpoints from all training runs owned by the authenticated user,
|
||||||
|
sorted by time (newest first). It supports pagination for efficiently handling large
|
||||||
|
numbers of checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of checkpoints to return (default 100)
|
||||||
|
offset: Offset for pagination (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Future containing the CheckpointsListResponse with checkpoints and cursor info
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> future = rest_client.list_user_checkpoints(limit=50)
|
||||||
|
>>> response = future.result()
|
||||||
|
>>> print(f"Found {len(response.checkpoints)} checkpoints")
|
||||||
|
>>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
|
||||||
|
>>> for checkpoint in response.checkpoints:
|
||||||
|
... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
|
||||||
|
>>> # Get next page if there are more checkpoints
|
||||||
|
>>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
|
||||||
|
... next_page = rest_client.list_user_checkpoints(limit=50, offset=50)
|
||||||
|
"""
|
||||||
|
return self._list_user_checkpoints_submit(limit, offset).future()
|
||||||
|
|
||||||
|
@capture_exceptions(fatal=True)
|
||||||
|
async def list_user_checkpoints_async(
|
||||||
|
self, limit: int = 100, offset: int = 0
|
||||||
|
) -> types.CheckpointsListResponse:
|
||||||
|
"""Async version of list_user_checkpoints.
|
||||||
|
|
||||||
|
This method retrieves checkpoints from all training runs owned by the authenticated user,
|
||||||
|
sorted by time (newest first). It supports pagination for efficiently handling large
|
||||||
|
numbers of checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of checkpoints to return (default 100)
|
||||||
|
offset: Offset for pagination (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CheckpointsListResponse with checkpoints and cursor info
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> response = await rest_client.list_user_checkpoints_async(limit=50)
|
||||||
|
>>> print(f"Found {len(response.checkpoints)} checkpoints")
|
||||||
|
>>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}")
|
||||||
|
>>> for checkpoint in response.checkpoints:
|
||||||
|
... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}")
|
||||||
|
>>> # Get next page if there are more checkpoints
|
||||||
|
>>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count:
|
||||||
|
... next_page = await rest_client.list_user_checkpoints_async(limit=50, offset=50)
|
||||||
|
"""
|
||||||
|
return await self._list_user_checkpoints_submit(limit, offset)
|
||||||
|
|
|
||||||
|
|
@ -593,6 +593,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
|
||||||
assert model_name is not None, "This shouldn't happen: model_name is None"
|
assert model_name is not None, "This shouldn't happen: model_name is None"
|
||||||
|
|
||||||
# Use tokenizer_id from get_info if available, otherwise fall back to heuristic logic
|
# Use tokenizer_id from get_info if available, otherwise fall back to heuristic logic
|
||||||
|
kwargs = {}
|
||||||
tokenizer_id = info.model_data.tokenizer_id
|
tokenizer_id = info.model_data.tokenizer_id
|
||||||
if tokenizer_id is None:
|
if tokenizer_id is None:
|
||||||
# We generally adhere to the huggingface convention of "<org>/<model>" but
|
# We generally adhere to the huggingface convention of "<org>/<model>" but
|
||||||
|
|
@ -608,4 +609,10 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre
|
||||||
else:
|
else:
|
||||||
tokenizer_id = model_name
|
tokenizer_id = model_name
|
||||||
|
|
||||||
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True)
|
if tokenizer_id == "moonshotai/Kimi-K2-Thinking":
|
||||||
|
kwargs = {
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"revision": "612681931a8c906ddb349f8ad0f582cb552189cd",
|
||||||
|
}
|
||||||
|
|
||||||
|
return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -185,7 +185,7 @@ class AsyncWeightsResource(AsyncAPIResource):
|
||||||
if not model_id:
|
if not model_id:
|
||||||
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}")
|
||||||
return await self._get(
|
return await self._get(
|
||||||
f"/api/v1/models/{model_id}/checkpoints",
|
f"/api/v1/training_runs/{model_id}/checkpoints",
|
||||||
options=make_request_options(
|
options=make_request_options(
|
||||||
extra_headers=extra_headers,
|
extra_headers=extra_headers,
|
||||||
extra_query=extra_query,
|
extra_query=extra_query,
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,12 @@ class Checkpoint(BaseModel):
|
||||||
tinker_path: str
|
tinker_path: str
|
||||||
"""The tinker path to the checkpoint"""
|
"""The tinker path to the checkpoint"""
|
||||||
|
|
||||||
|
size_bytes: int | None = None
|
||||||
|
"""The size of the checkpoint in bytes"""
|
||||||
|
|
||||||
|
public: bool = False
|
||||||
|
"""Whether the checkpoint is publicly accessible"""
|
||||||
|
|
||||||
|
|
||||||
class ParsedCheckpointTinkerPath(BaseModel):
|
class ParsedCheckpointTinkerPath(BaseModel):
|
||||||
tinker_path: str
|
tinker_path: str
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from .._models import BaseModel
|
from .._models import BaseModel
|
||||||
from .checkpoint import Checkpoint
|
from .checkpoint import Checkpoint
|
||||||
|
from .cursor import Cursor
|
||||||
|
|
||||||
__all__ = ["CheckpointsListResponse"]
|
__all__ = ["CheckpointsListResponse"]
|
||||||
|
|
||||||
|
|
@ -7,3 +8,6 @@ __all__ = ["CheckpointsListResponse"]
|
||||||
class CheckpointsListResponse(BaseModel):
|
class CheckpointsListResponse(BaseModel):
|
||||||
checkpoints: list[Checkpoint]
|
checkpoints: list[Checkpoint]
|
||||||
"""List of available model checkpoints for the model"""
|
"""List of available model checkpoints for the model"""
|
||||||
|
|
||||||
|
cursor: Cursor | None = None
|
||||||
|
"""Pagination cursor information (None for unpaginated responses)"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue