diff --git a/.sync_state b/.sync_state index 4ec93c3..a47c376 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "d973e1d6eede81b853a167e8f999f43402e07c3a", - "last_sync_time": "2025-11-11T05:56:15.874542" + "last_synced_sha": "31ae0341ea2c6122c46c0597c5a862440a9cf31e", + "last_sync_time": "2025-11-18T22:06:38.289689" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d0c9812..2acc3ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ dependencies = [ "sniffio", "numpy", "torch", + "rich>=13.0.0", + "click>=8.0.0", ] requires-python = ">= 3.11" classifiers = [ @@ -36,6 +38,9 @@ classifiers = [ "License :: OSI Approved :: Apache Software License" ] +[project.scripts] +tinker = "tinker.cli.__main__:cli" + [project.urls] Homepage = "https://thinkingmachines.ai/tinker" Repository = "https://github.com/thinking-machines-lab/tinker" diff --git a/src/tinker/cli/CLAUDE.md b/src/tinker/cli/CLAUDE.md new file mode 100644 index 0000000..f321106 --- /dev/null +++ b/src/tinker/cli/CLAUDE.md @@ -0,0 +1,281 @@ +# Tinker CLI Design Documentation + +## Overview + +The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance. + +## Key Design Decisions + +### 1. Lazy Import Strategy with Click + +**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level. + +**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands. + +**Implementation**: +- Main `__init__.py` only imports `click` and `lazy_group` +- Command modules are loaded only when invoked via `LazyGroup` +- Output formatting imports `rich` only when table output is needed +- JSON module imported only when JSON output is requested +- Version information loaded from `_version.py` only when `tinker version` is used + +### 2. Click Framework with LazyGroup + +**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading. + +**Rationale**: +- Click provides cleaner command structure with decorators +- Better subcommand isolation - each command file is self-contained +- Automatic help generation with better formatting +- Built-in type conversion and validation +- LazyGroup enables fast startup by deferring imports + +**LazyGroup Implementation**: +```python +class LazyGroup(click.Group): + def __init__(self, *args, lazy_subcommands=None, **kwargs): + # Map of command name to "module.path:command_name" + self.lazy_subcommands = lazy_subcommands or {} + + def get_command(self, ctx, cmd_name): + if cmd_name in self.lazy_subcommands: + # Import only when command is actually invoked + import_path = self.lazy_subcommands[cmd_name] + module_name, attr_name = import_path.rsplit(":", 1) + mod = importlib.import_module(module_name) + return getattr(mod, attr_name) +``` + +### 3. Hierarchical Command Structure + +**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`. + +**Rationale**: +- Provides a consistent, predictable interface +- Groups related functionality together +- Makes the CLI extensible for future commands +- Follows common CLI patterns (like `git`, `docker`, etc.) + +**Examples**: +- `tinker version` - Show CLI and SDK version +- `tinker run list` - List all training runs +- `tinker run info ` - Show details of a specific run +- `tinker checkpoint list` - List all checkpoints +- `tinker checkpoint info ` - Show checkpoint details + +### 4. Output System with Inheritance + +**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class. + +**Rationale**: +- Enforces consistent interface across all commands +- Encapsulates output logic with the command that generates it +- Makes it easy to support multiple output formats (table, JSON) +- Keeps related code together in the same module + +**Implementation**: +- `OutputBase` in `output.py` defines the contract +- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`) +- Base class handles format selection and rendering + +### 5. Self-Contained Command Modules + +**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point. + +**Rationale**: +- Modular architecture - commands can be developed independently +- Clear separation of concerns +- Easy to add new commands without modifying core files +- Consistent pattern across all commands + +**Command Structure**: +```python +# Each command file follows this pattern: +@click.group() # or @click.command() for simple commands +def cli(): + """Command description.""" + pass + +@cli.command() # For subcommands +def list(): + """Subcommand implementation.""" + pass +``` + +### 6. Centralized Client Management + +**Decision**: All SDK client creation and error handling is centralized in `client.py`. + +**Rationale**: +- Single place to handle authentication and connection errors +- Consistent error messages across all commands +- Reusable error handling decorator +- Clean separation of concerns + +### 7. Rich Tables for Human-Readable Output + +**Decision**: Use the `rich` library for table formatting, kept as an optional dependency. + +**Rationale**: +- Provides beautiful, formatted tables with colors and borders +- Handles column width adjustment automatically +- Supports both dark and light terminal themes +- Optional dependency keeps the core package lightweight + +### 8. Unix-Style Default Output + +**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output. + +**Rationale**: +- Follows Unix philosophy +- Tables are better for human consumption +- JSON is better for scripting and automation +- Single flag switches between formats consistently + +## Performance Optimizations + +1. **LazyGroup for deferred imports** - Commands only loaded when invoked +2. **No heavy imports at module level** - Only Click imported initially +3. **Lazy loading** of all SDK dependencies +4. **Progress indicators** that clear themselves +5. **Efficient data fetching** - fetch all data by default instead of pagination + +## Error Handling Strategy + +1. **User-friendly messages** - Technical errors are translated to helpful messages +2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes +3. **Graceful degradation** - Continue operation when possible +4. **Detailed error info** - Show details when available, traceback only in TTY + +### TinkerCliError Exception Pattern + +All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`: + +```python +from ..exceptions import TinkerCliError + +# Instead of: +print(f"Error: Something went wrong", file=sys.stderr) +sys.exit(1) + +# Use: +raise TinkerCliError( + "Something went wrong", + "Optional details or help text", + exit_code=1 # Optional, defaults to 1 +) +``` + +**Benefits:** +- Better testability (can catch exceptions in tests) +- Centralized error formatting in `__main__.py` +- Consistent exit codes across the CLI +- Stack traces preserved for debugging + +**Important Notes:** +- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification +- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages +- The main error handler in `__main__.py` handles printing to stderr and exiting + +## Future Extensibility + +The architecture supports easy addition of: + +### New Commands +- Create new module in `commands/` directory +- Define output classes in the same module if needed +- Add command to lazy_subcommands in `__init__.py` + +### New Subcommands +- Add new Click command decorator to existing command module +- Define corresponding output class if needed +- Subcommands automatically discovered by Click + +### New Output Formats +- Override `print()` method in `OutputBase` +- Or add new format handling to base class + +## Testing Guidelines + +1. **Startup time**: `time tinker --help` should be <50ms +2. **Import verification**: Check that modules aren't imported unnecessarily +3. **Output formats**: Test both table and JSON output +4. **Error cases**: Test with missing auth, invalid IDs, network errors +5. **Empty results**: Ensure graceful handling of no data + +## Module Structure + +``` +cli/ +├── __init__.py # Main entry with LazyGroup configuration +├── __main__.py # Module execution support +├── lazy_group.py # LazyGroup implementation for lazy loading +├── output.py # OutputBase class and formatting utilities +├── client.py # SDK client creation and error handling +├── commands/ +│ ├── __init__.py # Command module marker +│ ├── version.py # Version command +│ ├── run.py # Run commands and output classes +│ └── checkpoint.py # Checkpoint commands and output classes +└── CLAUDE.md # This documentation +``` + +## Command Examples + +```bash +# Show version +tinker version + +# List all training runs +tinker run list + +# Show run details +tinker run info run-abc123 + +# List all checkpoints +tinker checkpoint list + +# List checkpoints for specific run +tinker checkpoint list run-abc123 + +# Show checkpoint details +tinker checkpoint info ckpt-xyz789 + +# JSON output +tinker --format json run list +tinker --format json checkpoint list +``` + +## Dependencies + +### Required +- Python 3.11+ +- tinker SDK (main package) +- click>=8.0.0 (CLI framework) + +### Optional +- `rich` - For table formatting (installed with `pip install tinker[cli]`) + +## Migration from Argparse to Click + +### Key Changes: +1. **Command Definition**: Decorators instead of `parser.add_argument()` +2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch +3. **Context Passing**: Click's context system for sharing format option +4. **Error Handling**: Click handles exits and error formatting +5. **Help Generation**: Automatic from docstrings and decorators + +### Benefits: +- Cleaner, more Pythonic code +- Better command organization +- Built-in testing utilities +- Easier to extend with plugins +- More consistent behavior across commands + +## Maintenance Notes + +1. **Keep imports lazy** - Use LazyGroup for all commands +2. **Test startup time** - Regularly verify fast startup is maintained +3. **Follow Click patterns** - Use decorators and context properly +4. **Document changes** - Update this file when making architectural changes +5. **Maintain consistency** - All commands should follow the same structure diff --git a/src/tinker/cli/__main__.py b/src/tinker/cli/__main__.py new file mode 100644 index 0000000..41c54c6 --- /dev/null +++ b/src/tinker/cli/__main__.py @@ -0,0 +1,60 @@ +"""Tinker CLI - Command-line interface for the Tinker SDK. + +This module implements lazy loading to ensure fast startup times. +Only Click is imported at the module level. All other dependencies +including command modules are imported on-demand. + +Enable execution of the CLI via python -m tinker.cli +""" + +import sys +import click +from .lazy_group import LazyGroup +from .context import CLIContext +from .exceptions import TinkerCliError + + +@click.group( + cls=LazyGroup, + lazy_subcommands={ + "checkpoint": "tinker.cli.commands.checkpoint:cli", + "run": "tinker.cli.commands.run:cli", + "version": "tinker.cli.commands.version:cli", + }, + context_settings=dict(help_option_names=["-h", "--help"]), +) +@click.option( + "--format", + "-f", + type=click.Choice(["table", "json"]), + default="table", + help="Output format (default: table)", +) +@click.pass_context +def main_cli(ctx: click.Context, format: str) -> None: + """Tinker management CLI.""" + # Store format in context for subcommands to access + ctx.obj = CLIContext(format=format) # type: ignore[assignment] + + +def main(): + try: + main_cli() + except TinkerCliError as e: + # Print error message to stderr + if e.message: + print(f"Error: {e.message}", file=sys.stderr) + if e.details: + print(e.details, file=sys.stderr) + sys.exit(e.exit_code) + except KeyboardInterrupt: + # Standard Unix exit code for Ctrl+C + sys.exit(130) + + +# Make main available for entry point +cli = main + + +if __name__ == "__main__": + main() diff --git a/src/tinker/cli/client.py b/src/tinker/cli/client.py new file mode 100644 index 0000000..3e7a05a --- /dev/null +++ b/src/tinker/cli/client.py @@ -0,0 +1,157 @@ +"""Client utilities for tinker CLI - handles SDK client creation and configuration. + +This module provides functions for creating and configuring the Tinker SDK +client, with proper error handling for common issues like authentication +and network errors. +""" + +import sys +from functools import wraps +from typing import TypeVar, Callable, Any, cast, TYPE_CHECKING + +from .exceptions import TinkerCliError + +if TYPE_CHECKING: + from tinker.lib.public_interfaces.rest_client import RestClient + + +def create_rest_client() -> "RestClient": + """Create and configure a RestClient instance with proper error handling. + + This function handles the creation of the ServiceClient and RestClient, + with appropriate error messages for common failure cases. + + Returns: + A configured RestClient instance + + Raises: + TinkerCliError: If client creation fails + """ + # Lazy import to avoid slow startup + from tinker import ServiceClient + + try: + service_client = ServiceClient() + return service_client.create_rest_client() + except ImportError as e: + raise TinkerCliError( + f"Failed to import Tinker SDK: {e}", + "Please ensure the tinker package is properly installed.", + ) + except ValueError as e: + # Often indicates missing or invalid API key + raise TinkerCliError( + f"Configuration error: {e}", "Please check your Tinker API credentials." + ) + except Exception as e: + # Catch-all for other errors + raise TinkerCliError( + f"Failed to connect to Tinker API: {e}", + "Please check your network connection and API configuration.", + ) + + +# Type variable for decorator +F = TypeVar("F", bound=Callable[..., Any]) + + +def handle_api_errors(func: F) -> F: + """Decorator for handling common API errors. + + This decorator catches common exceptions from the Tinker API + and provides user-friendly error messages. + + Args: + func: Function to wrap with error handling + + Returns: + Wrapped function with error handling + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Lazy import to avoid slow startup + from tinker._exceptions import ( + APIError, + BadRequestError, + AuthenticationError, + PermissionDeniedError, + NotFoundError, + UnprocessableEntityError, + RateLimitError, + InternalServerError, + APIStatusError, + APIConnectionError, + APITimeoutError, + ) + + try: + return func(*args, **kwargs) + except NotFoundError as e: + details = f"Details: {e.message}" if hasattr(e, "message") else None + raise TinkerCliError("Resource not found", details) + except AuthenticationError as e: + details = "Please check your API key or authentication credentials." + if hasattr(e, "message"): + details += f"\nDetails: {e.message}" + raise TinkerCliError("Authentication failed", details) + except PermissionDeniedError as e: + details = "You don't have permission to access this resource." + if hasattr(e, "message"): + details += f"\nDetails: {e.message}" + raise TinkerCliError("Permission denied", details) + except BadRequestError as e: + details = f"Details: {e.message}" if hasattr(e, "message") else None + raise TinkerCliError("Invalid request", details) + except UnprocessableEntityError as e: + details = f"Details: {e.message}" if hasattr(e, "message") else None + raise TinkerCliError("Invalid data provided", details) + except RateLimitError as e: + details = "Please wait a moment before trying again." + if hasattr(e, "message"): + details += f"\nDetails: {e.message}" + raise TinkerCliError("Rate limit exceeded", details) + except InternalServerError as e: + details = "The Tinker API encountered an internal error. Please try again later." + if hasattr(e, "message"): + details += f"\nDetails: {e.message}" + raise TinkerCliError("Internal server error", details) + except APITimeoutError as e: + details = "The request to Tinker API timed out. Please try again." + if hasattr(e, "message"): + details += f"\nDetails: {e.message}" + raise TinkerCliError("Request timeout", details) + except APIConnectionError as e: + details = "Could not connect to the Tinker API. Please check your network connection." + if hasattr(e, "message"): + details += f"\nDetails: {e.message}" + raise TinkerCliError("Connection failed", details) + except APIStatusError as e: + status = e.status_code if hasattr(e, "status_code") else "unknown" + details = f"Details: {e.message}" if hasattr(e, "message") else None + raise TinkerCliError(f"API error (status {status})", details) + except APIError as e: + # Generic API error + details = f"Details: {e.message}" if hasattr(e, "message") else None + raise TinkerCliError("API error occurred", details) + except TinkerCliError: + # Re-raise our own errors without modification + raise + except KeyboardInterrupt: + # Re-raise keyboard interrupt to be handled by main + raise + except Exception as e: + # Catch-all for unexpected errors + import traceback + + details = None + if sys.stderr.isatty(): + # Only include traceback if stderr is a terminal (for debugging) + import io + + tb_str = io.StringIO() + traceback.print_exc(file=tb_str) + details = tb_str.getvalue() + raise TinkerCliError(f"Unexpected error occurred: {e}", details) + + return cast(F, wrapper) diff --git a/src/tinker/cli/commands/__init__.py b/src/tinker/cli/commands/__init__.py new file mode 100644 index 0000000..ebe6ce0 --- /dev/null +++ b/src/tinker/cli/commands/__init__.py @@ -0,0 +1 @@ +"""Command modules for the Tinker CLI.""" diff --git a/src/tinker/cli/commands/checkpoint.py b/src/tinker/cli/commands/checkpoint.py new file mode 100644 index 0000000..546341c --- /dev/null +++ b/src/tinker/cli/commands/checkpoint.py @@ -0,0 +1,414 @@ +"""Commands for managing checkpoints. + +This module implements the 'tinker checkpoint' commands, including: +- list: List all checkpoints or checkpoints for a specific run +- info: Show details of a specific checkpoint +""" + +from typing import TYPE_CHECKING, Any, Dict, List + +import click + +if TYPE_CHECKING: + from tinker.types import Checkpoint + +from ..client import create_rest_client, handle_api_errors +from ..context import CLIContext +from ..exceptions import TinkerCliError +from ..output import OutputBase, format_bool, format_size, format_timestamp + + +class CheckpointListOutput(OutputBase): + """Output for 'tinker checkpoint list' command.""" + + def __init__( + self, + checkpoints: List["Checkpoint"], + run_id: str | None = None, + total_count: int | None = None, + shown_count: int | None = None, + ): + """Initialize with list of checkpoints. + + Args: + checkpoints: List of Checkpoint objects + run_id: Optional training run ID if filtering by run + total_count: Total number of checkpoints available + shown_count: Number of checkpoints shown in this response + """ + self.checkpoints = checkpoints + self.run_id = run_id + self.total_count = total_count + self.shown_count = shown_count if shown_count is not None else len(checkpoints) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON output.""" + result = {} + if self.run_id: + result["run_id"] = self.run_id + + # Check if these are Pydantic models + if self.checkpoints and hasattr(self.checkpoints[0], "model_dump"): + result["checkpoints"] = [c.model_dump() for c in self.checkpoints] + else: + result["checkpoints"] = [dict(c) for c in self.checkpoints] + + return result + + def get_title(self) -> str | None: + """Return title for table output.""" + count = len(self.checkpoints) + + if self.run_id: + if count == 0: + return f"No checkpoints found for run {self.run_id}" + elif count == 1: + title = f"1 checkpoint for run {self.run_id}" + else: + title = f"{count} checkpoints for run {self.run_id}" + else: + if count == 0: + return "No checkpoints found" + elif count == 1: + title = "1 checkpoint" + else: + title = f"{count} checkpoints" + + # Add information about remaining checkpoints if available + if self.total_count is not None and self.total_count > self.shown_count: + remaining = self.total_count - self.shown_count + if remaining == 1: + title += " (1 more not shown, use --limit to see more)" + else: + title += f" ({remaining} more not shown, use --limit to see more)" + + return title + + def get_table_columns(self) -> List[str]: + """Return column headers for table output.""" + return ["Checkpoint ID", "Type", "Size", "Public", "Created", "Path"] + + def get_table_rows(self) -> List[List[str]]: + """Return rows for table output.""" + rows = [] + for ckpt in self.checkpoints: + rows.append( + [ + ckpt.checkpoint_id, + ckpt.checkpoint_type, + format_size(ckpt.size_bytes) if hasattr(ckpt, "size_bytes") else "N/A", + format_bool(ckpt.public), + format_timestamp(ckpt.time), + ckpt.tinker_path, + ] + ) + return rows + + +class CheckpointInfoOutput(OutputBase): + """Output for 'tinker checkpoint info' command.""" + + def __init__(self, checkpoint: "Checkpoint"): + """Initialize with a single checkpoint. + + Args: + checkpoint: Checkpoint object + """ + self.checkpoint = checkpoint + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON output.""" + if hasattr(self.checkpoint, "model_dump"): + return self.checkpoint.model_dump() + return dict(self.checkpoint) + + def get_title(self) -> str | None: + """Return title for table output.""" + return f"Checkpoint: {self.checkpoint.checkpoint_id}" + + def get_table_columns(self) -> List[str]: + """Return column headers for table output.""" + return ["Property", "Value"] + + def get_table_rows(self) -> List[List[str]]: + """Return rows for table output.""" + rows = [ + ["Checkpoint ID", self.checkpoint.checkpoint_id], + ["Type", self.checkpoint.checkpoint_type], + ["Tinker Path", self.checkpoint.tinker_path], + ] + + # Size if available + if hasattr(self.checkpoint, "size_bytes"): + rows.append(["Size", format_size(self.checkpoint.size_bytes)]) + + # Public status + rows.append(["Public", format_bool(self.checkpoint.public)]) + + # Creation time + rows.append(["Created", format_timestamp(self.checkpoint.time)]) + + # Parse training run ID from path + if self.checkpoint.tinker_path.startswith("tinker://"): + parts = self.checkpoint.tinker_path.replace("tinker://", "").split("/") + if parts: + rows.append(["Training Run ID", parts[0]]) + + return rows + + +def get_checkpoint_from_path(checkpoint_path: str) -> "Checkpoint": + """Get checkpoint details from a tinker path. + + Args: + checkpoint_path: A tinker path like "tinker://run-id/weights/0001" + + Returns: + Checkpoint object + + Raises: + TinkerCliError: If the checkpoint cannot be retrieved + """ + # Lazy import + from tinker import ParsedCheckpointTinkerPath + + try: + parsed = ParsedCheckpointTinkerPath.from_tinker_path(checkpoint_path) + client = create_rest_client() + + # Get the checkpoint info + checkpoints_response = client.list_checkpoints(parsed.training_run_id).result() + + # Find the matching checkpoint + for ckpt in checkpoints_response.checkpoints: + if ckpt.tinker_path == checkpoint_path: + return ckpt + + raise TinkerCliError(f"Checkpoint not found: {checkpoint_path}") + + except ValueError as e: + raise TinkerCliError( + f"Invalid checkpoint path: {e}", + "Checkpoint paths should be in the format: tinker://run-id/weights/0001", + ) + except TinkerCliError: + # Re-raise our own errors + raise + except Exception as e: + raise TinkerCliError(f"Failed to retrieve checkpoint: {e}") + + +# Click command group for checkpoint commands +@click.group() +def cli(): + """Manage checkpoints.""" + pass + + +@cli.command() +@click.option("--run-id", help="Training run ID") +@click.option( + "--limit", + type=int, + default=20, + help="Maximum number of checkpoints to display when listing from all runs (default: 20, use --limit=0 to show all)", +) +@click.pass_obj +@handle_api_errors +def list(cli_context: CLIContext, run_id: str | None, limit: int) -> None: + """List checkpoints. + + If --run-id is provided, list checkpoints for that specific training run. + Otherwise, list checkpoints from all recent runs. + """ + # Get format from context object + format = cli_context.format + + # Create client + client = create_rest_client() + + if run_id: + # List checkpoints for specific run. + # Note that there's no pagination for listing checkpoints on a single training run. + response = client.list_checkpoints(run_id).result() + + # Create output object + output = CheckpointListOutput(checkpoints=response.checkpoints, run_id=run_id) + else: + # List checkpoints from all user's training runs using list_user_checkpoints() + all_checkpoints = [] + offset = 0 + # Fetch in batches of 1000 since the queries are so slow + BATCH_SIZE = 1000 + + # First fetch to get initial data and total count + first_response = client.list_user_checkpoints( + limit=min(BATCH_SIZE, limit) if limit > 0 else BATCH_SIZE, offset=0 + ).result() + all_checkpoints.extend(first_response.checkpoints) + total_count = ( + first_response.cursor.total_count + if first_response.cursor + else len(first_response.checkpoints) + ) + + # Determine target count: either user-specified limit or total available + target_count = limit if limit > 0 else total_count + target_count = min(target_count, total_count) # Can't fetch more than exists + + # If we need to fetch more checkpoints, paginate with a progress bar + if len(all_checkpoints) < target_count: + with click.progressbar( + length=target_count, + label=f"Fetching {'all' if limit == 0 else str(target_count)} checkpoints", + show_percent=True, + show_pos=True, + show_eta=True, + ) as bar: + bar.update(len(all_checkpoints)) + + # Fetch remaining checkpoints in batches + while len(all_checkpoints) < target_count: + offset = len(all_checkpoints) + remaining = target_count - len(all_checkpoints) + next_batch_size = min(BATCH_SIZE, remaining) + + response = client.list_user_checkpoints( + limit=next_batch_size, offset=offset + ).result() + all_checkpoints.extend(response.checkpoints) + bar.update(len(response.checkpoints)) + + # Break if we got fewer than requested (reached the end) + if len(response.checkpoints) < next_batch_size: + break + + # Create output object with pagination information + output = CheckpointListOutput( + checkpoints=all_checkpoints, total_count=total_count, shown_count=len(all_checkpoints) + ) + + # Print in requested format + output.print(format=format) + + +@cli.command() +@click.argument("checkpoint_path") +@click.pass_obj +@handle_api_errors +def info(cli_context: CLIContext, checkpoint_path: str) -> None: + """Show details of a specific checkpoint. + + CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001). + """ + # Get format from context object + format = cli_context.format + + # Validate it's a tinker path + if not checkpoint_path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {checkpoint_path}", + "Checkpoint path must be in the format: tinker://run-id/weights/0001", + ) + + checkpoint = get_checkpoint_from_path(checkpoint_path) + + # Create output object + output = CheckpointInfoOutput(checkpoint=checkpoint) + + # Print in requested format + output.print(format=format) + + +@cli.command() +@click.argument("checkpoint_path") +@click.pass_obj +@handle_api_errors +def publish(cli_context: CLIContext, checkpoint_path: str) -> None: + """Publish a checkpoint to make it publicly accessible. + + CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001). + Only the owner of the training run can publish checkpoints. + """ + # Validate it's a tinker path + if not checkpoint_path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {checkpoint_path}", + "Checkpoint path must be in the format: tinker://run-id/weights/0001", + ) + + # Create client and publish + client = create_rest_client() + client.publish_checkpoint_from_tinker_path(checkpoint_path).result() + + +@cli.command() +@click.argument("checkpoint_path") +@click.pass_obj +@handle_api_errors +def unpublish(cli_context: CLIContext, checkpoint_path: str) -> None: + """Unpublish a checkpoint to make it private again. + + CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001). + Only the owner of the training run can unpublish checkpoints. + """ + # Validate it's a tinker path + if not checkpoint_path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {checkpoint_path}", + "Checkpoint path must be in the format: tinker://run-id/weights/0001", + ) + + # Create client and unpublish + client = create_rest_client() + client.unpublish_checkpoint_from_tinker_path(checkpoint_path).result() + + +@cli.command() +@click.argument("checkpoint_path") +@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") +@click.pass_obj +@handle_api_errors +def delete(cli_context: CLIContext, checkpoint_path: str, yes: bool) -> None: + """Delete a checkpoint permanently. + + CHECKPOINT_PATH must be a tinker path (e.g., tinker://run-id/weights/0001). + Only the owner of the training run can delete checkpoints. + + WARNING: This action is permanent and cannot be undone. + """ + # Validate it's a tinker path + if not checkpoint_path.startswith("tinker://"): + raise TinkerCliError( + f"Invalid checkpoint path: {checkpoint_path}", + "Checkpoint path must be in the format: tinker://run-id/weights/0001", + ) + + # Get format from context object + format = cli_context.format + + # If not using --yes, show checkpoint info and prompt for confirmation + if not yes: + try: + checkpoint = get_checkpoint_from_path(checkpoint_path) + + # Display checkpoint info using the same format as 'info' command + output = CheckpointInfoOutput(checkpoint) + output.print(format=format) + click.echo() + + except TinkerCliError: + # If we can't get checkpoint info, still allow deletion attempt + # The API will return appropriate error if checkpoint doesn't exist + click.echo(f"Checkpoint path: {checkpoint_path}") + click.echo() + + # Confirmation prompt + click.echo("WARNING: This action is permanent and cannot be undone.") + if not click.confirm("Are you sure you want to delete this checkpoint?"): + click.echo("Deletion cancelled.") + return + + # Create client and delete + client = create_rest_client() + client.delete_checkpoint_from_tinker_path(checkpoint_path).result() diff --git a/src/tinker/cli/commands/run.py b/src/tinker/cli/commands/run.py new file mode 100644 index 0000000..68badbb --- /dev/null +++ b/src/tinker/cli/commands/run.py @@ -0,0 +1,257 @@ +"""Commands for managing training runs. + +This module implements the 'tinker run' commands, including: +- list: List all training runs +- info: Show details of a specific run +""" + +from typing import Any, Dict, List, TYPE_CHECKING + +import click + +if TYPE_CHECKING: + from tinker.types import TrainingRun + +from ..client import create_rest_client, handle_api_errors +from ..context import CLIContext +from ..output import OutputBase, format_timestamp + + +class RunListOutput(OutputBase): + """Output for 'tinker run list' command.""" + + def __init__( + self, + runs: List["TrainingRun"], + total_count: int | None = None, + shown_count: int | None = None, + ): + """Initialize with list of training runs. + + Args: + runs: List of TrainingRun objects + total_count: Total number of runs available (from cursor) + shown_count: Number of runs shown in this response + """ + self.runs = runs + self.total_count = total_count + self.shown_count = shown_count if shown_count is not None else len(runs) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON output.""" + # Check if these are Pydantic models + if self.runs and hasattr(self.runs[0], "model_dump"): + return {"runs": [run.model_dump() for run in self.runs]} + return {"runs": [dict(run) for run in self.runs]} + + def get_title(self) -> str | None: + """Return title for table output.""" + count = len(self.runs) + if count == 0: + return "No training runs found" + + # Build the base title + if count == 1: + title = "1 training run" + else: + title = f"{count} training runs" + + # Add information about remaining runs if available + if self.total_count is not None and self.total_count > self.shown_count: + remaining = self.total_count - self.shown_count + if remaining == 1: + title += f" (1 more not shown, use --limit to see more)" + else: + title += f" ({remaining} more not shown, use --limit to see more)" + + return title + + def get_table_columns(self) -> List[str]: + """Return column headers for table output.""" + return ["Run ID", "Base Model", "Owner", "LoRA", "Last Update", "Corrupted"] + + def get_table_rows(self) -> List[List[str]]: + """Return rows for table output.""" + rows = [] + for run in self.runs: + # Format LoRA information + if run.is_lora and run.lora_rank: + lora_info = f"Rank {run.lora_rank}" + elif run.is_lora: + lora_info = "Yes" + else: + lora_info = "No" + + rows.append( + [ + run.training_run_id, + run.base_model, + run.model_owner, + lora_info, + format_timestamp(run.last_request_time), + str(run.corrupted), + ] + ) + + return rows + + +class RunInfoOutput(OutputBase): + """Output for 'tinker run info' command.""" + + def __init__(self, run: "TrainingRun"): + """Initialize with a single training run. + + Args: + run: TrainingRun object + """ + self.run = run + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON output.""" + if hasattr(self.run, "model_dump"): + return self.run.model_dump() + return dict(self.run) + + def get_title(self) -> str | None: + """Return title for table output.""" + return f"Training Run: {self.run.training_run_id}" + + def get_table_columns(self) -> List[str]: + """Return column headers for table output.""" + return ["Property", "Value"] + + def get_table_rows(self) -> List[List[str]]: + """Return rows for table output.""" + rows = [ + ["Run ID", self.run.training_run_id], + ["Base Model", self.run.base_model], + ["Owner", self.run.model_owner], + ] + + # LoRA information + if self.run.is_lora: + if self.run.lora_rank: + rows.append(["LoRA", f"Yes (Rank {self.run.lora_rank})"]) + else: + rows.append(["LoRA", "Yes"]) + else: + rows.append(["LoRA", "No"]) + + # Last update time + rows.append(["Last Update", format_timestamp(self.run.last_request_time)]) + + # Corruption status + rows.append(["Status", "Corrupted" if self.run.corrupted else "Active"]) + + # Last checkpoints + if self.run.last_checkpoint: + rows.append(["Last Training Checkpoint", self.run.last_checkpoint.checkpoint_id]) + rows.append([" - Time", format_timestamp(self.run.last_checkpoint.time)]) + rows.append([" - Path", self.run.last_checkpoint.tinker_path]) + + if self.run.last_sampler_checkpoint: + rows.append(["Last Sampler Checkpoint", self.run.last_sampler_checkpoint.checkpoint_id]) + rows.append([" - Time", format_timestamp(self.run.last_sampler_checkpoint.time)]) + rows.append([" - Path", self.run.last_sampler_checkpoint.tinker_path]) + + # User metadata if present + if self.run.user_metadata: + rows.append(["Metadata", ""]) + for key, value in self.run.user_metadata.items(): + rows.append([f" - {key}", value]) + + return rows + + +# Click command group for run commands +@click.group() +def cli(): + """Manage training runs.""" + pass + + +@cli.command() +@click.option( + "--limit", + type=int, + default=20, + help="Maximum number of runs to fetch (default: 20, use --limit=0 to fetch all)", +) +@click.pass_obj +@handle_api_errors +def list(cli_context: CLIContext, limit: int) -> None: + """List all training runs.""" + # Get format from context object + format = cli_context.format + + # Create client + client = create_rest_client() + + all_runs = [] + offset = 0 + batch_size = 100 # Fetch in batches of 100 for efficiency + + # First fetch to get initial data and total count + first_response = client.list_training_runs( + limit=min(batch_size, limit) if limit > 0 else batch_size, offset=0 + ).result() + all_runs.extend(first_response.training_runs) + total_count = first_response.cursor.total_count + + # Determine target count: either user-specified limit or total available + target_count = limit if limit > 0 else total_count + target_count = min(target_count, total_count) # Can't fetch more than exists + + # If we need to fetch more runs, paginate with a progress bar + if len(all_runs) < target_count: + with click.progressbar( + length=target_count, + label=f"Fetching {'all' if limit == 0 else str(target_count)} training runs", + show_percent=True, + show_pos=True, + show_eta=True, + ) as bar: + bar.update(len(all_runs)) + + # Fetch remaining runs in batches + while len(all_runs) < target_count: + offset = len(all_runs) + remaining = target_count - len(all_runs) + next_batch_size = min(batch_size, remaining) + + response = client.list_training_runs(limit=next_batch_size, offset=offset).result() + all_runs.extend(response.training_runs) + bar.update(len(response.training_runs)) + + # Break if we got fewer than requested (reached the end) + if len(response.training_runs) < next_batch_size: + break + + # Create output object with pagination information + output = RunListOutput(runs=all_runs, total_count=total_count, shown_count=len(all_runs)) + + # Print in requested format + output.print(format=format) + + +@cli.command() +@click.argument("run_id") +@click.pass_obj +@handle_api_errors +def info(cli_context: CLIContext, run_id: str) -> None: + """Show details of a specific run.""" + # Get format from context object + format = cli_context.format + + # Create client + client = create_rest_client() + + # Fetch training run details + response = client.get_training_run(run_id).result() + + # Create output object + output = RunInfoOutput(run=response) + + # Print in requested format + output.print(format=format) diff --git a/src/tinker/cli/commands/version.py b/src/tinker/cli/commands/version.py new file mode 100644 index 0000000..275f23b --- /dev/null +++ b/src/tinker/cli/commands/version.py @@ -0,0 +1,18 @@ +"""Command for showing version information. + +This module implements the 'tinker version' command. +""" + +import click + + +@click.command() +def cli(): + """Show version information.""" + try: + # Lazy import version only when needed + from tinker._version import __version__ + + click.echo(f"tinker {__version__}") + except ImportError: + click.echo("tinker (version unavailable)") diff --git a/src/tinker/cli/context.py b/src/tinker/cli/context.py new file mode 100644 index 0000000..f1fa867 --- /dev/null +++ b/src/tinker/cli/context.py @@ -0,0 +1,23 @@ +"""Context object for the Tinker CLI. + +This module provides a dataclass for sharing configuration and state +between CLI commands. +""" + +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class CLIContext: + """Context object for sharing state between CLI commands. + + This dataclass is passed through the Click command hierarchy + using the @click.pass_obj decorator, allowing commands to access + shared configuration without needing to traverse the context tree. + + Attributes: + format: Output format for command results ('table' or 'json') + """ + + format: Literal["table", "json"] = "table" diff --git a/src/tinker/cli/exceptions.py b/src/tinker/cli/exceptions.py new file mode 100644 index 0000000..94805dd --- /dev/null +++ b/src/tinker/cli/exceptions.py @@ -0,0 +1,31 @@ +"""Custom exceptions for the Tinker CLI. + +This module defines exceptions used throughout the CLI for consistent +error handling and graceful exits. +""" + + +class TinkerCliError(Exception): + """Custom exception for CLI errors that should exit gracefully. + + This exception is caught at the top level of the CLI and converted + to appropriate error messages and exit codes. + + Attributes: + message: The main error message to display + details: Optional additional details or suggestions + exit_code: The exit code to use (default: 1) + """ + + def __init__(self, message: str, details: str | None = None, exit_code: int = 1): + """Initialize a TinkerCliError. + + Args: + message: The main error message (will be prefixed with "Error: ") + details: Optional additional details or help text + exit_code: The exit code to use when exiting (default: 1) + """ + self.message = message + self.details = details + self.exit_code = exit_code + super().__init__(message) diff --git a/src/tinker/cli/lazy_group.py b/src/tinker/cli/lazy_group.py new file mode 100644 index 0000000..715fcf4 --- /dev/null +++ b/src/tinker/cli/lazy_group.py @@ -0,0 +1,89 @@ +"""Lazy loading support for Click commands. + +This module provides a LazyGroup class that extends Click's Group to support +lazy loading of subcommands, ensuring fast CLI startup times. +""" + +import importlib +from typing import Any, Dict, List +import click + + +class LazyGroup(click.Group): + """A Click Group that supports lazy loading of subcommands. + + This allows the CLI to have fast startup times by only importing + command modules when they are actually invoked, not when the CLI + is first loaded or when help is displayed. + """ + + def __init__( + self, *args: Any, lazy_subcommands: Dict[str, str] | None = None, **kwargs: Any + ) -> None: + """Initialize the LazyGroup. + + Args: + lazy_subcommands: A dictionary mapping command names to import paths. + Format: {"command": "module.path:attribute_name"} + """ + super().__init__(*args, **kwargs) + self.lazy_subcommands = lazy_subcommands or {} + + def list_commands(self, ctx: click.Context) -> List[str]: + """Return a list of all command names. + + This includes both eagerly loaded commands and lazy commands. + """ + # Get any eagerly loaded commands + base = super().list_commands(ctx) + # Add lazy command names + lazy = sorted(self.lazy_subcommands.keys()) + return base + lazy + + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: + """Get a command by name, loading it lazily if necessary. + + Args: + ctx: The Click context + cmd_name: The name of the command to retrieve + + Returns: + The Click command object, or None if not found + """ + # Check if it's a lazy command + if cmd_name in self.lazy_subcommands: + return self._lazy_load(cmd_name) + # Fall back to normal command loading + return super().get_command(ctx, cmd_name) + + def _lazy_load(self, cmd_name: str) -> click.Command: + """Lazily load a command by importing its module. + + Args: + cmd_name: The name of the command to load + + Returns: + The loaded Click command object + + Raises: + ValueError: If the imported object is not a Click Command + """ + # Get the import path for this command + import_path = self.lazy_subcommands[cmd_name] + + # Split into module path and attribute name + module_name, attr_name = import_path.rsplit(":", 1) + + # Import the module + mod = importlib.import_module(module_name) + + # Get the command object + cmd_object = getattr(mod, attr_name) + + # Verify it's a Click command + if not isinstance(cmd_object, click.Command): + raise ValueError( + f"Lazy loading of {import_path} failed: '{attr_name}' is not a Click Command" + ) + + return cmd_object diff --git a/src/tinker/cli/output.py b/src/tinker/cli/output.py new file mode 100644 index 0000000..694e511 --- /dev/null +++ b/src/tinker/cli/output.py @@ -0,0 +1,227 @@ +"""Output formatting utilities for the Tinker CLI. + +This module provides a base class for structured output that can be +rendered as either a table (using rich) or as JSON. Each command +defines its own output class that inherits from OutputBase. +""" + +import sys +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Callable, Union + + +class OutputBase(ABC): + """Virtual base class for all command outputs. + + Subclasses must implement methods to convert data to various formats. + The base class provides the common print() method that handles + format selection and rendering. + """ + + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + """Convert output to dictionary for JSON serialization. + + Returns: + Dictionary representation of the output data + """ + pass + + @abstractmethod + def get_table_columns(self) -> List[str]: + """Return list of column names for table output. + + Returns: + List of column header strings + """ + pass + + @abstractmethod + def get_table_rows(self) -> List[List[str]]: + """Return list of rows for table output. + + Each row should be a list of string values corresponding + to the columns returned by get_table_columns(). + + Returns: + List of rows, where each row is a list of string values + """ + pass + + def get_title(self) -> str | None: + """Optional title for the output display. + + Override this method to provide a title for the table. + + Returns: + Title string or None + """ + return None + + def print(self, format: str = "table") -> None: + """Print the output in the specified format. + + Args: + format: Output format - either "table" or "json" + """ + if format == "json": + self._print_json() + else: + self._print_table() + + def _print_json(self) -> None: + """Print output as JSON.""" + import json + + data = self.to_dict() + json.dump(data, sys.stdout, indent=2, default=str) + print() # Add newline after JSON output + + def _print_table(self) -> None: + """Print output as a rich table.""" + # Lazy import rich to avoid slow startup + from rich.console import Console + from rich.table import Table + + console = Console() + + # Create table with optional title + title = self.get_title() + table = Table(title=title) if title else Table() + + # Add columns + columns = self.get_table_columns() + for col in columns: + # First column (usually ID) gets special styling + if col == columns[0]: + table.add_column(col, style="bright_cyan", no_wrap=True) + else: + table.add_column(col) + + # Add rows + rows = self.get_table_rows() + for row in rows: + table.add_row(*row) + + # Print the table + console.print(table) + + +# Utility formatting functions + + +def format_size(bytes: int) -> str: + """Format bytes as human-readable size. + + Args: + bytes: Size in bytes + + Returns: + Human-readable size string (e.g., "1.2 GB") + """ + if bytes < 0: + return "N/A" + + size = float(bytes) + for unit in ["B", "KB", "MB", "GB", "TB", "PB"]: + if size < 1024.0: + if unit == "B": + return f"{int(size)} {unit}" + return f"{size:.1f} {unit}" + size /= 1024.0 + + return f"{size:.1f} EB" + + +def format_timestamp(dt: Union[datetime, str, None]) -> str: + """Format datetime as relative time or absolute date. + + Args: + dt: datetime object, ISO string, or None + + Returns: + Formatted time string (e.g., "2 hours ago", "2024-01-15") + """ + if not dt: + return "N/A" + + # Lazy import datetime + from datetime import datetime, timezone + + # Handle different input types + if isinstance(dt, str): + # Try to parse ISO format string + try: + from datetime import datetime + + dt = datetime.fromisoformat(dt.replace("Z", "+00:00")) + except (ValueError, AttributeError): + return str(dt) + + if not hasattr(dt, "replace"): + # Not a datetime object + return str(dt) + + try: + # Get current time + now = datetime.now(timezone.utc) + + # Ensure dt has timezone info + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + else: + dt = dt.astimezone(timezone.utc) + + # Calculate time difference + delta = now - dt + + # Format based on age + if delta.days > 30: + return dt.strftime("%Y-%m-%d") + elif delta.days > 7: + weeks = delta.days // 7 + return f"{weeks} week{'s' if weeks > 1 else ''} ago" + elif delta.days > 0: + return f"{delta.days} day{'s' if delta.days > 1 else ''} ago" + elif delta.seconds > 3600: + hours = delta.seconds // 3600 + return f"{hours} hour{'s' if hours > 1 else ''} ago" + elif delta.seconds > 60: + minutes = delta.seconds // 60 + return f"{minutes} minute{'s' if minutes > 1 else ''} ago" + else: + return "just now" + + except Exception: + # If any error occurs, just return string representation + return str(dt) + + +def format_bool(value: bool) -> str: + """Format boolean for display. + + Args: + value: Boolean value + + Returns: + "Yes" or "No" + """ + return "Yes" if value else "No" + + +def format_optional(value: Any, formatter: Callable[[Any], str] | None = None) -> str: + """Format an optional value. + + Args: + value: Value to format (may be None) + formatter: Optional formatting function to apply if value is not None + + Returns: + Formatted string or "N/A" if value is None + """ + if value is None: + return "N/A" + if formatter: + return formatter(value) + return str(value) diff --git a/src/tinker/lib/public_interfaces/rest_client.py b/src/tinker/lib/public_interfaces/rest_client.py index 503df79..6716302 100644 --- a/src/tinker/lib/public_interfaces/rest_client.py +++ b/src/tinker/lib/public_interfaces/rest_client.py @@ -31,9 +31,12 @@ class RestClient(TelemetryProvider): Key methods: - list_checkpoints() - list available model checkpoints (both training and sampler) + - list_user_checkpoints() - list all checkpoints across all user's training runs - get_training_run() - get model information and metadata as ModelEntry - delete_checkpoint() - delete an existing checkpoint for a training run - get_checkpoint_archive_url() - get signed URL to download checkpoint archive + - publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public + - unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private Args: holder: Internal client managing HTTP connections and async operations @@ -420,3 +423,227 @@ class RestClient(TelemetryProvider): return await self._get_checkpoint_archive_url_submit( parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id ) + + def _publish_checkpoint_submit( + self, training_run_id: types.ModelID, checkpoint_id: str + ) -> AwaitableConcurrentFuture[None]: + """Internal method to submit publish checkpoint request.""" + + async def _publish_checkpoint_async() -> None: + async def _send_request() -> None: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + await client.post( + f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish", + cast_to=NoneType, + ) + + return await self.holder.execute_with_retries(_send_request) + + return self.holder.run_coroutine_threadsafe(_publish_checkpoint_async()) + + @sync_only + @capture_exceptions(fatal=True) + def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]: + """Publish a checkpoint referenced by a tinker path to make it publicly accessible. + + Only the exact owner of the training run can publish checkpoints. + Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path method. + + Args: + tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + + Returns: + A Future that completes when the checkpoint is published + + Raises: + HTTPException: 400 if checkpoint identifier is invalid + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 409 if checkpoint is already public + HTTPException: 500 if there's an error publishing the checkpoint + + Example: + >>> future = rest_client.publish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") + >>> future.result() # Wait for completion + >>> print("Checkpoint published successfully") + """ + parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + return self._publish_checkpoint_submit( + parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + ).future() + + @capture_exceptions(fatal=True) + async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: + """Async version of publish_checkpoint_from_tinker_path. + + Only the exact owner of the training run can publish checkpoints. + Published checkpoints can be unpublished using the unpublish_checkpoint_from_tinker_path_async method. + + Args: + tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + + Raises: + HTTPException: 400 if checkpoint identifier is invalid + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 409 if checkpoint is already public + HTTPException: 500 if there's an error publishing the checkpoint + + Example: + >>> await rest_client.publish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001") + >>> print("Checkpoint published successfully") + """ + parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + await self._publish_checkpoint_submit( + parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + ) + + def _unpublish_checkpoint_submit( + self, training_run_id: types.ModelID, checkpoint_id: str + ) -> AwaitableConcurrentFuture[None]: + """Internal method to submit unpublish checkpoint request.""" + + async def _unpublish_checkpoint_async() -> None: + async def _send_request() -> None: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + await client.delete( + f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish", + cast_to=NoneType, + ) + + return await self.holder.execute_with_retries(_send_request) + + return self.holder.run_coroutine_threadsafe(_unpublish_checkpoint_async()) + + @sync_only + @capture_exceptions(fatal=True) + def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]: + """Unpublish a checkpoint referenced by a tinker path to make it private again. + + Only the exact owner of the training run can unpublish checkpoints. + This reverses the effect of publishing a checkpoint. + + Args: + tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + + Returns: + A Future that completes when the checkpoint is unpublished + + Raises: + HTTPException: 400 if checkpoint identifier is invalid + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 409 if checkpoint is already private + HTTPException: 500 if there's an error unpublishing the checkpoint + + Example: + >>> future = rest_client.unpublish_checkpoint_from_tinker_path("tinker://run-id/weights/0001") + >>> future.result() # Wait for completion + >>> print("Checkpoint unpublished successfully") + """ + parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + return self._unpublish_checkpoint_submit( + parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + ).future() + + @capture_exceptions(fatal=True) + async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: + """Async version of unpublish_checkpoint_from_tinker_path. + + Only the exact owner of the training run can unpublish checkpoints. + This reverses the effect of publishing a checkpoint. + + Args: + tinker_path: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001") + + Raises: + HTTPException: 400 if checkpoint identifier is invalid + HTTPException: 404 if checkpoint not found or user doesn't own the training run + HTTPException: 409 if checkpoint is already private + HTTPException: 500 if there's an error unpublishing the checkpoint + + Example: + >>> await rest_client.unpublish_checkpoint_from_tinker_path_async("tinker://run-id/weights/0001") + >>> print("Checkpoint unpublished successfully") + """ + parsed_tinker_path = types.ParsedCheckpointTinkerPath.from_tinker_path(tinker_path) + await self._unpublish_checkpoint_submit( + parsed_tinker_path.training_run_id, parsed_tinker_path.checkpoint_id + ) + + def _list_user_checkpoints_submit( + self, limit: int = 100, offset: int = 0 + ) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]: + """Internal method to submit list user checkpoints request.""" + + async def _list_user_checkpoints_async() -> types.CheckpointsListResponse: + async def _send_request() -> types.CheckpointsListResponse: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + params: dict[str, object] = {"limit": limit, "offset": offset} + + return await client.get( + "/api/v1/checkpoints", + options={"params": params}, + cast_to=types.CheckpointsListResponse, + ) + + return await self.holder.execute_with_retries(_send_request) + + return self.holder.run_coroutine_threadsafe(_list_user_checkpoints_async()) + + @sync_only + @capture_exceptions(fatal=True) + def list_user_checkpoints( + self, limit: int = 100, offset: int = 0 + ) -> ConcurrentFuture[types.CheckpointsListResponse]: + """List all checkpoints for the current user across all their training runs. + + This method retrieves checkpoints from all training runs owned by the authenticated user, + sorted by time (newest first). It supports pagination for efficiently handling large + numbers of checkpoints. + + Args: + limit: Maximum number of checkpoints to return (default 100) + offset: Offset for pagination (default 0) + + Returns: + A Future containing the CheckpointsListResponse with checkpoints and cursor info + + Example: + >>> future = rest_client.list_user_checkpoints(limit=50) + >>> response = future.result() + >>> print(f"Found {len(response.checkpoints)} checkpoints") + >>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}") + >>> for checkpoint in response.checkpoints: + ... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}") + >>> # Get next page if there are more checkpoints + >>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count: + ... next_page = rest_client.list_user_checkpoints(limit=50, offset=50) + """ + return self._list_user_checkpoints_submit(limit, offset).future() + + @capture_exceptions(fatal=True) + async def list_user_checkpoints_async( + self, limit: int = 100, offset: int = 0 + ) -> types.CheckpointsListResponse: + """Async version of list_user_checkpoints. + + This method retrieves checkpoints from all training runs owned by the authenticated user, + sorted by time (newest first). It supports pagination for efficiently handling large + numbers of checkpoints. + + Args: + limit: Maximum number of checkpoints to return (default 100) + offset: Offset for pagination (default 0) + + Returns: + CheckpointsListResponse with checkpoints and cursor info + + Example: + >>> response = await rest_client.list_user_checkpoints_async(limit=50) + >>> print(f"Found {len(response.checkpoints)} checkpoints") + >>> print(f"Total: {response.cursor.total_count if response.cursor else 'Unknown'}") + >>> for checkpoint in response.checkpoints: + ... print(f" {checkpoint.training_run_id}/{checkpoint.checkpoint_id}") + >>> # Get next page if there are more checkpoints + >>> if response.cursor and response.cursor.offset + response.cursor.limit < response.cursor.total_count: + ... next_page = await rest_client.list_user_checkpoints_async(limit=50, offset=50) + """ + return await self._list_user_checkpoints_submit(limit, offset) diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index c8a163e..a35d2bc 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -593,6 +593,7 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre assert model_name is not None, "This shouldn't happen: model_name is None" # Use tokenizer_id from get_info if available, otherwise fall back to heuristic logic + kwargs = {} tokenizer_id = info.model_data.tokenizer_id if tokenizer_id is None: # We generally adhere to the huggingface convention of "/" but @@ -608,4 +609,10 @@ def _get_tokenizer(model_id: types.ModelID, holder: InternalClientHolder) -> Pre else: tokenizer_id = model_name - return AutoTokenizer.from_pretrained(tokenizer_id, fast=True) + if tokenizer_id == "moonshotai/Kimi-K2-Thinking": + kwargs = { + "trust_remote_code": True, + "revision": "612681931a8c906ddb349f8ad0f582cb552189cd", + } + + return AutoTokenizer.from_pretrained(tokenizer_id, fast=True, **kwargs) diff --git a/src/tinker/resources/weights.py b/src/tinker/resources/weights.py index 9cd8dca..17c6480 100644 --- a/src/tinker/resources/weights.py +++ b/src/tinker/resources/weights.py @@ -185,7 +185,7 @@ class AsyncWeightsResource(AsyncAPIResource): if not model_id: raise ValueError(f"Expected a non-empty value for `model_id` but received {model_id!r}") return await self._get( - f"/api/v1/models/{model_id}/checkpoints", + f"/api/v1/training_runs/{model_id}/checkpoints", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, diff --git a/src/tinker/types/checkpoint.py b/src/tinker/types/checkpoint.py index 7dd0e98..70e3973 100644 --- a/src/tinker/types/checkpoint.py +++ b/src/tinker/types/checkpoint.py @@ -21,6 +21,12 @@ class Checkpoint(BaseModel): tinker_path: str """The tinker path to the checkpoint""" + size_bytes: int | None = None + """The size of the checkpoint in bytes""" + + public: bool = False + """Whether the checkpoint is publicly accessible""" + class ParsedCheckpointTinkerPath(BaseModel): tinker_path: str diff --git a/src/tinker/types/checkpoints_list_response.py b/src/tinker/types/checkpoints_list_response.py index 3ab9a5b..67471bb 100644 --- a/src/tinker/types/checkpoints_list_response.py +++ b/src/tinker/types/checkpoints_list_response.py @@ -1,5 +1,6 @@ from .._models import BaseModel from .checkpoint import Checkpoint +from .cursor import Cursor __all__ = ["CheckpointsListResponse"] @@ -7,3 +8,6 @@ __all__ = ["CheckpointsListResponse"] class CheckpointsListResponse(BaseModel): checkpoints: list[Checkpoint] """List of available model checkpoints for the model""" + + cursor: Cursor | None = None + """Pagination cursor information (None for unpaginated responses)"""