mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
07bd3c2dd3
commit
30517b667f
33 changed files with 1272 additions and 371 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "db025e90079a19c36090a13aa88e4b2494d5a502",
|
||||
"last_sync_time": "2026-03-19T02:39:30.785199"
|
||||
"last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a",
|
||||
"last_sync_time": "2026-04-14T00:00:48.831738"
|
||||
}
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.16.1"
|
||||
version = "0.18.0"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -428,7 +428,9 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
|
|||
retries_taken: int = 0,
|
||||
) -> httpx.Request:
|
||||
if log.isEnabledFor(logging.DEBUG):
|
||||
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
|
||||
log.debug(
|
||||
"Request options: %s", model_dump(options, exclude_unset=False, exclude_none=True)
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from ._base_client import (
|
|||
AsyncAPIClient,
|
||||
)
|
||||
from ._compat import cached_property
|
||||
from ._exceptions import APIStatusError, TinkerError
|
||||
from ._exceptions import APIStatusError
|
||||
from ._qs import Querystring
|
||||
from ._streaming import AsyncStream as AsyncStream
|
||||
from ._types import (
|
||||
|
|
@ -26,6 +26,7 @@ from ._types import (
|
|||
)
|
||||
from ._utils import get_async_library, is_given
|
||||
from ._version import __version__
|
||||
from .lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .resources import futures, telemetry
|
||||
|
|
@ -47,9 +48,6 @@ __all__ = [
|
|||
|
||||
|
||||
class AsyncTinker(AsyncAPIClient):
|
||||
# client options
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -72,20 +70,16 @@ class AsyncTinker(AsyncAPIClient):
|
|||
# outlining your use-case to help us decide if it should be
|
||||
# part of our public interface in the future.
|
||||
_strict_response_validation: bool = False,
|
||||
_auth: AuthTokenProvider | None = None,
|
||||
) -> None:
|
||||
"""Construct a new async AsyncTinker client instance.
|
||||
|
||||
This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided.
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("TINKER_API_KEY")
|
||||
if api_key is None:
|
||||
raise TinkerError(
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
|
||||
)
|
||||
if not api_key.startswith("tml-"):
|
||||
raise TinkerError("The api_key must start with the 'tml-' prefix")
|
||||
self.api_key = api_key
|
||||
if _auth is not None:
|
||||
self._auth = _auth
|
||||
else:
|
||||
self._auth = ApiKeyAuthProvider(api_key=api_key)
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("TINKER_BASE_URL")
|
||||
|
|
@ -158,9 +152,8 @@ class AsyncTinker(AsyncAPIClient):
|
|||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"X-API-Key": api_key}
|
||||
def custom_auth(self) -> AuthTokenProvider:
|
||||
return self._auth
|
||||
|
||||
@property
|
||||
@override
|
||||
|
|
@ -174,7 +167,6 @@ class AsyncTinker(AsyncAPIClient):
|
|||
def copy(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
|
|
@ -212,7 +204,7 @@ class AsyncTinker(AsyncAPIClient):
|
|||
|
||||
http_client = http_client or self._client
|
||||
return self.__class__(
|
||||
api_key=api_key or self.api_key,
|
||||
_auth=self._auth,
|
||||
base_url=base_url or self.base_url,
|
||||
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
|
||||
http_client=http_client,
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ def model_dump(
|
|||
exclude: IncEx | None = None,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
warnings: bool = True,
|
||||
mode: Literal["json", "python"] = "python",
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -145,6 +146,7 @@ def model_dump(
|
|||
exclude=exclude,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
# warnings are not supported in Pydantic v1
|
||||
warnings=warnings if PYDANTIC_V2 else True,
|
||||
)
|
||||
|
|
@ -154,6 +156,7 @@ def model_dump(
|
|||
exclude=exclude,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from typing import (
|
|||
|
||||
import anyio
|
||||
import httpx
|
||||
import orjson
|
||||
import pydantic
|
||||
from typing_extensions import Awaitable, ParamSpec, get_origin, override
|
||||
|
||||
|
|
@ -351,7 +352,7 @@ class APIResponse(BaseAPIResponse[R]):
|
|||
def json(self) -> object:
|
||||
"""Read and decode the JSON response content."""
|
||||
self.read()
|
||||
return self.http_response.json()
|
||||
return orjson.loads(self.http_response.content)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the response and release the connection.
|
||||
|
|
@ -451,7 +452,7 @@ class AsyncAPIResponse(BaseAPIResponse[R]):
|
|||
async def json(self) -> object:
|
||||
"""Read and decode the JSON response content."""
|
||||
await self.read()
|
||||
return self.http_response.json()
|
||||
return orjson.loads(self.http_response.content)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the response and release the connection.
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ def _transform_recursive(
|
|||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
return model_dump(data, exclude_unset=True, mode="json")
|
||||
return model_dump(data, exclude_unset=False, exclude_none=True, mode="json")
|
||||
|
||||
annotated_type = _get_annotated_type(annotation)
|
||||
if annotated_type is None:
|
||||
|
|
@ -382,7 +382,7 @@ async def _async_transform_recursive(
|
|||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
return model_dump(data, exclude_unset=True, mode="json")
|
||||
return model_dump(data, exclude_unset=False, exclude_none=True, mode="json")
|
||||
|
||||
annotated_type = _get_annotated_type(annotation)
|
||||
if annotated_type is None:
|
||||
|
|
|
|||
285
src/tinker/cli/AGENTS.md
Normal file
285
src/tinker/cli/AGENTS.md
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
# Tinker CLI Design Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance.
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### 1. Lazy Import Strategy with Click
|
||||
|
||||
**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level.
|
||||
|
||||
**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands.
|
||||
|
||||
**Implementation**:
|
||||
- Main `__init__.py` only imports `click` and `lazy_group`
|
||||
- Command modules are loaded only when invoked via `LazyGroup`
|
||||
- Output formatting imports `rich` only when table output is needed
|
||||
- JSON module imported only when JSON output is requested
|
||||
- Version information loaded from `_version.py` only when `tinker version` is used
|
||||
|
||||
### 2. Click Framework with LazyGroup
|
||||
|
||||
**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading.
|
||||
|
||||
**Rationale**:
|
||||
- Click provides cleaner command structure with decorators
|
||||
- Better subcommand isolation - each command file is self-contained
|
||||
- Automatic help generation with better formatting
|
||||
- Built-in type conversion and validation
|
||||
- LazyGroup enables fast startup by deferring imports
|
||||
|
||||
**LazyGroup Implementation**:
|
||||
```python
|
||||
class LazyGroup(click.Group):
|
||||
def __init__(self, *args, lazy_subcommands=None, **kwargs):
|
||||
# Map of command name to "module.path:command_name"
|
||||
self.lazy_subcommands = lazy_subcommands or {}
|
||||
|
||||
def get_command(self, ctx, cmd_name):
|
||||
if cmd_name in self.lazy_subcommands:
|
||||
# Import only when command is actually invoked
|
||||
import_path = self.lazy_subcommands[cmd_name]
|
||||
module_name, attr_name = import_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_name)
|
||||
return getattr(mod, attr_name)
|
||||
```
|
||||
|
||||
### 3. Hierarchical Command Structure
|
||||
|
||||
**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`.
|
||||
|
||||
**Rationale**:
|
||||
- Provides a consistent, predictable interface
|
||||
- Groups related functionality together
|
||||
- Makes the CLI extensible for future commands
|
||||
- Follows common CLI patterns (like `git`, `docker`, etc.)
|
||||
|
||||
**Examples**:
|
||||
- `tinker version` - Show CLI and SDK version
|
||||
- `tinker run list` - List all training runs
|
||||
- `tinker run info <run-id>` - Show details of a specific run
|
||||
- `tinker checkpoint list` - List all checkpoints
|
||||
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
||||
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
|
||||
|
||||
### 4. Output System with Inheritance
|
||||
|
||||
**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class.
|
||||
|
||||
**Rationale**:
|
||||
- Enforces consistent interface across all commands
|
||||
- Encapsulates output logic with the command that generates it
|
||||
- Makes it easy to support multiple output formats (table, JSON)
|
||||
- Keeps related code together in the same module
|
||||
|
||||
**Implementation**:
|
||||
- `OutputBase` in `output.py` defines the contract
|
||||
- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`)
|
||||
- Base class handles format selection and rendering
|
||||
|
||||
### 5. Self-Contained Command Modules
|
||||
|
||||
**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point.
|
||||
|
||||
**Rationale**:
|
||||
- Modular architecture - commands can be developed independently
|
||||
- Clear separation of concerns
|
||||
- Easy to add new commands without modifying core files
|
||||
- Consistent pattern across all commands
|
||||
|
||||
**Command Structure**:
|
||||
```python
|
||||
# Each command file follows this pattern:
|
||||
@click.group() # or @click.command() for simple commands
|
||||
def cli():
|
||||
"""Command description."""
|
||||
pass
|
||||
|
||||
@cli.command() # For subcommands
|
||||
def list():
|
||||
"""Subcommand implementation."""
|
||||
pass
|
||||
```
|
||||
|
||||
### 6. Centralized Client Management
|
||||
|
||||
**Decision**: All SDK client creation and error handling is centralized in `client.py`.
|
||||
|
||||
**Rationale**:
|
||||
- Single place to handle authentication and connection errors
|
||||
- Consistent error messages across all commands
|
||||
- Reusable error handling decorator
|
||||
- Clean separation of concerns
|
||||
|
||||
### 7. Rich Tables for Human-Readable Output
|
||||
|
||||
**Decision**: Use the `rich` library for table formatting, kept as an optional dependency.
|
||||
|
||||
**Rationale**:
|
||||
- Provides beautiful, formatted tables with colors and borders
|
||||
- Handles column width adjustment automatically
|
||||
- Supports both dark and light terminal themes
|
||||
- Optional dependency keeps the core package lightweight
|
||||
|
||||
### 8. Unix-Style Default Output
|
||||
|
||||
**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output.
|
||||
|
||||
**Rationale**:
|
||||
- Follows Unix philosophy
|
||||
- Tables are better for human consumption
|
||||
- JSON is better for scripting and automation
|
||||
- Single flag switches between formats consistently
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **LazyGroup for deferred imports** - Commands only loaded when invoked
|
||||
2. **No heavy imports at module level** - Only Click imported initially
|
||||
3. **Lazy loading** of all SDK dependencies
|
||||
4. **Progress indicators** that clear themselves
|
||||
5. **Efficient data fetching** - fetch all data by default instead of pagination
|
||||
|
||||
## Error Handling Strategy
|
||||
|
||||
1. **User-friendly messages** - Technical errors are translated to helpful messages
|
||||
2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes
|
||||
3. **Graceful degradation** - Continue operation when possible
|
||||
4. **Detailed error info** - Show details when available, traceback only in TTY
|
||||
|
||||
### TinkerCliError Exception Pattern
|
||||
|
||||
All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`:
|
||||
|
||||
```python
|
||||
from ..exceptions import TinkerCliError
|
||||
|
||||
# Instead of:
|
||||
print(f"Error: Something went wrong", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Use:
|
||||
raise TinkerCliError(
|
||||
"Something went wrong",
|
||||
"Optional details or help text",
|
||||
exit_code=1 # Optional, defaults to 1
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Better testability (can catch exceptions in tests)
|
||||
- Centralized error formatting in `__main__.py`
|
||||
- Consistent exit codes across the CLI
|
||||
- Stack traces preserved for debugging
|
||||
|
||||
**Important Notes:**
|
||||
- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification
|
||||
- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages
|
||||
- The main error handler in `__main__.py` handles printing to stderr and exiting
|
||||
|
||||
## Future Extensibility
|
||||
|
||||
The architecture supports easy addition of:
|
||||
|
||||
### New Commands
|
||||
- Create new module in `commands/` directory
|
||||
- Define output classes in the same module if needed
|
||||
- Add command to lazy_subcommands in `__init__.py`
|
||||
|
||||
### New Subcommands
|
||||
- Add new Click command decorator to existing command module
|
||||
- Define corresponding output class if needed
|
||||
- Subcommands automatically discovered by Click
|
||||
|
||||
### New Output Formats
|
||||
- Override `print()` method in `OutputBase`
|
||||
- Or add new format handling to base class
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
1. **Startup time**: `time tinker --help` should be <50ms
|
||||
2. **Import verification**: Check that modules aren't imported unnecessarily
|
||||
3. **Output formats**: Test both table and JSON output
|
||||
4. **Error cases**: Test with missing auth, invalid IDs, network errors
|
||||
5. **Empty results**: Ensure graceful handling of no data
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
cli/
|
||||
├── __init__.py # Main entry with LazyGroup configuration
|
||||
├── __main__.py # Module execution support
|
||||
├── lazy_group.py # LazyGroup implementation for lazy loading
|
||||
├── output.py # OutputBase class and formatting utilities
|
||||
├── client.py # SDK client creation and error handling
|
||||
├── commands/
|
||||
│ ├── __init__.py # Command module marker
|
||||
│ ├── version.py # Version command
|
||||
│ ├── run.py # Run commands and output classes
|
||||
│ └── checkpoint.py # Checkpoint commands and output classes
|
||||
└── CLAUDE.md # This documentation
|
||||
```
|
||||
|
||||
## Command Examples
|
||||
|
||||
```bash
|
||||
# Show version
|
||||
tinker version
|
||||
|
||||
# List all training runs
|
||||
tinker run list
|
||||
|
||||
# Show run details
|
||||
tinker run info run-abc123
|
||||
|
||||
# List all checkpoints
|
||||
tinker checkpoint list
|
||||
|
||||
# List checkpoints for specific run
|
||||
tinker checkpoint list run-abc123
|
||||
|
||||
# Show checkpoint details
|
||||
tinker checkpoint info ckpt-xyz789
|
||||
|
||||
# Upload checkpoint to Hugging Face Hub
|
||||
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
|
||||
|
||||
# JSON output
|
||||
tinker --format json run list
|
||||
tinker --format json checkpoint list
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Required
|
||||
- Python 3.11+
|
||||
- tinker SDK (main package)
|
||||
- click>=8.0.0 (CLI framework)
|
||||
|
||||
### Optional
|
||||
- `rich` - For table formatting (installed with `pip install tinker[cli]`)
|
||||
|
||||
## Migration from Argparse to Click
|
||||
|
||||
### Key Changes:
|
||||
1. **Command Definition**: Decorators instead of `parser.add_argument()`
|
||||
2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch
|
||||
3. **Context Passing**: Click's context system for sharing format option
|
||||
4. **Error Handling**: Click handles exits and error formatting
|
||||
5. **Help Generation**: Automatic from docstrings and decorators
|
||||
|
||||
### Benefits:
|
||||
- Cleaner, more Pythonic code
|
||||
- Better command organization
|
||||
- Built-in testing utilities
|
||||
- Easier to extend with plugins
|
||||
- More consistent behavior across commands
|
||||
|
||||
## Maintenance Notes
|
||||
|
||||
1. **Keep imports lazy** - Use LazyGroup for all commands
|
||||
2. **Test startup time** - Regularly verify fast startup is maintained
|
||||
3. **Follow Click patterns** - Use decorators and context properly
|
||||
4. **Document changes** - Update this file when making architectural changes
|
||||
5. **Maintain consistency** - All commands should follow the same structure
|
||||
|
|
@ -1,285 +1 @@
|
|||
# Tinker CLI Design Documentation
|
||||
|
||||
## Overview
|
||||
|
||||
The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance.
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### 1. Lazy Import Strategy with Click
|
||||
|
||||
**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level.
|
||||
|
||||
**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands.
|
||||
|
||||
**Implementation**:
|
||||
- Main `__init__.py` only imports `click` and `lazy_group`
|
||||
- Command modules are loaded only when invoked via `LazyGroup`
|
||||
- Output formatting imports `rich` only when table output is needed
|
||||
- JSON module imported only when JSON output is requested
|
||||
- Version information loaded from `_version.py` only when `tinker version` is used
|
||||
|
||||
### 2. Click Framework with LazyGroup
|
||||
|
||||
**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading.
|
||||
|
||||
**Rationale**:
|
||||
- Click provides cleaner command structure with decorators
|
||||
- Better subcommand isolation - each command file is self-contained
|
||||
- Automatic help generation with better formatting
|
||||
- Built-in type conversion and validation
|
||||
- LazyGroup enables fast startup by deferring imports
|
||||
|
||||
**LazyGroup Implementation**:
|
||||
```python
|
||||
class LazyGroup(click.Group):
|
||||
def __init__(self, *args, lazy_subcommands=None, **kwargs):
|
||||
# Map of command name to "module.path:command_name"
|
||||
self.lazy_subcommands = lazy_subcommands or {}
|
||||
|
||||
def get_command(self, ctx, cmd_name):
|
||||
if cmd_name in self.lazy_subcommands:
|
||||
# Import only when command is actually invoked
|
||||
import_path = self.lazy_subcommands[cmd_name]
|
||||
module_name, attr_name = import_path.rsplit(":", 1)
|
||||
mod = importlib.import_module(module_name)
|
||||
return getattr(mod, attr_name)
|
||||
```
|
||||
|
||||
### 3. Hierarchical Command Structure
|
||||
|
||||
**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`.
|
||||
|
||||
**Rationale**:
|
||||
- Provides a consistent, predictable interface
|
||||
- Groups related functionality together
|
||||
- Makes the CLI extensible for future commands
|
||||
- Follows common CLI patterns (like `git`, `docker`, etc.)
|
||||
|
||||
**Examples**:
|
||||
- `tinker version` - Show CLI and SDK version
|
||||
- `tinker run list` - List all training runs
|
||||
- `tinker run info <run-id>` - Show details of a specific run
|
||||
- `tinker checkpoint list` - List all checkpoints
|
||||
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
|
||||
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
|
||||
|
||||
### 4. Output System with Inheritance
|
||||
|
||||
**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class.
|
||||
|
||||
**Rationale**:
|
||||
- Enforces consistent interface across all commands
|
||||
- Encapsulates output logic with the command that generates it
|
||||
- Makes it easy to support multiple output formats (table, JSON)
|
||||
- Keeps related code together in the same module
|
||||
|
||||
**Implementation**:
|
||||
- `OutputBase` in `output.py` defines the contract
|
||||
- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`)
|
||||
- Base class handles format selection and rendering
|
||||
|
||||
### 5. Self-Contained Command Modules
|
||||
|
||||
**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point.
|
||||
|
||||
**Rationale**:
|
||||
- Modular architecture - commands can be developed independently
|
||||
- Clear separation of concerns
|
||||
- Easy to add new commands without modifying core files
|
||||
- Consistent pattern across all commands
|
||||
|
||||
**Command Structure**:
|
||||
```python
|
||||
# Each command file follows this pattern:
|
||||
@click.group() # or @click.command() for simple commands
|
||||
def cli():
|
||||
"""Command description."""
|
||||
pass
|
||||
|
||||
@cli.command() # For subcommands
|
||||
def list():
|
||||
"""Subcommand implementation."""
|
||||
pass
|
||||
```
|
||||
|
||||
### 6. Centralized Client Management
|
||||
|
||||
**Decision**: All SDK client creation and error handling is centralized in `client.py`.
|
||||
|
||||
**Rationale**:
|
||||
- Single place to handle authentication and connection errors
|
||||
- Consistent error messages across all commands
|
||||
- Reusable error handling decorator
|
||||
- Clean separation of concerns
|
||||
|
||||
### 7. Rich Tables for Human-Readable Output
|
||||
|
||||
**Decision**: Use the `rich` library for table formatting, kept as an optional dependency.
|
||||
|
||||
**Rationale**:
|
||||
- Provides beautiful, formatted tables with colors and borders
|
||||
- Handles column width adjustment automatically
|
||||
- Supports both dark and light terminal themes
|
||||
- Optional dependency keeps the core package lightweight
|
||||
|
||||
### 8. Unix-Style Default Output
|
||||
|
||||
**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output.
|
||||
|
||||
**Rationale**:
|
||||
- Follows Unix philosophy
|
||||
- Tables are better for human consumption
|
||||
- JSON is better for scripting and automation
|
||||
- Single flag switches between formats consistently
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **LazyGroup for deferred imports** - Commands only loaded when invoked
|
||||
2. **No heavy imports at module level** - Only Click imported initially
|
||||
3. **Lazy loading** of all SDK dependencies
|
||||
4. **Progress indicators** that clear themselves
|
||||
5. **Efficient data fetching** - fetch all data by default instead of pagination
|
||||
|
||||
## Error Handling Strategy
|
||||
|
||||
1. **User-friendly messages** - Technical errors are translated to helpful messages
|
||||
2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes
|
||||
3. **Graceful degradation** - Continue operation when possible
|
||||
4. **Detailed error info** - Show details when available, traceback only in TTY
|
||||
|
||||
### TinkerCliError Exception Pattern
|
||||
|
||||
All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`:
|
||||
|
||||
```python
|
||||
from ..exceptions import TinkerCliError
|
||||
|
||||
# Instead of:
|
||||
print(f"Error: Something went wrong", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Use:
|
||||
raise TinkerCliError(
|
||||
"Something went wrong",
|
||||
"Optional details or help text",
|
||||
exit_code=1 # Optional, defaults to 1
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Better testability (can catch exceptions in tests)
|
||||
- Centralized error formatting in `__main__.py`
|
||||
- Consistent exit codes across the CLI
|
||||
- Stack traces preserved for debugging
|
||||
|
||||
**Important Notes:**
|
||||
- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification
|
||||
- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages
|
||||
- The main error handler in `__main__.py` handles printing to stderr and exiting
|
||||
|
||||
## Future Extensibility
|
||||
|
||||
The architecture supports easy addition of:
|
||||
|
||||
### New Commands
|
||||
- Create new module in `commands/` directory
|
||||
- Define output classes in the same module if needed
|
||||
- Add command to lazy_subcommands in `__init__.py`
|
||||
|
||||
### New Subcommands
|
||||
- Add new Click command decorator to existing command module
|
||||
- Define corresponding output class if needed
|
||||
- Subcommands automatically discovered by Click
|
||||
|
||||
### New Output Formats
|
||||
- Override `print()` method in `OutputBase`
|
||||
- Or add new format handling to base class
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
1. **Startup time**: `time tinker --help` should be <50ms
|
||||
2. **Import verification**: Check that modules aren't imported unnecessarily
|
||||
3. **Output formats**: Test both table and JSON output
|
||||
4. **Error cases**: Test with missing auth, invalid IDs, network errors
|
||||
5. **Empty results**: Ensure graceful handling of no data
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
cli/
|
||||
├── __init__.py # Main entry with LazyGroup configuration
|
||||
├── __main__.py # Module execution support
|
||||
├── lazy_group.py # LazyGroup implementation for lazy loading
|
||||
├── output.py # OutputBase class and formatting utilities
|
||||
├── client.py # SDK client creation and error handling
|
||||
├── commands/
|
||||
│ ├── __init__.py # Command module marker
|
||||
│ ├── version.py # Version command
|
||||
│ ├── run.py # Run commands and output classes
|
||||
│ └── checkpoint.py # Checkpoint commands and output classes
|
||||
└── CLAUDE.md # This documentation
|
||||
```
|
||||
|
||||
## Command Examples
|
||||
|
||||
```bash
|
||||
# Show version
|
||||
tinker version
|
||||
|
||||
# List all training runs
|
||||
tinker run list
|
||||
|
||||
# Show run details
|
||||
tinker run info run-abc123
|
||||
|
||||
# List all checkpoints
|
||||
tinker checkpoint list
|
||||
|
||||
# List checkpoints for specific run
|
||||
tinker checkpoint list run-abc123
|
||||
|
||||
# Show checkpoint details
|
||||
tinker checkpoint info ckpt-xyz789
|
||||
|
||||
# Upload checkpoint to Hugging Face Hub
|
||||
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
|
||||
|
||||
# JSON output
|
||||
tinker --format json run list
|
||||
tinker --format json checkpoint list
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Required
|
||||
- Python 3.11+
|
||||
- tinker SDK (main package)
|
||||
- click>=8.0.0 (CLI framework)
|
||||
|
||||
### Optional
|
||||
- `rich` - For table formatting (installed with `pip install tinker[cli]`)
|
||||
|
||||
## Migration from Argparse to Click
|
||||
|
||||
### Key Changes:
|
||||
1. **Command Definition**: Decorators instead of `parser.add_argument()`
|
||||
2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch
|
||||
3. **Context Passing**: Click's context system for sharing format option
|
||||
4. **Error Handling**: Click handles exits and error formatting
|
||||
5. **Help Generation**: Automatic from docstrings and decorators
|
||||
|
||||
### Benefits:
|
||||
- Cleaner, more Pythonic code
|
||||
- Better command organization
|
||||
- Built-in testing utilities
|
||||
- Easier to extend with plugins
|
||||
- More consistent behavior across commands
|
||||
|
||||
## Maintenance Notes
|
||||
|
||||
1. **Keep imports lazy** - Use LazyGroup for all commands
|
||||
2. **Test startup time** - Regularly verify fast startup is maintained
|
||||
3. **Follow Click patterns** - Use decorators and context properly
|
||||
4. **Document changes** - Update this file when making architectural changes
|
||||
5. **Maintain consistency** - All commands should follow the same structure
|
||||
@AGENTS.md
|
||||
|
|
|
|||
104
src/tinker/lib/_auth_token_provider.py
Normal file
104
src/tinker/lib/_auth_token_provider.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
"""Authentication credential management for the Tinker SDK.
|
||||
|
||||
Provides composable credential providers that plug into httpx's async auth flow:
|
||||
- AuthTokenProvider: abstract base (httpx.Auth) — subclasses implement get_token()
|
||||
- ApiKeyAuthProvider: resolves from api_key arg or TINKER_API_KEY env var
|
||||
- CredentialCmdAuthProvider: runs a command on every call for fresh credentials
|
||||
- resolve_auth_provider(): factory that picks the right provider
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from tinker._exceptions import TinkerError
|
||||
|
||||
|
||||
class AuthTokenProvider(httpx.Auth):
|
||||
"""Abstract base auth provider. Subclasses implement get_token()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_token(self) -> str | None: ...
|
||||
|
||||
async def async_auth_flow(
|
||||
self, request: httpx.Request
|
||||
) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||
token = await self.get_token()
|
||||
if token:
|
||||
request.headers["X-API-Key"] = token
|
||||
yield request
|
||||
|
||||
|
||||
class ApiKeyAuthProvider(AuthTokenProvider):
|
||||
"""Resolves api_key from constructor arg or TINKER_API_KEY env var."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
resolved = api_key or os.environ.get("TINKER_API_KEY")
|
||||
if not resolved:
|
||||
raise TinkerError(
|
||||
"The api_key client option must be set either by passing api_key to the client"
|
||||
" or by setting the TINKER_API_KEY environment variable"
|
||||
)
|
||||
if not resolved.startswith("tml-") and not resolved.startswith("eyJ"):
|
||||
raise TinkerError("The api_key must start with the 'tml-' prefix")
|
||||
self._token = resolved
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
return self._token
|
||||
|
||||
|
||||
class CredentialCmdAuthProvider(AuthTokenProvider):
|
||||
"""Runs TINKER_CREDENTIAL_CMD on every get_token() call.
|
||||
|
||||
Always produces a fresh credential (e.g. short-lived bearer tokens).
|
||||
Uses async subprocess to avoid blocking the event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, cmd: str) -> None:
|
||||
if not cmd:
|
||||
raise TinkerError(
|
||||
"Your organization requires dynamic credentials — set TINKER_CREDENTIAL_CMD"
|
||||
" to a command that prints a valid credential."
|
||||
)
|
||||
self._cmd = cmd
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
self._cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
credential = stdout.decode().strip()
|
||||
if not credential:
|
||||
raise TinkerError("TINKER_CREDENTIAL_CMD returned an empty credential.")
|
||||
return credential
|
||||
|
||||
|
||||
def resolve_auth_provider(api_key: str | None, enforce_cmd: bool) -> AuthTokenProvider:
|
||||
"""Construct the appropriate auth provider based on available credentials.
|
||||
|
||||
- enforce_cmd=True: uses TINKER_CREDENTIAL_CMD, unless the api_key is
|
||||
already a JWT (dynamic credential) — in which case it's used directly.
|
||||
- enforce_cmd=False: tries api_key first, falls back to TINKER_CREDENTIAL_CMD
|
||||
"""
|
||||
credential_cmd = os.environ.get("TINKER_CREDENTIAL_CMD", "")
|
||||
|
||||
# A JWT passed as api_key is already a dynamic credential — use it
|
||||
# directly even when credential_cmd is enforced.
|
||||
resolved = api_key or os.environ.get("TINKER_API_KEY", "")
|
||||
if resolved and resolved.startswith("eyJ"):
|
||||
return ApiKeyAuthProvider(api_key=resolved)
|
||||
|
||||
if enforce_cmd:
|
||||
return CredentialCmdAuthProvider(credential_cmd)
|
||||
|
||||
try:
|
||||
return ApiKeyAuthProvider(api_key=api_key)
|
||||
except TinkerError:
|
||||
if credential_cmd:
|
||||
return CredentialCmdAuthProvider(credential_cmd)
|
||||
raise
|
||||
90
src/tinker/lib/_jwt_auth.py
Normal file
90
src/tinker/lib/_jwt_auth.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""JWT authentication for Tinker SDK.
|
||||
|
||||
Internal to the SDK; not part of the public API.
|
||||
|
||||
When the server sets pjwt_auth_enabled, the SDK exchanges the caller's
|
||||
credential for a short-lived JWT minted by the Tinker server. The JWT is
|
||||
cached and refreshed in the background before it expires, so callers always
|
||||
send a valid token without any per-request overhead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
from tinker.lib._auth_token_provider import AuthTokenProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REFRESH_BEFORE_EXPIRY_SECS = 300 # refresh 5 min before expiry
|
||||
_RETRY_DELAY_SECS = 60
|
||||
|
||||
|
||||
def _jwt_expiry(jwt: str) -> float:
|
||||
"""Return the exp claim of a JWT as a Unix timestamp."""
|
||||
try:
|
||||
payload = jwt.split(".")[1]
|
||||
payload += "=" * (-len(payload) % 4)
|
||||
return float(json.loads(base64.urlsafe_b64decode(payload))["exp"])
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse JWT expiry: {e}") from e
|
||||
|
||||
|
||||
class JwtAuthProvider(AuthTokenProvider):
|
||||
"""AuthTokenProvider that exchanges a credential for a short-lived JWT.
|
||||
|
||||
After init(), get_token() returns the current JWT. A background task
|
||||
refreshes the JWT before it expires.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aclient_fn: Callable[[], AbstractContextManager],
|
||||
seed_token: str | None = None,
|
||||
) -> None:
|
||||
self._token = seed_token or ""
|
||||
self._aclient_fn = aclient_fn
|
||||
|
||||
async def get_token(self) -> str | None:
|
||||
return self._token
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Fetch a JWT (unless seeded) then start the background refresh loop.
|
||||
|
||||
When seed_token was provided, skips the initial fetch and starts
|
||||
refreshing from the seed — useful for shadow holders that already
|
||||
have a valid JWT from the primary holder.
|
||||
"""
|
||||
token = self._token if self._token else await self._fetch()
|
||||
self._refresh_task = asyncio.create_task(self._refresh_loop(token))
|
||||
|
||||
async def _fetch(self) -> str:
|
||||
"""Exchange the current credential for a JWT via /api/v1/auth/token."""
|
||||
with self._aclient_fn() as client:
|
||||
response = await client.service.auth_token()
|
||||
self._token = response.jwt
|
||||
return response.jwt
|
||||
|
||||
async def _refresh_loop(self, token: str) -> None:
|
||||
while True:
|
||||
try:
|
||||
delay = max(
|
||||
_RETRY_DELAY_SECS,
|
||||
_jwt_expiry(token) - time.time() - _REFRESH_BEFORE_EXPIRY_SECS,
|
||||
)
|
||||
except ValueError:
|
||||
logger.debug("Failed to parse JWT expiry, retrying in %ds", _RETRY_DELAY_SECS)
|
||||
delay = _RETRY_DELAY_SECS
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
token = await self._fetch()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("JWT refresh failed, retrying in %ds: %s", _RETRY_DELAY_SECS, e)
|
||||
156
src/tinker/lib/_jwt_auth_test.py
Normal file
156
src/tinker/lib/_jwt_auth_test.py
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
"""Tests for JWT authentication helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tinker._exceptions import TinkerError
|
||||
from tinker.lib._auth_token_provider import (
|
||||
ApiKeyAuthProvider,
|
||||
CredentialCmdAuthProvider,
|
||||
resolve_auth_provider,
|
||||
)
|
||||
from tinker.lib._jwt_auth import (
|
||||
JwtAuthProvider,
|
||||
_jwt_expiry,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_jwt(exp: float) -> str:
|
||||
"""Build a minimal fake JWT with a given exp claim."""
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_bytes = json.dumps({"exp": exp, "sub": "test"}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode()
|
||||
return f"{header}.{payload}.fakesig"
|
||||
|
||||
|
||||
class _MockAuthResponse:
|
||||
def __init__(self, jwt: str) -> None:
|
||||
self.jwt = jwt
|
||||
|
||||
|
||||
class _MockHolder:
|
||||
"""Minimal mock providing aclient() for testing JwtAuthProvider."""
|
||||
|
||||
def __init__(self, response_jwt: str, *, fail: bool = False) -> None:
|
||||
service = MagicMock()
|
||||
if fail:
|
||||
service.auth_token = AsyncMock(side_effect=Exception("network error"))
|
||||
else:
|
||||
service.auth_token = AsyncMock(return_value=_MockAuthResponse(response_jwt))
|
||||
client = MagicMock()
|
||||
client.service = service
|
||||
cm = MagicMock()
|
||||
cm.__enter__ = MagicMock(return_value=client)
|
||||
cm.__exit__ = MagicMock(return_value=None)
|
||||
self._cm = cm
|
||||
|
||||
def aclient(self):
|
||||
return self._cm
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _jwt_expiry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_jwt_expiry_parses_valid():
|
||||
exp = time.time() + 3600
|
||||
assert abs(_jwt_expiry(_make_jwt(exp)) - exp) < 1
|
||||
|
||||
|
||||
def test_jwt_expiry_raises_on_invalid():
|
||||
with pytest.raises(Exception):
|
||||
_jwt_expiry("not.a.jwt")
|
||||
|
||||
|
||||
def test_jwt_expiry_raises_on_missing_exp():
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
|
||||
payload = base64.urlsafe_b64encode(b'{"sub":"x"}').rstrip(b"=").decode()
|
||||
with pytest.raises(Exception):
|
||||
_jwt_expiry(f"{header}.{payload}.sig")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthTokenProvider hierarchy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_provider_resolves_key():
|
||||
auth = ApiKeyAuthProvider(api_key="tml-test-key")
|
||||
assert await auth.get_token() == "tml-test-key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credential_cmd_provider_runs_command():
|
||||
auth = CredentialCmdAuthProvider("echo test-credential")
|
||||
assert await auth.get_token() == "test-credential"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_auth_provider_fallback_to_cmd(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delenv("TINKER_API_KEY", raising=False)
|
||||
monkeypatch.setenv("TINKER_CREDENTIAL_CMD", "echo fallback-cred")
|
||||
auth = resolve_auth_provider(api_key=None, enforce_cmd=False)
|
||||
assert isinstance(auth, CredentialCmdAuthProvider)
|
||||
assert await auth.get_token() == "fallback-cred"
|
||||
|
||||
|
||||
def test_credential_cmd_provider_raises_with_empty_cmd():
|
||||
with pytest.raises(TinkerError, match="dynamic credentials"):
|
||||
CredentialCmdAuthProvider("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JwtAuthProvider.init
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_fetches_jwt_and_stores_it():
|
||||
exp = time.time() + 7200
|
||||
jwt = _make_jwt(exp)
|
||||
holder = _MockHolder(jwt)
|
||||
provider = JwtAuthProvider(holder.aclient)
|
||||
|
||||
await provider.init()
|
||||
|
||||
assert await provider.get_token() == jwt
|
||||
holder._cm.__enter__.return_value.service.auth_token.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_raises_on_fetch_failure():
|
||||
holder = _MockHolder("some-jwt", fail=True)
|
||||
provider = JwtAuthProvider(holder.aclient)
|
||||
|
||||
with pytest.raises(Exception, match="network error"):
|
||||
await provider.init()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JwtAuthProvider._fetch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_returns_and_stores_token():
|
||||
exp = time.time() + 7200
|
||||
jwt = _make_jwt(exp)
|
||||
holder = _MockHolder(jwt)
|
||||
provider = JwtAuthProvider(holder.aclient)
|
||||
|
||||
result = await provider._fetch()
|
||||
|
||||
assert result == jwt
|
||||
assert await provider.get_token() == jwt
|
||||
|
|
@ -21,6 +21,12 @@ from tinker import types
|
|||
from tinker._client import AsyncTinker
|
||||
from tinker._exceptions import APIConnectionError, APIStatusError
|
||||
from tinker._version import __version__ as tinker_sdk_version
|
||||
from tinker.lib._auth_token_provider import (
|
||||
ApiKeyAuthProvider,
|
||||
AuthTokenProvider,
|
||||
resolve_auth_provider,
|
||||
)
|
||||
from tinker.lib._jwt_auth import JwtAuthProvider
|
||||
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
|
|
@ -180,17 +186,63 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
project_id: str | None = None,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
api_key: str | None = None,
|
||||
_client_config: dict[str, str | int | bool] | None = None,
|
||||
_jwt_auth_seed: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._constructor_kwargs = kwargs
|
||||
self._api_key = api_key
|
||||
self._constructor_kwargs = dict(kwargs)
|
||||
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
|
||||
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
|
||||
self._sample_backoff_until: float | None = None
|
||||
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
|
||||
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
|
||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
|
||||
self._training_client_lock: threading.Lock = threading.Lock()
|
||||
self._telemetry: Telemetry | None = None
|
||||
|
||||
# Fetch server-side client config before any server contact so that
|
||||
# flags are available for subsequent setup steps. Shadow holders
|
||||
# receive the config via kwargs to avoid a redundant fetch (and
|
||||
# potential deadlock on the event loop thread).
|
||||
if _client_config is not None:
|
||||
self._client_config = types.ClientConfigResponse.model_validate(_client_config)
|
||||
else:
|
||||
self._assert_not_on_event_loop("fetch client config")
|
||||
config_auth = resolve_auth_provider(api_key, enforce_cmd=False)
|
||||
self._client_config = self.run_coroutine_threadsafe(
|
||||
self._fetch_client_config(config_auth)
|
||||
).result()
|
||||
|
||||
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(
|
||||
self._client_config.sample_dispatch_bytes_semaphore_size
|
||||
)
|
||||
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(
|
||||
self._client_config.inflight_response_bytes_semaphore_size
|
||||
)
|
||||
|
||||
if not self._client_config.pjwt_auth_enabled:
|
||||
# Without JWT exchange, only API keys are accepted by the server.
|
||||
# Replace any cmd-based provider with a plain API key provider.
|
||||
self._default_auth = ApiKeyAuthProvider(api_key=api_key)
|
||||
else:
|
||||
# Create a dedicated pool for JWT exchange with the appropriate
|
||||
# credential provider. The lambda captures the pool so it stays alive.
|
||||
use_cmd = self._client_config.credential_default_source == "credential_cmd"
|
||||
auth_pool_auth = resolve_auth_provider(self._api_key, use_cmd)
|
||||
auth_kwargs = {**self._constructor_kwargs, "_auth": auth_pool_auth}
|
||||
auth_pool = ClientConnectionPool(self.get_loop(), 1, auth_kwargs)
|
||||
auth_aclient = lambda: auth_pool.aclient() # noqa: E731
|
||||
self._default_auth = JwtAuthProvider(auth_aclient, seed_token=_jwt_auth_seed)
|
||||
if _jwt_auth_seed:
|
||||
# Shadow holder: start refresh in background, don't block.
|
||||
self.run_coroutine_threadsafe(self._default_auth.init())
|
||||
else:
|
||||
# Primary holder: must have a valid JWT before proceeding.
|
||||
self._assert_not_on_event_loop("exchange JWT")
|
||||
self.run_coroutine_threadsafe(
|
||||
self.execute_with_retries(self._default_auth.init)
|
||||
).result()
|
||||
|
||||
if session_id is not None:
|
||||
# Shadow mode: reuse existing session, can't create new clients
|
||||
|
|
@ -199,14 +251,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
self._sampling_client_counter: int | None = None
|
||||
else:
|
||||
# Normal mode: create new session.
|
||||
# This blocks on .result() — must NOT be called from the event
|
||||
# loop thread (e.g. inside the sidecar subprocess). Shadow
|
||||
# holders (session_id is not None) skip this path.
|
||||
if self._loop.is_running() and _current_loop() is self._loop:
|
||||
raise RuntimeError(
|
||||
"Cannot create a new session from the event loop thread. "
|
||||
"Use session_id= to create a shadow holder instead."
|
||||
)
|
||||
self._assert_not_on_event_loop("create a new session")
|
||||
self._session_id = self.run_coroutine_threadsafe(
|
||||
self._create_session(user_metadata=user_metadata, project_id=project_id)
|
||||
).result()
|
||||
|
|
@ -230,6 +275,26 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
"""Get or create a shadow holder from the singleton cache."""
|
||||
return _shadow_holder_singleton.get_or_create(session_id, kwargs)
|
||||
|
||||
def _assert_not_on_event_loop(self, action: str) -> None:
|
||||
"""Raise if called from the event loop thread (would deadlock on .result())."""
|
||||
if self._loop.is_running() and _current_loop() is self._loop:
|
||||
raise RuntimeError(
|
||||
f"Cannot {action} from the event loop thread. "
|
||||
"Use session_id= to create a shadow holder instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def shadow_kwargs(self) -> dict[str, Any]:
|
||||
"""Constructor kwargs for shadow holders, including cached server config and JWT seed."""
|
||||
result = {
|
||||
**self._constructor_kwargs,
|
||||
"api_key": self._api_key,
|
||||
"_client_config": self._client_config.model_dump(),
|
||||
}
|
||||
if isinstance(self._default_auth, JwtAuthProvider):
|
||||
result["_jwt_auth_seed"] = self._default_auth._token
|
||||
return result
|
||||
|
||||
@asynccontextmanager
|
||||
async def _sample_dispatch_count_rate_limit(self):
|
||||
async with self._sample_dispatch_semaphore:
|
||||
|
|
@ -316,6 +381,23 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
"""Start the session heartbeat task."""
|
||||
return asyncio.create_task(self._session_heartbeat(self._session_id))
|
||||
|
||||
async def _fetch_client_config(self, auth: AuthTokenProvider) -> types.ClientConfigResponse:
|
||||
"""Call /api/v1/client/config and return server feature flags.
|
||||
|
||||
Creates a one-off connection pool with the given auth. Retries
|
||||
transient failures via execute_with_retries.
|
||||
"""
|
||||
kwargs = {**self._constructor_kwargs, "_auth": auth}
|
||||
pool = ClientConnectionPool(self.get_loop(), 1, kwargs)
|
||||
|
||||
async def _once() -> types.ClientConfigResponse:
|
||||
with pool.aclient() as client:
|
||||
return await client.service.client_config(
|
||||
request=types.ClientConfigRequest(sdk_version=tinker_sdk_version)
|
||||
)
|
||||
|
||||
return await self.execute_with_retries(_once)
|
||||
|
||||
async def _create_session(
|
||||
self,
|
||||
user_metadata: dict[str, str] | None = None,
|
||||
|
|
@ -350,8 +432,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
|
|||
if client_pool_type == ClientConnectionPoolType.TRAIN
|
||||
else MAX_REQUESTS_PER_HTTPX_CLIENT
|
||||
)
|
||||
kwargs = {**self._constructor_kwargs, "_auth": self._default_auth}
|
||||
self._client_pools[client_pool_type] = ClientConnectionPool(
|
||||
self.get_loop(), max_requests_per_client, self._constructor_kwargs
|
||||
self.get_loop(), max_requests_per_client, kwargs
|
||||
)
|
||||
return self._client_pools[client_pool_type]
|
||||
|
||||
|
|
|
|||
96
src/tinker/lib/internal_client_holder_test.py
Normal file
96
src/tinker/lib/internal_client_holder_test.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Tests for InternalClientHolder helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tinker.lib._auth_token_provider import AuthTokenProvider
|
||||
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
|
||||
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
|
||||
|
||||
|
||||
class _MockHolder:
|
||||
"""Minimal stand-in for testing _fetch_client_config."""
|
||||
|
||||
def __init__(self, response: _ClientConfigResponse | Exception) -> None:
|
||||
service = MagicMock()
|
||||
if isinstance(response, Exception):
|
||||
service.client_config = AsyncMock(side_effect=response)
|
||||
else:
|
||||
service.client_config = AsyncMock(return_value=response)
|
||||
client = MagicMock()
|
||||
client.service = service
|
||||
cm = MagicMock()
|
||||
cm.__enter__ = MagicMock(return_value=client)
|
||||
cm.__exit__ = MagicMock(return_value=None)
|
||||
self._cm = cm
|
||||
|
||||
self._constructor_kwargs: dict[str, Any] = {}
|
||||
self._default_auth = MagicMock(spec=AuthTokenProvider)
|
||||
self._loop = asyncio.get_event_loop()
|
||||
|
||||
def get_loop(self) -> asyncio.AbstractEventLoop:
|
||||
return self._loop
|
||||
|
||||
async def execute_with_retries(self, func: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
# Bind the real method so the pool it creates uses our mock client
|
||||
_fetch_client_config = InternalClientHolder._fetch_client_config
|
||||
|
||||
|
||||
def _patch_pool(monkeypatch: pytest.MonkeyPatch, holder: _MockHolder) -> None:
|
||||
monkeypatch.setattr(ClientConnectionPool, "aclient", lambda self: holder._cm)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _fetch_client_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_client_config_returns_flags_from_server(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=True))
|
||||
_patch_pool(monkeypatch, holder)
|
||||
result = await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
|
||||
assert result.pjwt_auth_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_client_config_returns_defaults_when_server_disables(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=False))
|
||||
_patch_pool(monkeypatch, holder)
|
||||
result = await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
|
||||
assert result.pjwt_auth_enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_client_config_raises_on_network_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
holder = _MockHolder(Exception("connection refused"))
|
||||
_patch_pool(monkeypatch, holder)
|
||||
with pytest.raises(Exception, match="connection refused"):
|
||||
await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_client_config_passes_sdk_version(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
from tinker._version import __version__ as tinker_sdk_version
|
||||
|
||||
holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=False))
|
||||
_patch_pool(monkeypatch, holder)
|
||||
await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
|
||||
|
||||
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
|
||||
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version
|
||||
|
|
@ -451,7 +451,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
_SamplingClientPickleState(
|
||||
session_id=self.holder.get_session_id(),
|
||||
sampling_session_id=self._sampling_session_id,
|
||||
constructor_kwargs=self.holder._constructor_kwargs,
|
||||
constructor_kwargs=self.holder.shadow_kwargs,
|
||||
subprocess_sampling=self._sampling_client_sidecar_handle is not None,
|
||||
),
|
||||
),
|
||||
|
|
|
|||
|
|
@ -230,10 +230,26 @@ class ServiceClient(TelemetryProvider):
|
|||
user_metadata,
|
||||
).result_async()
|
||||
|
||||
def _get_rest_client_for_weights(self, weights_access_token: str | None = None) -> RestClient:
|
||||
"""Get a rest client for weights info lookups.
|
||||
|
||||
If weights_access_token is provided, creates a separate ServiceClient
|
||||
authenticated with that token.
|
||||
"""
|
||||
if weights_access_token is not None:
|
||||
token_client = ServiceClient(
|
||||
api_key=weights_access_token, **self.holder._constructor_kwargs
|
||||
)
|
||||
return token_client.create_rest_client()
|
||||
return self.create_rest_client()
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_training_client_from_state(
|
||||
self, path: str, user_metadata: dict[str, str] | None = None
|
||||
self,
|
||||
path: str,
|
||||
user_metadata: dict[str, str] | None = None,
|
||||
weights_access_token: str | None = None,
|
||||
) -> TrainingClient:
|
||||
"""Create a TrainingClient from saved model weights.
|
||||
|
||||
|
|
@ -243,6 +259,7 @@ class ServiceClient(TelemetryProvider):
|
|||
Args:
|
||||
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
- `user_metadata`: Optional metadata to attach to the new training run
|
||||
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
|
||||
|
||||
Returns:
|
||||
- `TrainingClient` loaded with the specified weights
|
||||
|
|
@ -256,7 +273,7 @@ class ServiceClient(TelemetryProvider):
|
|||
# Continue training from the loaded state
|
||||
```
|
||||
"""
|
||||
rest_client = self.create_rest_client()
|
||||
rest_client = self._get_rest_client_for_weights(weights_access_token)
|
||||
# Use weights info endpoint which allows access to models with public checkpoints
|
||||
weights_info = rest_client.get_weights_info_by_tinker_path(path).result()
|
||||
|
||||
|
|
@ -271,15 +288,18 @@ class ServiceClient(TelemetryProvider):
|
|||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
training_client.load_state(path).result()
|
||||
training_client.load_state(path, weights_access_token=weights_access_token).result()
|
||||
return training_client
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_training_client_from_state_async(
|
||||
self, path: str, user_metadata: dict[str, str] | None = None
|
||||
self,
|
||||
path: str,
|
||||
user_metadata: dict[str, str] | None = None,
|
||||
weights_access_token: str | None = None,
|
||||
) -> TrainingClient:
|
||||
"""Async version of create_training_client_from_state."""
|
||||
rest_client = self.create_rest_client()
|
||||
rest_client = self._get_rest_client_for_weights(weights_access_token)
|
||||
# Use weights info endpoint which allows access to models with public checkpoints
|
||||
weights_info = await rest_client.get_weights_info_by_tinker_path(path)
|
||||
|
||||
|
|
@ -297,14 +317,19 @@ class ServiceClient(TelemetryProvider):
|
|||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
load_future = await training_client.load_state_async(path)
|
||||
load_future = await training_client.load_state_async(
|
||||
path, weights_access_token=weights_access_token
|
||||
)
|
||||
await load_future.result_async()
|
||||
return training_client
|
||||
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_training_client_from_state_with_optimizer(
|
||||
self, path: str, user_metadata: dict[str, str] | None = None
|
||||
self,
|
||||
path: str,
|
||||
user_metadata: dict[str, str] | None = None,
|
||||
weights_access_token: str | None = None,
|
||||
) -> TrainingClient:
|
||||
"""Create a TrainingClient from saved model weights and optimizer state.
|
||||
|
||||
|
|
@ -315,6 +340,7 @@ class ServiceClient(TelemetryProvider):
|
|||
Args:
|
||||
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
- `user_metadata`: Optional metadata to attach to the new training run
|
||||
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
|
||||
|
||||
Returns:
|
||||
- `TrainingClient` loaded with the specified weights and optimizer state
|
||||
|
|
@ -328,7 +354,7 @@ class ServiceClient(TelemetryProvider):
|
|||
# Continue training with restored optimizer momentum
|
||||
```
|
||||
"""
|
||||
rest_client = self.create_rest_client()
|
||||
rest_client = self._get_rest_client_for_weights(weights_access_token)
|
||||
# Use weights info endpoint which allows access to models with public checkpoints
|
||||
weights_info = rest_client.get_weights_info_by_tinker_path(path).result()
|
||||
|
||||
|
|
@ -343,15 +369,20 @@ class ServiceClient(TelemetryProvider):
|
|||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
training_client.load_state_with_optimizer(path).result()
|
||||
training_client.load_state_with_optimizer(
|
||||
path, weights_access_token=weights_access_token
|
||||
).result()
|
||||
return training_client
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_training_client_from_state_with_optimizer_async(
|
||||
self, path: str, user_metadata: dict[str, str] | None = None
|
||||
self,
|
||||
path: str,
|
||||
user_metadata: dict[str, str] | None = None,
|
||||
weights_access_token: str | None = None,
|
||||
) -> TrainingClient:
|
||||
"""Async version of create_training_client_from_state_with_optimizer."""
|
||||
rest_client = self.create_rest_client()
|
||||
rest_client = self._get_rest_client_for_weights(weights_access_token)
|
||||
# Use weights info endpoint which allows access to models with public checkpoints
|
||||
weights_info = await rest_client.get_weights_info_by_tinker_path(path)
|
||||
|
||||
|
|
@ -369,7 +400,9 @@ class ServiceClient(TelemetryProvider):
|
|||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
load_future = await training_client.load_state_with_optimizer_async(path)
|
||||
load_future = await training_client.load_state_with_optimizer_async(
|
||||
path, weights_access_token=weights_access_token
|
||||
)
|
||||
await load_future.result_async()
|
||||
return training_client
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ import asyncio
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple
|
||||
|
||||
from tinker import types
|
||||
from tinker._exceptions import ConflictError
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
|
|
@ -40,6 +42,7 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# FwdBwdChunkSize
|
||||
MAX_CHUNK_LEN = 1024
|
||||
MAX_CHUNK_BYTES_COUNT = 5000000
|
||||
|
|
@ -603,13 +606,44 @@ class TrainingClient(TelemetryProvider):
|
|||
seq_id=request_id + 1,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save(
|
||||
request=request,
|
||||
try:
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save(
|
||||
request=request,
|
||||
max_retries=0,
|
||||
)
|
||||
except ConflictError:
|
||||
# 409 means a checkpoint with this name already exists.
|
||||
# This is common when retrying after a transient network
|
||||
# error — the first attempt saved the checkpoint but the
|
||||
# response was lost. Treat as success: the checkpoint IS
|
||||
# saved, and crashing a long training run is worse than
|
||||
# returning a synthetic response.
|
||||
logger.info(
|
||||
"Checkpoint '%s' already exists (409 Conflict); "
|
||||
"treating as success — the checkpoint is saved.",
|
||||
name,
|
||||
)
|
||||
if telemetry := self.holder.get_telemetry():
|
||||
telemetry.log(
|
||||
"training_client.save_state.conflict_resolved",
|
||||
event_data={
|
||||
"checkpoint_name": name,
|
||||
"model_id": self._guaranteed_model_id(),
|
||||
},
|
||||
severity="INFO",
|
||||
)
|
||||
return None
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
future = await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
# _send_request returns None on 409 conflict (checkpoint already
|
||||
# saved), or an UntypedAPIFuture on success.
|
||||
if future is None:
|
||||
model_id = self._guaranteed_model_id()
|
||||
return types.SaveWeightsResponse(path=f"tinker://{model_id}/weights/{name}")
|
||||
|
||||
return await _APIFuture(
|
||||
types.SaveWeightsResponse,
|
||||
self.holder,
|
||||
|
|
@ -629,7 +663,11 @@ class TrainingClient(TelemetryProvider):
|
|||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _load_state_impl(
|
||||
self, request_id: int, path: str, optimizer: bool
|
||||
self,
|
||||
request_id: int,
|
||||
path: str,
|
||||
optimizer: bool,
|
||||
weights_access_token: str | None = None,
|
||||
) -> types.LoadWeightsResponse:
|
||||
start_time = time.time()
|
||||
|
||||
|
|
@ -639,6 +677,7 @@ class TrainingClient(TelemetryProvider):
|
|||
path=path,
|
||||
seq_id=request_id + 1,
|
||||
optimizer=optimizer,
|
||||
weights_access_token=weights_access_token,
|
||||
)
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.load(
|
||||
|
|
@ -657,7 +696,9 @@ class TrainingClient(TelemetryProvider):
|
|||
)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
|
||||
def load_state(
|
||||
self, path: str, weights_access_token: str | None = None
|
||||
) -> APIFuture[types.LoadWeightsResponse]:
|
||||
"""Load model weights from a saved checkpoint.
|
||||
|
||||
This loads only the model weights, not optimizer state (e.g., Adam momentum).
|
||||
|
|
@ -665,6 +706,7 @@ class TrainingClient(TelemetryProvider):
|
|||
|
||||
Args:
|
||||
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the load response
|
||||
|
|
@ -678,18 +720,27 @@ class TrainingClient(TelemetryProvider):
|
|||
```
|
||||
"""
|
||||
request_id = self._get_request_id()
|
||||
return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, False))
|
||||
return self.holder.run_coroutine_threadsafe(
|
||||
self._load_state_impl(
|
||||
request_id, path, False, weights_access_token=weights_access_token
|
||||
)
|
||||
)
|
||||
|
||||
async def load_state_async(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
|
||||
async def load_state_async(
|
||||
self, path: str, weights_access_token: str | None = None
|
||||
) -> APIFuture[types.LoadWeightsResponse]:
|
||||
"""Async version of load_state."""
|
||||
return self.load_state(path)
|
||||
return self.load_state(path, weights_access_token=weights_access_token)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state_with_optimizer(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
|
||||
def load_state_with_optimizer(
|
||||
self, path: str, weights_access_token: str | None = None
|
||||
) -> APIFuture[types.LoadWeightsResponse]:
|
||||
"""Load model weights and optimizer state from a checkpoint.
|
||||
|
||||
Args:
|
||||
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
||||
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the load response
|
||||
|
|
@ -705,13 +756,15 @@ class TrainingClient(TelemetryProvider):
|
|||
```
|
||||
"""
|
||||
request_id = self._get_request_id()
|
||||
return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, True))
|
||||
return self.holder.run_coroutine_threadsafe(
|
||||
self._load_state_impl(request_id, path, True, weights_access_token=weights_access_token)
|
||||
)
|
||||
|
||||
async def load_state_with_optimizer_async(
|
||||
self, path: str
|
||||
self, path: str, weights_access_token: str | None = None
|
||||
) -> APIFuture[types.LoadWeightsResponse]:
|
||||
"""Async version of load_state_with_optimizer."""
|
||||
return self.load_state_with_optimizer(path)
|
||||
return self.load_state_with_optimizer(path, weights_access_token=weights_access_token)
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def _save_weights_for_sampler_impl(
|
||||
|
|
@ -739,13 +792,46 @@ class TrainingClient(TelemetryProvider):
|
|||
sampling_session_seq_id=sampling_session_seq_id,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save_for_sampler(
|
||||
request=request,
|
||||
try:
|
||||
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
|
||||
return await client.weights.save_for_sampler(
|
||||
request=request,
|
||||
max_retries=0,
|
||||
)
|
||||
except ConflictError:
|
||||
if name is None:
|
||||
# Unnamed saves use server-generated unique paths;
|
||||
# 409 should be impossible. Re-raise as a real error.
|
||||
raise
|
||||
# See save_state for full rationale on treating 409 as success.
|
||||
logger.info(
|
||||
"Sampler checkpoint '%s' already exists (409 Conflict); "
|
||||
"treating as success — the checkpoint is saved.",
|
||||
name,
|
||||
)
|
||||
if telemetry := self.holder.get_telemetry():
|
||||
telemetry.log(
|
||||
"training_client.save_weights_for_sampler.conflict_resolved",
|
||||
event_data={
|
||||
"checkpoint_name": name,
|
||||
"model_id": self._guaranteed_model_id(),
|
||||
},
|
||||
severity="INFO",
|
||||
)
|
||||
return None
|
||||
|
||||
async with self._take_turn(request_id):
|
||||
future = await self.holder.execute_with_retries(_send_request)
|
||||
|
||||
# _send_request returns None on 409 conflict (checkpoint already
|
||||
# saved), or an UntypedAPIFuture on success.
|
||||
if future is None:
|
||||
assert name is not None
|
||||
model_id = self._guaranteed_model_id()
|
||||
return types.SaveWeightsForSamplerResponseInternal(
|
||||
path=f"tinker://{model_id}/sampler_weights/{name}"
|
||||
)
|
||||
|
||||
return await _APIFuture(
|
||||
types.SaveWeightsForSamplerResponseInternal,
|
||||
self.holder,
|
||||
|
|
@ -906,7 +992,7 @@ class TrainingClient(TelemetryProvider):
|
|||
"""Save current weights and create a SamplingClient for inference.
|
||||
|
||||
Args:
|
||||
- `name`: Optional name for the saved weights (currently ignored for ephemeral saves)
|
||||
- `name`: Deprecated, has no effect. Will be removed in a future release.
|
||||
- `retry_config`: Optional configuration for retrying failed requests
|
||||
|
||||
Returns:
|
||||
|
|
@ -923,8 +1009,17 @@ class TrainingClient(TelemetryProvider):
|
|||
result = sampling_client.sample(prompt, 1, params).result()
|
||||
```
|
||||
"""
|
||||
# Ignore name argument for ephemeral save weights for sampler
|
||||
_ = name
|
||||
if name is not None:
|
||||
warnings.warn(
|
||||
"The 'name' parameter of save_weights_and_get_sampling_client() is deprecated "
|
||||
"and has no effect — checkpoints are always ephemeral. "
|
||||
"This parameter will be removed in a future release. "
|
||||
"Remove the 'name' argument from your call. "
|
||||
"If you need a persistent checkpoint, use "
|
||||
"save_weights_for_sampler(name=...) + create_sampling_client(model_path=...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.save_weights_and_get_sampling_client_submit(retry_config).result()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
|
|
@ -932,8 +1027,17 @@ class TrainingClient(TelemetryProvider):
|
|||
self, name: str | None = None, retry_config: RetryConfig | None = None
|
||||
) -> SamplingClient:
|
||||
"""Async version of save_weights_and_get_sampling_client."""
|
||||
# Ignore name argument for ephemeral save weights for sampler
|
||||
_ = name
|
||||
if name is not None:
|
||||
warnings.warn(
|
||||
"The 'name' parameter of save_weights_and_get_sampling_client_async() is deprecated "
|
||||
"and has no effect — checkpoints are always ephemeral. "
|
||||
"This parameter will be removed in a future release. "
|
||||
"Remove the 'name' argument from your call. "
|
||||
"If you need a persistent checkpoint, use "
|
||||
"save_weights_for_sampler(name=...) + create_sampling_client(model_path=...) instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return await self.save_weights_and_get_sampling_client_submit(retry_config)
|
||||
|
||||
def get_telemetry(self) -> Telemetry | None:
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class AsyncFuturesResource(AsyncAPIResource):
|
|||
FutureRetrieveResponse,
|
||||
await self._post(
|
||||
"/api/v1/retrieve_future",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=cast(
|
||||
Any, FutureRetrieveResponse
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class AsyncModelsResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/create_model",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
@ -106,7 +106,7 @@ class AsyncModelsResource(AsyncAPIResource):
|
|||
|
||||
result = await self._post(
|
||||
"/api/v1/get_info",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=GetInfoResponse,
|
||||
)
|
||||
|
|
@ -159,7 +159,7 @@ class AsyncModelsResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/unload_model",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class AsyncSamplingResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/asample",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ from .._base_client import make_request_options
|
|||
from .._compat import model_dump
|
||||
from .._resource import AsyncAPIResource
|
||||
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
|
||||
from ..types.auth_token_response import AuthTokenResponse
|
||||
from ..types.client_config_request import ClientConfigRequest
|
||||
from ..types.client_config_response import ClientConfigResponse
|
||||
from ..types.create_sampling_session_request import CreateSamplingSessionRequest
|
||||
from ..types.create_sampling_session_response import CreateSamplingSessionResponse
|
||||
from ..types.create_session_request import CreateSessionRequest
|
||||
|
|
@ -63,6 +66,54 @@ class AsyncServiceResource(AsyncAPIResource):
|
|||
cast_to=HealthResponse,
|
||||
)
|
||||
|
||||
async def auth_token(
|
||||
self,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
) -> AuthTokenResponse:
|
||||
"""Exchange the current credential for a short-lived JWT."""
|
||||
options = make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
)
|
||||
if max_retries is not NOT_GIVEN:
|
||||
options["max_retries"] = max_retries
|
||||
|
||||
return await self._post(
|
||||
"/api/v1/auth/token",
|
||||
body={},
|
||||
options=options,
|
||||
cast_to=AuthTokenResponse,
|
||||
)
|
||||
|
||||
async def client_config(
|
||||
self,
|
||||
*,
|
||||
request: ClientConfigRequest,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_query: Query | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ClientConfigResponse:
|
||||
"""Fetch server-side feature flags for this client."""
|
||||
return await self._post(
|
||||
"/api/v1/client/config",
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
),
|
||||
cast_to=ClientConfigResponse,
|
||||
)
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -104,7 +155,7 @@ class AsyncServiceResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/create_session",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=CreateSessionResponse,
|
||||
)
|
||||
|
|
@ -148,7 +199,7 @@ class AsyncServiceResource(AsyncAPIResource):
|
|||
request = SessionHeartbeatRequest(session_id=session_id)
|
||||
return await self._post(
|
||||
"/api/v1/session_heartbeat",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=SessionHeartbeatResponse,
|
||||
)
|
||||
|
|
@ -190,7 +241,7 @@ class AsyncServiceResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/create_sampling_session",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=CreateSamplingSessionResponse,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class AsyncTelemetryResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/telemetry",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=TelemetryResponse,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class AsyncTrainingResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/forward",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
@ -102,7 +102,7 @@ class AsyncTrainingResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/forward_backward",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
@ -148,7 +148,7 @@ class AsyncTrainingResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/optim_step",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class AsyncWeightsResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/load_weights",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
@ -107,7 +107,7 @@ class AsyncWeightsResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/save_weights",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
@ -153,7 +153,7 @@ class AsyncWeightsResource(AsyncAPIResource):
|
|||
|
||||
return await self._post(
|
||||
"/api/v1/save_weights_for_sampler",
|
||||
body=model_dump(request, exclude_unset=True, mode="json"),
|
||||
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
|
||||
options=options,
|
||||
cast_to=UntypedAPIFuture,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .auth_token_response import AuthTokenResponse as AuthTokenResponse
|
||||
from .checkpoint import (
|
||||
Checkpoint as Checkpoint,
|
||||
)
|
||||
|
|
@ -13,6 +14,8 @@ from .checkpoint_archive_url_response import (
|
|||
CheckpointArchiveUrlResponse as CheckpointArchiveUrlResponse,
|
||||
)
|
||||
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
|
||||
from .client_config_request import ClientConfigRequest as ClientConfigRequest
|
||||
from .client_config_response import ClientConfigResponse as ClientConfigResponse
|
||||
from .create_model_request import CreateModelRequest as CreateModelRequest
|
||||
from .create_model_response import CreateModelResponse as CreateModelResponse
|
||||
from .create_sampling_session_request import (
|
||||
|
|
|
|||
9
src/tinker/types/auth_token_response.py
Normal file
9
src/tinker/types/auth_token_response.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .._models import BaseModel
|
||||
|
||||
__all__ = ["AuthTokenResponse"]
|
||||
|
||||
|
||||
class AuthTokenResponse(BaseModel):
|
||||
jwt: str
|
||||
10
src/tinker/types/client_config_request.py
Normal file
10
src/tinker/types/client_config_request.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .._models import StrictBase
|
||||
|
||||
__all__ = ["ClientConfigRequest"]
|
||||
|
||||
|
||||
class ClientConfigRequest(StrictBase):
|
||||
sdk_version: str
|
||||
"""The SDK version string for flag resolution."""
|
||||
18
src/tinker/types/client_config_response.py
Normal file
18
src/tinker/types/client_config_response.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .._models import BaseModel
|
||||
|
||||
__all__ = ["ClientConfigResponse"]
|
||||
|
||||
|
||||
class ClientConfigResponse(BaseModel):
|
||||
"""Server-side feature flags resolved for this caller.
|
||||
|
||||
Uses BaseModel (extra="ignore") so new flags from the server are
|
||||
silently dropped until the SDK adds fields for them.
|
||||
"""
|
||||
|
||||
pjwt_auth_enabled: bool = False
|
||||
credential_default_source: str = "api_key"
|
||||
sample_dispatch_bytes_semaphore_size: int = 10 * 1024 * 1024
|
||||
inflight_response_bytes_semaphore_size: int = 50 * 1024 * 1024
|
||||
|
|
@ -46,6 +46,9 @@ class Datum(StrictBase):
|
|||
def _maybe_convert_array(cls, key: str, value: Any) -> Any:
|
||||
"""Convert torch.Tensor, numpy array, or numeric lists to TensorData if needed."""
|
||||
if _HAVE_TORCH and isinstance(value, torch.Tensor):
|
||||
# Auto-sparsify 2-D target_tokens and weights to reduce wire payload
|
||||
if key in _sparse_eligible_keys and value.ndim == 2:
|
||||
return TensorData.from_torch_sparse(value)
|
||||
return TensorData.from_torch(value)
|
||||
elif isinstance(value, np.ndarray):
|
||||
return TensorData.from_numpy(value)
|
||||
|
|
@ -81,3 +84,5 @@ _key_to_type = {
|
|||
"clip_low_threshold": "float32",
|
||||
"clip_high_threshold": "float32",
|
||||
}
|
||||
|
||||
_sparse_eligible_keys = {"target_tokens", "weights"}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ class SupportedModel(BaseModel):
|
|||
model_name: Optional[str] = None
|
||||
"""The name of the supported model."""
|
||||
|
||||
max_context_length: Optional[int] = None
|
||||
"""The maximum context length (in tokens) supported by this model."""
|
||||
|
||||
|
||||
class GetServerCapabilitiesResponse(BaseModel):
|
||||
"""Response containing the server's supported models and capabilities."""
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ class LoadWeightsRequest(StrictBase):
|
|||
|
||||
seq_id: Optional[int] = None
|
||||
|
||||
weights_access_token: Optional[str] = None
|
||||
"""Optional access token for loading checkpoints under a different account."""
|
||||
|
||||
type: Literal["load_weights"] = "load_weights"
|
||||
|
||||
if PYDANTIC_V2:
|
||||
|
|
|
|||
|
|
@ -33,6 +33,17 @@ class TensorData(StrictBase):
|
|||
provided, and is generally inferred as a 1D tensor.
|
||||
"""
|
||||
|
||||
sparse_crow_indices: Optional[List[int]] = None
|
||||
"""Optional CSR compressed row pointers. When set, this tensor is sparse CSR:
|
||||
- data contains only the non-zero values (flattened)
|
||||
- sparse_crow_indices contains the row pointers (length = nrows + 1)
|
||||
- sparse_col_indices contains the column indices (length = nnz)
|
||||
- shape is required and specifies the dense shape
|
||||
"""
|
||||
|
||||
sparse_col_indices: Optional[List[int]] = None
|
||||
"""Optional CSR column indices. Must be set together with sparse_crow_indices."""
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, array: npt.NDArray[Any]) -> "TensorData":
|
||||
return cls(
|
||||
|
|
@ -49,8 +60,41 @@ class TensorData(StrictBase):
|
|||
shape=list(tensor.shape),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_torch_sparse(cls, tensor: "torch.Tensor") -> "TensorData":
|
||||
"""Create a sparse CSR TensorData from a dense 2-D torch tensor.
|
||||
|
||||
Automatically detects sparsity and encodes as CSR when it saves space.
|
||||
Falls back to dense if the tensor is 1-D or mostly non-zero.
|
||||
"""
|
||||
if not _HAVE_TORCH:
|
||||
raise ImportError("PyTorch is not installed.")
|
||||
|
||||
if tensor.ndim != 2:
|
||||
return cls.from_torch(tensor)
|
||||
|
||||
# Only use sparse if it actually saves space
|
||||
# Dense: nrows * ncols values
|
||||
# CSR: (nrows + 1) crow_indices + nnz col_indices + nnz values
|
||||
nnz = tensor.count_nonzero().item()
|
||||
dense_size = tensor.shape[0] * tensor.shape[1]
|
||||
csr_size = (tensor.shape[0] + 1) + 2 * nnz
|
||||
if csr_size >= dense_size:
|
||||
return cls.from_torch(tensor)
|
||||
|
||||
sparse_csr = tensor.to_sparse_csr()
|
||||
return cls(
|
||||
data=sparse_csr.values().tolist(),
|
||||
dtype=_convert_torch_dtype_to_tensor(tensor.dtype),
|
||||
shape=list(tensor.shape),
|
||||
sparse_crow_indices=sparse_csr.crow_indices().tolist(),
|
||||
sparse_col_indices=sparse_csr.col_indices().tolist(),
|
||||
)
|
||||
|
||||
def to_numpy(self) -> npt.NDArray[Any]:
|
||||
"""Convert TensorData to numpy array."""
|
||||
if self.sparse_crow_indices is not None:
|
||||
return self.to_torch().numpy()
|
||||
numpy_dtype = _convert_tensor_dtype_to_numpy(self.dtype)
|
||||
arr = np.array(self.data, dtype=numpy_dtype)
|
||||
if self.shape is not None:
|
||||
|
|
@ -63,6 +107,17 @@ class TensorData(StrictBase):
|
|||
raise ImportError("PyTorch is not installed. Cannot convert to torch tensor.")
|
||||
|
||||
torch_dtype = _convert_tensor_dtype_to_torch(self.dtype)
|
||||
|
||||
if self.sparse_crow_indices is not None:
|
||||
assert self.sparse_col_indices is not None, (
|
||||
"sparse_col_indices required with sparse_crow_indices"
|
||||
)
|
||||
assert self.shape is not None, "shape is required for sparse tensors"
|
||||
crow = torch.tensor(self.sparse_crow_indices, dtype=torch.int64)
|
||||
col = torch.tensor(self.sparse_col_indices, dtype=torch.int64)
|
||||
values = torch.tensor(self.data, dtype=torch_dtype)
|
||||
return torch.sparse_csr_tensor(crow, col, values, self.shape).to_dense()
|
||||
|
||||
tensor = torch.tensor(self.data, dtype=torch_dtype)
|
||||
if self.shape is not None:
|
||||
tensor = tensor.reshape(self.shape)
|
||||
|
|
|
|||
79
tests/test_conflict_recovery.py
Normal file
79
tests/test_conflict_recovery.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Tests for 409 ConflictError recovery in checkpoint save operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from tinker._exceptions import ConflictError
|
||||
from tinker.lib.public_interfaces.training_client import TrainingClient
|
||||
|
||||
|
||||
def _make_conflict_error() -> ConflictError:
|
||||
"""Create a ConflictError for testing."""
|
||||
request = httpx.Request("POST", "http://test/api/v1/save_weights")
|
||||
response = httpx.Response(409, request=request)
|
||||
return ConflictError("conflict", response=response, body=None)
|
||||
|
||||
|
||||
def _make_mock_holder() -> Mock:
|
||||
"""Create a mock InternalClientHolder whose weights.save() raises ConflictError."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.weights.save = AsyncMock(side_effect=_make_conflict_error())
|
||||
mock_client.weights.save_for_sampler = AsyncMock(side_effect=_make_conflict_error())
|
||||
|
||||
@contextmanager
|
||||
def fake_aclient(*args: Any, **kwargs: Any):
|
||||
yield mock_client
|
||||
|
||||
async def fake_execute_with_retries(fn: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
holder = Mock()
|
||||
holder.aclient = fake_aclient
|
||||
holder.get_telemetry = Mock(return_value=None)
|
||||
holder.execute_with_retries = fake_execute_with_retries
|
||||
holder.get_loop = Mock(side_effect=lambda: asyncio.get_event_loop())
|
||||
|
||||
def fake_run_coroutine_threadsafe(coro: Any) -> Any:
|
||||
return asyncio.ensure_future(coro)
|
||||
|
||||
holder.run_coroutine_threadsafe = fake_run_coroutine_threadsafe
|
||||
return holder
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_state_returns_synthetic_path_on_conflict() -> None:
|
||||
"""save_state catches 409 and returns SaveWeightsResponse with synthetic path."""
|
||||
holder = _make_mock_holder()
|
||||
client = TrainingClient(holder, model_seq_id=0, model_id="model-123")
|
||||
|
||||
result = await client.save_state("ckpt-001")
|
||||
assert result.path == "tinker://model-123/weights/ckpt-001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_weights_for_sampler_returns_synthetic_path_on_conflict() -> None:
|
||||
"""save_weights_for_sampler catches 409 and returns response with synthetic path."""
|
||||
holder = _make_mock_holder()
|
||||
holder._sampling_client_counter = 0
|
||||
client = TrainingClient(holder, model_seq_id=0, model_id="model-789")
|
||||
|
||||
result = await client.save_weights_for_sampler("ckpt-001")
|
||||
assert result.path == "tinker://model-789/sampler_weights/ckpt-001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_weights_for_sampler_unnamed_reraises_conflict() -> None:
|
||||
"""409 on unnamed sampler save (name=None) should re-raise, not swallow."""
|
||||
holder = _make_mock_holder()
|
||||
holder._sampling_client_counter = 0
|
||||
client = TrainingClient(holder, model_seq_id=0, model_id="model-000")
|
||||
|
||||
with pytest.raises(ConflictError):
|
||||
await client.save_weights_for_sampler(None)
|
||||
Loading…
Add table
Add a link
Reference in a new issue