Sync contents

This commit is contained in:
Andrii Grynenko 2026-04-14 00:00:48 +00:00
parent 07bd3c2dd3
commit 30517b667f
33 changed files with 1272 additions and 371 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "db025e90079a19c36090a13aa88e4b2494d5a502",
"last_sync_time": "2026-03-19T02:39:30.785199"
"last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a",
"last_sync_time": "2026-04-14T00:00:48.831738"
}

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.16.1"
version = "0.18.0"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -428,7 +428,9 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
retries_taken: int = 0,
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
log.debug(
"Request options: %s", model_dump(options, exclude_unset=False, exclude_none=True)
)
kwargs: dict[str, Any] = {}

View file

@ -12,7 +12,7 @@ from ._base_client import (
AsyncAPIClient,
)
from ._compat import cached_property
from ._exceptions import APIStatusError, TinkerError
from ._exceptions import APIStatusError
from ._qs import Querystring
from ._streaming import AsyncStream as AsyncStream
from ._types import (
@ -26,6 +26,7 @@ from ._types import (
)
from ._utils import get_async_library, is_given
from ._version import __version__
from .lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider
if TYPE_CHECKING:
from .resources import futures, telemetry
@ -47,9 +48,6 @@ __all__ = [
class AsyncTinker(AsyncAPIClient):
# client options
api_key: str
def __init__(
self,
*,
@ -72,20 +70,16 @@ class AsyncTinker(AsyncAPIClient):
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
_auth: AuthTokenProvider | None = None,
) -> None:
"""Construct a new async AsyncTinker client instance.
This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided.
"""
if api_key is None:
api_key = os.environ.get("TINKER_API_KEY")
if api_key is None:
raise TinkerError(
"The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable"
)
if not api_key.startswith("tml-"):
raise TinkerError("The api_key must start with the 'tml-' prefix")
self.api_key = api_key
if _auth is not None:
self._auth = _auth
else:
self._auth = ApiKeyAuthProvider(api_key=api_key)
if base_url is None:
base_url = os.environ.get("TINKER_BASE_URL")
@ -158,9 +152,8 @@ class AsyncTinker(AsyncAPIClient):
@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"X-API-Key": api_key}
def custom_auth(self) -> AuthTokenProvider:
return self._auth
@property
@override
@ -174,7 +167,6 @@ class AsyncTinker(AsyncAPIClient):
def copy(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
@ -212,7 +204,7 @@ class AsyncTinker(AsyncAPIClient):
http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
_auth=self._auth,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,

View file

@ -136,6 +136,7 @@ def model_dump(
exclude: IncEx | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
mode: Literal["json", "python"] = "python",
) -> dict[str, Any]:
@ -145,6 +146,7 @@ def model_dump(
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
# warnings are not supported in Pydantic v1
warnings=warnings if PYDANTIC_V2 else True,
)
@ -154,6 +156,7 @@ def model_dump(
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
),
)

View file

@ -21,6 +21,7 @@ from typing import (
import anyio
import httpx
import orjson
import pydantic
from typing_extensions import Awaitable, ParamSpec, get_origin, override
@ -351,7 +352,7 @@ class APIResponse(BaseAPIResponse[R]):
def json(self) -> object:
"""Read and decode the JSON response content."""
self.read()
return self.http_response.json()
return orjson.loads(self.http_response.content)
def close(self) -> None:
"""Close the response and release the connection.
@ -451,7 +452,7 @@ class AsyncAPIResponse(BaseAPIResponse[R]):
async def json(self) -> object:
"""Read and decode the JSON response content."""
await self.read()
return self.http_response.json()
return orjson.loads(self.http_response.content)
async def close(self) -> None:
"""Close the response and release the connection.

View file

@ -215,7 +215,7 @@ def _transform_recursive(
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")
return model_dump(data, exclude_unset=False, exclude_none=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
@ -382,7 +382,7 @@ async def _async_transform_recursive(
return data
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True, mode="json")
return model_dump(data, exclude_unset=False, exclude_none=True, mode="json")
annotated_type = _get_annotated_type(annotation)
if annotated_type is None:

285
src/tinker/cli/AGENTS.md Normal file
View file

@ -0,0 +1,285 @@
# Tinker CLI Design Documentation
## Overview
The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance.
## Key Design Decisions
### 1. Lazy Import Strategy with Click
**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level.
**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands.
**Implementation**:
- Main `__init__.py` only imports `click` and `lazy_group`
- Command modules are loaded only when invoked via `LazyGroup`
- Output formatting imports `rich` only when table output is needed
- JSON module imported only when JSON output is requested
- Version information loaded from `_version.py` only when `tinker version` is used
### 2. Click Framework with LazyGroup
**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading.
**Rationale**:
- Click provides cleaner command structure with decorators
- Better subcommand isolation - each command file is self-contained
- Automatic help generation with better formatting
- Built-in type conversion and validation
- LazyGroup enables fast startup by deferring imports
**LazyGroup Implementation**:
```python
class LazyGroup(click.Group):
def __init__(self, *args, lazy_subcommands=None, **kwargs):
# Map of command name to "module.path:command_name"
self.lazy_subcommands = lazy_subcommands or {}
def get_command(self, ctx, cmd_name):
if cmd_name in self.lazy_subcommands:
# Import only when command is actually invoked
import_path = self.lazy_subcommands[cmd_name]
module_name, attr_name = import_path.rsplit(":", 1)
mod = importlib.import_module(module_name)
return getattr(mod, attr_name)
```
### 3. Hierarchical Command Structure
**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`.
**Rationale**:
- Provides a consistent, predictable interface
- Groups related functionality together
- Makes the CLI extensible for future commands
- Follows common CLI patterns (like `git`, `docker`, etc.)
**Examples**:
- `tinker version` - Show CLI and SDK version
- `tinker run list` - List all training runs
- `tinker run info <run-id>` - Show details of a specific run
- `tinker checkpoint list` - List all checkpoints
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
### 4. Output System with Inheritance
**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class.
**Rationale**:
- Enforces consistent interface across all commands
- Encapsulates output logic with the command that generates it
- Makes it easy to support multiple output formats (table, JSON)
- Keeps related code together in the same module
**Implementation**:
- `OutputBase` in `output.py` defines the contract
- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`)
- Base class handles format selection and rendering
### 5. Self-Contained Command Modules
**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point.
**Rationale**:
- Modular architecture - commands can be developed independently
- Clear separation of concerns
- Easy to add new commands without modifying core files
- Consistent pattern across all commands
**Command Structure**:
```python
# Each command file follows this pattern:
@click.group() # or @click.command() for simple commands
def cli():
"""Command description."""
pass
@cli.command() # For subcommands
def list():
"""Subcommand implementation."""
pass
```
### 6. Centralized Client Management
**Decision**: All SDK client creation and error handling is centralized in `client.py`.
**Rationale**:
- Single place to handle authentication and connection errors
- Consistent error messages across all commands
- Reusable error handling decorator
- Clean separation of concerns
### 7. Rich Tables for Human-Readable Output
**Decision**: Use the `rich` library for table formatting, kept as an optional dependency.
**Rationale**:
- Provides beautiful, formatted tables with colors and borders
- Handles column width adjustment automatically
- Supports both dark and light terminal themes
- Optional dependency keeps the core package lightweight
### 8. Unix-Style Default Output
**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output.
**Rationale**:
- Follows Unix philosophy
- Tables are better for human consumption
- JSON is better for scripting and automation
- Single flag switches between formats consistently
## Performance Optimizations
1. **LazyGroup for deferred imports** - Commands only loaded when invoked
2. **No heavy imports at module level** - Only Click imported initially
3. **Lazy loading** of all SDK dependencies
4. **Progress indicators** that clear themselves
5. **Efficient data fetching** - fetch all data by default instead of pagination
## Error Handling Strategy
1. **User-friendly messages** - Technical errors are translated to helpful messages
2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes
3. **Graceful degradation** - Continue operation when possible
4. **Detailed error info** - Show details when available, traceback only in TTY
### TinkerCliError Exception Pattern
All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`:
```python
from ..exceptions import TinkerCliError
# Instead of:
print(f"Error: Something went wrong", file=sys.stderr)
sys.exit(1)
# Use:
raise TinkerCliError(
"Something went wrong",
"Optional details or help text",
exit_code=1 # Optional, defaults to 1
)
```
**Benefits:**
- Better testability (can catch exceptions in tests)
- Centralized error formatting in `__main__.py`
- Consistent exit codes across the CLI
- Stack traces preserved for debugging
**Important Notes:**
- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification
- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages
- The main error handler in `__main__.py` handles printing to stderr and exiting
## Future Extensibility
The architecture supports easy addition of:
### New Commands
- Create new module in `commands/` directory
- Define output classes in the same module if needed
- Add command to lazy_subcommands in `__init__.py`
### New Subcommands
- Add new Click command decorator to existing command module
- Define corresponding output class if needed
- Subcommands automatically discovered by Click
### New Output Formats
- Override `print()` method in `OutputBase`
- Or add new format handling to base class
## Testing Guidelines
1. **Startup time**: `time tinker --help` should be <50ms
2. **Import verification**: Check that modules aren't imported unnecessarily
3. **Output formats**: Test both table and JSON output
4. **Error cases**: Test with missing auth, invalid IDs, network errors
5. **Empty results**: Ensure graceful handling of no data
## Module Structure
```
cli/
├── __init__.py # Main entry with LazyGroup configuration
├── __main__.py # Module execution support
├── lazy_group.py # LazyGroup implementation for lazy loading
├── output.py # OutputBase class and formatting utilities
├── client.py # SDK client creation and error handling
├── commands/
│ ├── __init__.py # Command module marker
│ ├── version.py # Version command
│ ├── run.py # Run commands and output classes
│ └── checkpoint.py # Checkpoint commands and output classes
└── CLAUDE.md # This documentation
```
## Command Examples
```bash
# Show version
tinker version
# List all training runs
tinker run list
# Show run details
tinker run info run-abc123
# List all checkpoints
tinker checkpoint list
# List checkpoints for specific run
tinker checkpoint list run-abc123
# Show checkpoint details
tinker checkpoint info ckpt-xyz789
# Upload checkpoint to Hugging Face Hub
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
# JSON output
tinker --format json run list
tinker --format json checkpoint list
```
## Dependencies
### Required
- Python 3.11+
- tinker SDK (main package)
- click>=8.0.0 (CLI framework)
### Optional
- `rich` - For table formatting (installed with `pip install tinker[cli]`)
## Migration from Argparse to Click
### Key Changes:
1. **Command Definition**: Decorators instead of `parser.add_argument()`
2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch
3. **Context Passing**: Click's context system for sharing format option
4. **Error Handling**: Click handles exits and error formatting
5. **Help Generation**: Automatic from docstrings and decorators
### Benefits:
- Cleaner, more Pythonic code
- Better command organization
- Built-in testing utilities
- Easier to extend with plugins
- More consistent behavior across commands
## Maintenance Notes
1. **Keep imports lazy** - Use LazyGroup for all commands
2. **Test startup time** - Regularly verify fast startup is maintained
3. **Follow Click patterns** - Use decorators and context properly
4. **Document changes** - Update this file when making architectural changes
5. **Maintain consistency** - All commands should follow the same structure

View file

@ -1,285 +1 @@
# Tinker CLI Design Documentation
## Overview
The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance.
## Key Design Decisions
### 1. Lazy Import Strategy with Click
**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level.
**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands.
**Implementation**:
- Main `__init__.py` only imports `click` and `lazy_group`
- Command modules are loaded only when invoked via `LazyGroup`
- Output formatting imports `rich` only when table output is needed
- JSON module imported only when JSON output is requested
- Version information loaded from `_version.py` only when `tinker version` is used
### 2. Click Framework with LazyGroup
**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading.
**Rationale**:
- Click provides cleaner command structure with decorators
- Better subcommand isolation - each command file is self-contained
- Automatic help generation with better formatting
- Built-in type conversion and validation
- LazyGroup enables fast startup by deferring imports
**LazyGroup Implementation**:
```python
class LazyGroup(click.Group):
def __init__(self, *args, lazy_subcommands=None, **kwargs):
# Map of command name to "module.path:command_name"
self.lazy_subcommands = lazy_subcommands or {}
def get_command(self, ctx, cmd_name):
if cmd_name in self.lazy_subcommands:
# Import only when command is actually invoked
import_path = self.lazy_subcommands[cmd_name]
module_name, attr_name = import_path.rsplit(":", 1)
mod = importlib.import_module(module_name)
return getattr(mod, attr_name)
```
### 3. Hierarchical Command Structure
**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`.
**Rationale**:
- Provides a consistent, predictable interface
- Groups related functionality together
- Makes the CLI extensible for future commands
- Follows common CLI patterns (like `git`, `docker`, etc.)
**Examples**:
- `tinker version` - Show CLI and SDK version
- `tinker run list` - List all training runs
- `tinker run info <run-id>` - Show details of a specific run
- `tinker checkpoint list` - List all checkpoints
- `tinker checkpoint info <checkpoint-id>` - Show checkpoint details
- `tinker checkpoint push-hf <checkpoint-path>` - Upload a checkpoint to Hugging Face Hub
### 4. Output System with Inheritance
**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class.
**Rationale**:
- Enforces consistent interface across all commands
- Encapsulates output logic with the command that generates it
- Makes it easy to support multiple output formats (table, JSON)
- Keeps related code together in the same module
**Implementation**:
- `OutputBase` in `output.py` defines the contract
- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`)
- Base class handles format selection and rendering
### 5. Self-Contained Command Modules
**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point.
**Rationale**:
- Modular architecture - commands can be developed independently
- Clear separation of concerns
- Easy to add new commands without modifying core files
- Consistent pattern across all commands
**Command Structure**:
```python
# Each command file follows this pattern:
@click.group() # or @click.command() for simple commands
def cli():
"""Command description."""
pass
@cli.command() # For subcommands
def list():
"""Subcommand implementation."""
pass
```
### 6. Centralized Client Management
**Decision**: All SDK client creation and error handling is centralized in `client.py`.
**Rationale**:
- Single place to handle authentication and connection errors
- Consistent error messages across all commands
- Reusable error handling decorator
- Clean separation of concerns
### 7. Rich Tables for Human-Readable Output
**Decision**: Use the `rich` library for table formatting, kept as an optional dependency.
**Rationale**:
- Provides beautiful, formatted tables with colors and borders
- Handles column width adjustment automatically
- Supports both dark and light terminal themes
- Optional dependency keeps the core package lightweight
### 8. Unix-Style Default Output
**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output.
**Rationale**:
- Follows Unix philosophy
- Tables are better for human consumption
- JSON is better for scripting and automation
- Single flag switches between formats consistently
## Performance Optimizations
1. **LazyGroup for deferred imports** - Commands only loaded when invoked
2. **No heavy imports at module level** - Only Click imported initially
3. **Lazy loading** of all SDK dependencies
4. **Progress indicators** that clear themselves
5. **Efficient data fetching** - fetch all data by default instead of pagination
## Error Handling Strategy
1. **User-friendly messages** - Technical errors are translated to helpful messages
2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes
3. **Graceful degradation** - Continue operation when possible
4. **Detailed error info** - Show details when available, traceback only in TTY
### TinkerCliError Exception Pattern
All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`:
```python
from ..exceptions import TinkerCliError
# Instead of:
print(f"Error: Something went wrong", file=sys.stderr)
sys.exit(1)
# Use:
raise TinkerCliError(
"Something went wrong",
"Optional details or help text",
exit_code=1 # Optional, defaults to 1
)
```
**Benefits:**
- Better testability (can catch exceptions in tests)
- Centralized error formatting in `__main__.py`
- Consistent exit codes across the CLI
- Stack traces preserved for debugging
**Important Notes:**
- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification
- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages
- The main error handler in `__main__.py` handles printing to stderr and exiting
## Future Extensibility
The architecture supports easy addition of:
### New Commands
- Create new module in `commands/` directory
- Define output classes in the same module if needed
- Add command to lazy_subcommands in `__init__.py`
### New Subcommands
- Add new Click command decorator to existing command module
- Define corresponding output class if needed
- Subcommands automatically discovered by Click
### New Output Formats
- Override `print()` method in `OutputBase`
- Or add new format handling to base class
## Testing Guidelines
1. **Startup time**: `time tinker --help` should be <50ms
2. **Import verification**: Check that modules aren't imported unnecessarily
3. **Output formats**: Test both table and JSON output
4. **Error cases**: Test with missing auth, invalid IDs, network errors
5. **Empty results**: Ensure graceful handling of no data
## Module Structure
```
cli/
├── __init__.py # Main entry with LazyGroup configuration
├── __main__.py # Module execution support
├── lazy_group.py # LazyGroup implementation for lazy loading
├── output.py # OutputBase class and formatting utilities
├── client.py # SDK client creation and error handling
├── commands/
│ ├── __init__.py # Command module marker
│ ├── version.py # Version command
│ ├── run.py # Run commands and output classes
│ └── checkpoint.py # Checkpoint commands and output classes
└── CLAUDE.md # This documentation
```
## Command Examples
```bash
# Show version
tinker version
# List all training runs
tinker run list
# Show run details
tinker run info run-abc123
# List all checkpoints
tinker checkpoint list
# List checkpoints for specific run
tinker checkpoint list run-abc123
# Show checkpoint details
tinker checkpoint info ckpt-xyz789
# Upload checkpoint to Hugging Face Hub
tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter
# JSON output
tinker --format json run list
tinker --format json checkpoint list
```
## Dependencies
### Required
- Python 3.11+
- tinker SDK (main package)
- click>=8.0.0 (CLI framework)
### Optional
- `rich` - For table formatting (installed with `pip install tinker[cli]`)
## Migration from Argparse to Click
### Key Changes:
1. **Command Definition**: Decorators instead of `parser.add_argument()`
2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch
3. **Context Passing**: Click's context system for sharing format option
4. **Error Handling**: Click handles exits and error formatting
5. **Help Generation**: Automatic from docstrings and decorators
### Benefits:
- Cleaner, more Pythonic code
- Better command organization
- Built-in testing utilities
- Easier to extend with plugins
- More consistent behavior across commands
## Maintenance Notes
1. **Keep imports lazy** - Use LazyGroup for all commands
2. **Test startup time** - Regularly verify fast startup is maintained
3. **Follow Click patterns** - Use decorators and context properly
4. **Document changes** - Update this file when making architectural changes
5. **Maintain consistency** - All commands should follow the same structure
@AGENTS.md

View file

@ -0,0 +1,104 @@
"""Authentication credential management for the Tinker SDK.
Provides composable credential providers that plug into httpx's async auth flow:
- AuthTokenProvider: abstract base (httpx.Auth) subclasses implement get_token()
- ApiKeyAuthProvider: resolves from api_key arg or TINKER_API_KEY env var
- CredentialCmdAuthProvider: runs a command on every call for fresh credentials
- resolve_auth_provider(): factory that picks the right provider
"""
from __future__ import annotations
import abc
import asyncio
import os
from collections.abc import AsyncGenerator
import httpx
from tinker._exceptions import TinkerError
class AuthTokenProvider(httpx.Auth):
"""Abstract base auth provider. Subclasses implement get_token()."""
@abc.abstractmethod
async def get_token(self) -> str | None: ...
async def async_auth_flow(
self, request: httpx.Request
) -> AsyncGenerator[httpx.Request, httpx.Response]:
token = await self.get_token()
if token:
request.headers["X-API-Key"] = token
yield request
class ApiKeyAuthProvider(AuthTokenProvider):
"""Resolves api_key from constructor arg or TINKER_API_KEY env var."""
def __init__(self, api_key: str | None = None) -> None:
resolved = api_key or os.environ.get("TINKER_API_KEY")
if not resolved:
raise TinkerError(
"The api_key client option must be set either by passing api_key to the client"
" or by setting the TINKER_API_KEY environment variable"
)
if not resolved.startswith("tml-") and not resolved.startswith("eyJ"):
raise TinkerError("The api_key must start with the 'tml-' prefix")
self._token = resolved
async def get_token(self) -> str | None:
return self._token
class CredentialCmdAuthProvider(AuthTokenProvider):
"""Runs TINKER_CREDENTIAL_CMD on every get_token() call.
Always produces a fresh credential (e.g. short-lived bearer tokens).
Uses async subprocess to avoid blocking the event loop.
"""
def __init__(self, cmd: str) -> None:
if not cmd:
raise TinkerError(
"Your organization requires dynamic credentials — set TINKER_CREDENTIAL_CMD"
" to a command that prints a valid credential."
)
self._cmd = cmd
async def get_token(self) -> str | None:
proc = await asyncio.create_subprocess_shell(
self._cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, _ = await proc.communicate()
credential = stdout.decode().strip()
if not credential:
raise TinkerError("TINKER_CREDENTIAL_CMD returned an empty credential.")
return credential
def resolve_auth_provider(api_key: str | None, enforce_cmd: bool) -> AuthTokenProvider:
"""Construct the appropriate auth provider based on available credentials.
- enforce_cmd=True: uses TINKER_CREDENTIAL_CMD, unless the api_key is
already a JWT (dynamic credential) in which case it's used directly.
- enforce_cmd=False: tries api_key first, falls back to TINKER_CREDENTIAL_CMD
"""
credential_cmd = os.environ.get("TINKER_CREDENTIAL_CMD", "")
# A JWT passed as api_key is already a dynamic credential — use it
# directly even when credential_cmd is enforced.
resolved = api_key or os.environ.get("TINKER_API_KEY", "")
if resolved and resolved.startswith("eyJ"):
return ApiKeyAuthProvider(api_key=resolved)
if enforce_cmd:
return CredentialCmdAuthProvider(credential_cmd)
try:
return ApiKeyAuthProvider(api_key=api_key)
except TinkerError:
if credential_cmd:
return CredentialCmdAuthProvider(credential_cmd)
raise

View file

@ -0,0 +1,90 @@
"""JWT authentication for Tinker SDK.
Internal to the SDK; not part of the public API.
When the server sets pjwt_auth_enabled, the SDK exchanges the caller's
credential for a short-lived JWT minted by the Tinker server. The JWT is
cached and refreshed in the background before it expires, so callers always
send a valid token without any per-request overhead.
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
import time
from collections.abc import Callable
from contextlib import AbstractContextManager
from tinker.lib._auth_token_provider import AuthTokenProvider
logger = logging.getLogger(__name__)
_REFRESH_BEFORE_EXPIRY_SECS = 300 # refresh 5 min before expiry
_RETRY_DELAY_SECS = 60
def _jwt_expiry(jwt: str) -> float:
"""Return the exp claim of a JWT as a Unix timestamp."""
try:
payload = jwt.split(".")[1]
payload += "=" * (-len(payload) % 4)
return float(json.loads(base64.urlsafe_b64decode(payload))["exp"])
except Exception as e:
raise ValueError(f"Failed to parse JWT expiry: {e}") from e
class JwtAuthProvider(AuthTokenProvider):
"""AuthTokenProvider that exchanges a credential for a short-lived JWT.
After init(), get_token() returns the current JWT. A background task
refreshes the JWT before it expires.
"""
def __init__(
self,
aclient_fn: Callable[[], AbstractContextManager],
seed_token: str | None = None,
) -> None:
self._token = seed_token or ""
self._aclient_fn = aclient_fn
async def get_token(self) -> str | None:
return self._token
async def init(self) -> None:
"""Fetch a JWT (unless seeded) then start the background refresh loop.
When seed_token was provided, skips the initial fetch and starts
refreshing from the seed useful for shadow holders that already
have a valid JWT from the primary holder.
"""
token = self._token if self._token else await self._fetch()
self._refresh_task = asyncio.create_task(self._refresh_loop(token))
async def _fetch(self) -> str:
"""Exchange the current credential for a JWT via /api/v1/auth/token."""
with self._aclient_fn() as client:
response = await client.service.auth_token()
self._token = response.jwt
return response.jwt
async def _refresh_loop(self, token: str) -> None:
while True:
try:
delay = max(
_RETRY_DELAY_SECS,
_jwt_expiry(token) - time.time() - _REFRESH_BEFORE_EXPIRY_SECS,
)
except ValueError:
logger.debug("Failed to parse JWT expiry, retrying in %ds", _RETRY_DELAY_SECS)
delay = _RETRY_DELAY_SECS
try:
await asyncio.sleep(delay)
token = await self._fetch()
except asyncio.CancelledError:
return
except Exception as e:
logger.debug("JWT refresh failed, retrying in %ds: %s", _RETRY_DELAY_SECS, e)

View file

@ -0,0 +1,156 @@
"""Tests for JWT authentication helpers."""
from __future__ import annotations
import base64
import json
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
from tinker._exceptions import TinkerError
from tinker.lib._auth_token_provider import (
ApiKeyAuthProvider,
CredentialCmdAuthProvider,
resolve_auth_provider,
)
from tinker.lib._jwt_auth import (
JwtAuthProvider,
_jwt_expiry,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_jwt(exp: float) -> str:
"""Build a minimal fake JWT with a given exp claim."""
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
payload_bytes = json.dumps({"exp": exp, "sub": "test"}).encode()
payload = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode()
return f"{header}.{payload}.fakesig"
class _MockAuthResponse:
def __init__(self, jwt: str) -> None:
self.jwt = jwt
class _MockHolder:
"""Minimal mock providing aclient() for testing JwtAuthProvider."""
def __init__(self, response_jwt: str, *, fail: bool = False) -> None:
service = MagicMock()
if fail:
service.auth_token = AsyncMock(side_effect=Exception("network error"))
else:
service.auth_token = AsyncMock(return_value=_MockAuthResponse(response_jwt))
client = MagicMock()
client.service = service
cm = MagicMock()
cm.__enter__ = MagicMock(return_value=client)
cm.__exit__ = MagicMock(return_value=None)
self._cm = cm
def aclient(self):
return self._cm
# ---------------------------------------------------------------------------
# _jwt_expiry
# ---------------------------------------------------------------------------
def test_jwt_expiry_parses_valid():
exp = time.time() + 3600
assert abs(_jwt_expiry(_make_jwt(exp)) - exp) < 1
def test_jwt_expiry_raises_on_invalid():
with pytest.raises(Exception):
_jwt_expiry("not.a.jwt")
def test_jwt_expiry_raises_on_missing_exp():
header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
payload = base64.urlsafe_b64encode(b'{"sub":"x"}').rstrip(b"=").decode()
with pytest.raises(Exception):
_jwt_expiry(f"{header}.{payload}.sig")
# ---------------------------------------------------------------------------
# AuthTokenProvider hierarchy
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_api_key_provider_resolves_key():
auth = ApiKeyAuthProvider(api_key="tml-test-key")
assert await auth.get_token() == "tml-test-key"
@pytest.mark.asyncio
async def test_credential_cmd_provider_runs_command():
auth = CredentialCmdAuthProvider("echo test-credential")
assert await auth.get_token() == "test-credential"
@pytest.mark.asyncio
async def test_resolve_auth_provider_fallback_to_cmd(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("TINKER_API_KEY", raising=False)
monkeypatch.setenv("TINKER_CREDENTIAL_CMD", "echo fallback-cred")
auth = resolve_auth_provider(api_key=None, enforce_cmd=False)
assert isinstance(auth, CredentialCmdAuthProvider)
assert await auth.get_token() == "fallback-cred"
def test_credential_cmd_provider_raises_with_empty_cmd():
with pytest.raises(TinkerError, match="dynamic credentials"):
CredentialCmdAuthProvider("")
# ---------------------------------------------------------------------------
# JwtAuthProvider.init
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_init_fetches_jwt_and_stores_it():
exp = time.time() + 7200
jwt = _make_jwt(exp)
holder = _MockHolder(jwt)
provider = JwtAuthProvider(holder.aclient)
await provider.init()
assert await provider.get_token() == jwt
holder._cm.__enter__.return_value.service.auth_token.assert_called_once()
@pytest.mark.asyncio
async def test_init_raises_on_fetch_failure():
holder = _MockHolder("some-jwt", fail=True)
provider = JwtAuthProvider(holder.aclient)
with pytest.raises(Exception, match="network error"):
await provider.init()
# ---------------------------------------------------------------------------
# JwtAuthProvider._fetch
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_fetch_returns_and_stores_token():
exp = time.time() + 7200
jwt = _make_jwt(exp)
holder = _MockHolder(jwt)
provider = JwtAuthProvider(holder.aclient)
result = await provider._fetch()
assert result == jwt
assert await provider.get_token() == jwt

View file

@ -21,6 +21,12 @@ from tinker import types
from tinker._client import AsyncTinker
from tinker._exceptions import APIConnectionError, APIStatusError
from tinker._version import __version__ as tinker_sdk_version
from tinker.lib._auth_token_provider import (
ApiKeyAuthProvider,
AuthTokenProvider,
resolve_auth_provider,
)
from tinker.lib._jwt_auth import JwtAuthProvider
from tinker.lib.async_tinker_provider import AsyncTinkerProvider
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
@ -180,17 +186,63 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
project_id: str | None = None,
*,
session_id: str | None = None,
api_key: str | None = None,
_client_config: dict[str, str | int | bool] | None = None,
_jwt_auth_seed: str | None = None,
**kwargs: Any,
) -> None:
self._constructor_kwargs = kwargs
self._api_key = api_key
self._constructor_kwargs = dict(kwargs)
self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop()
self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {}
self._sample_backoff_until: float | None = None
self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400)
self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10)
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024)
self._training_client_lock: threading.Lock = threading.Lock()
self._telemetry: Telemetry | None = None
# Fetch server-side client config before any server contact so that
# flags are available for subsequent setup steps. Shadow holders
# receive the config via kwargs to avoid a redundant fetch (and
# potential deadlock on the event loop thread).
if _client_config is not None:
self._client_config = types.ClientConfigResponse.model_validate(_client_config)
else:
self._assert_not_on_event_loop("fetch client config")
config_auth = resolve_auth_provider(api_key, enforce_cmd=False)
self._client_config = self.run_coroutine_threadsafe(
self._fetch_client_config(config_auth)
).result()
self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(
self._client_config.sample_dispatch_bytes_semaphore_size
)
self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(
self._client_config.inflight_response_bytes_semaphore_size
)
if not self._client_config.pjwt_auth_enabled:
# Without JWT exchange, only API keys are accepted by the server.
# Replace any cmd-based provider with a plain API key provider.
self._default_auth = ApiKeyAuthProvider(api_key=api_key)
else:
# Create a dedicated pool for JWT exchange with the appropriate
# credential provider. The lambda captures the pool so it stays alive.
use_cmd = self._client_config.credential_default_source == "credential_cmd"
auth_pool_auth = resolve_auth_provider(self._api_key, use_cmd)
auth_kwargs = {**self._constructor_kwargs, "_auth": auth_pool_auth}
auth_pool = ClientConnectionPool(self.get_loop(), 1, auth_kwargs)
auth_aclient = lambda: auth_pool.aclient() # noqa: E731
self._default_auth = JwtAuthProvider(auth_aclient, seed_token=_jwt_auth_seed)
if _jwt_auth_seed:
# Shadow holder: start refresh in background, don't block.
self.run_coroutine_threadsafe(self._default_auth.init())
else:
# Primary holder: must have a valid JWT before proceeding.
self._assert_not_on_event_loop("exchange JWT")
self.run_coroutine_threadsafe(
self.execute_with_retries(self._default_auth.init)
).result()
if session_id is not None:
# Shadow mode: reuse existing session, can't create new clients
@ -199,14 +251,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
self._sampling_client_counter: int | None = None
else:
# Normal mode: create new session.
# This blocks on .result() — must NOT be called from the event
# loop thread (e.g. inside the sidecar subprocess). Shadow
# holders (session_id is not None) skip this path.
if self._loop.is_running() and _current_loop() is self._loop:
raise RuntimeError(
"Cannot create a new session from the event loop thread. "
"Use session_id= to create a shadow holder instead."
)
self._assert_not_on_event_loop("create a new session")
self._session_id = self.run_coroutine_threadsafe(
self._create_session(user_metadata=user_metadata, project_id=project_id)
).result()
@ -230,6 +275,26 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
"""Get or create a shadow holder from the singleton cache."""
return _shadow_holder_singleton.get_or_create(session_id, kwargs)
def _assert_not_on_event_loop(self, action: str) -> None:
"""Raise if called from the event loop thread (would deadlock on .result())."""
if self._loop.is_running() and _current_loop() is self._loop:
raise RuntimeError(
f"Cannot {action} from the event loop thread. "
"Use session_id= to create a shadow holder instead."
)
@property
def shadow_kwargs(self) -> dict[str, Any]:
"""Constructor kwargs for shadow holders, including cached server config and JWT seed."""
result = {
**self._constructor_kwargs,
"api_key": self._api_key,
"_client_config": self._client_config.model_dump(),
}
if isinstance(self._default_auth, JwtAuthProvider):
result["_jwt_auth_seed"] = self._default_auth._token
return result
@asynccontextmanager
async def _sample_dispatch_count_rate_limit(self):
async with self._sample_dispatch_semaphore:
@ -316,6 +381,23 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
"""Start the session heartbeat task."""
return asyncio.create_task(self._session_heartbeat(self._session_id))
async def _fetch_client_config(self, auth: AuthTokenProvider) -> types.ClientConfigResponse:
"""Call /api/v1/client/config and return server feature flags.
Creates a one-off connection pool with the given auth. Retries
transient failures via execute_with_retries.
"""
kwargs = {**self._constructor_kwargs, "_auth": auth}
pool = ClientConnectionPool(self.get_loop(), 1, kwargs)
async def _once() -> types.ClientConfigResponse:
with pool.aclient() as client:
return await client.service.client_config(
request=types.ClientConfigRequest(sdk_version=tinker_sdk_version)
)
return await self.execute_with_retries(_once)
async def _create_session(
self,
user_metadata: dict[str, str] | None = None,
@ -350,8 +432,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider):
if client_pool_type == ClientConnectionPoolType.TRAIN
else MAX_REQUESTS_PER_HTTPX_CLIENT
)
kwargs = {**self._constructor_kwargs, "_auth": self._default_auth}
self._client_pools[client_pool_type] = ClientConnectionPool(
self.get_loop(), max_requests_per_client, self._constructor_kwargs
self.get_loop(), max_requests_per_client, kwargs
)
return self._client_pools[client_pool_type]

View file

@ -0,0 +1,96 @@
"""Tests for InternalClientHolder helpers."""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from tinker.lib._auth_token_provider import AuthTokenProvider
from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder
from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse
class _MockHolder:
"""Minimal stand-in for testing _fetch_client_config."""
def __init__(self, response: _ClientConfigResponse | Exception) -> None:
service = MagicMock()
if isinstance(response, Exception):
service.client_config = AsyncMock(side_effect=response)
else:
service.client_config = AsyncMock(return_value=response)
client = MagicMock()
client.service = service
cm = MagicMock()
cm.__enter__ = MagicMock(return_value=client)
cm.__exit__ = MagicMock(return_value=None)
self._cm = cm
self._constructor_kwargs: dict[str, Any] = {}
self._default_auth = MagicMock(spec=AuthTokenProvider)
self._loop = asyncio.get_event_loop()
def get_loop(self) -> asyncio.AbstractEventLoop:
return self._loop
async def execute_with_retries(self, func: Any, *args: Any, **kwargs: Any) -> Any:
return await func(*args, **kwargs)
# Bind the real method so the pool it creates uses our mock client
_fetch_client_config = InternalClientHolder._fetch_client_config
def _patch_pool(monkeypatch: pytest.MonkeyPatch, holder: _MockHolder) -> None:
monkeypatch.setattr(ClientConnectionPool, "aclient", lambda self: holder._cm)
# ---------------------------------------------------------------------------
# _fetch_client_config
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_fetch_client_config_returns_flags_from_server(
monkeypatch: pytest.MonkeyPatch,
) -> None:
holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=True))
_patch_pool(monkeypatch, holder)
result = await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
assert result.pjwt_auth_enabled is True
@pytest.mark.asyncio
async def test_fetch_client_config_returns_defaults_when_server_disables(
monkeypatch: pytest.MonkeyPatch,
) -> None:
holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=False))
_patch_pool(monkeypatch, holder)
result = await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
assert result.pjwt_auth_enabled is False
@pytest.mark.asyncio
async def test_fetch_client_config_raises_on_network_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
holder = _MockHolder(Exception("connection refused"))
_patch_pool(monkeypatch, holder)
with pytest.raises(Exception, match="connection refused"):
await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
@pytest.mark.asyncio
async def test_fetch_client_config_passes_sdk_version(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from tinker._version import __version__ as tinker_sdk_version
holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=False))
_patch_pool(monkeypatch, holder)
await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type]
call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args
assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version

View file

@ -451,7 +451,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
_SamplingClientPickleState(
session_id=self.holder.get_session_id(),
sampling_session_id=self._sampling_session_id,
constructor_kwargs=self.holder._constructor_kwargs,
constructor_kwargs=self.holder.shadow_kwargs,
subprocess_sampling=self._sampling_client_sidecar_handle is not None,
),
),

View file

@ -230,10 +230,26 @@ class ServiceClient(TelemetryProvider):
user_metadata,
).result_async()
def _get_rest_client_for_weights(self, weights_access_token: str | None = None) -> RestClient:
"""Get a rest client for weights info lookups.
If weights_access_token is provided, creates a separate ServiceClient
authenticated with that token.
"""
if weights_access_token is not None:
token_client = ServiceClient(
api_key=weights_access_token, **self.holder._constructor_kwargs
)
return token_client.create_rest_client()
return self.create_rest_client()
@sync_only
@capture_exceptions(fatal=True)
def create_training_client_from_state(
self, path: str, user_metadata: dict[str, str] | None = None
self,
path: str,
user_metadata: dict[str, str] | None = None,
weights_access_token: str | None = None,
) -> TrainingClient:
"""Create a TrainingClient from saved model weights.
@ -243,6 +259,7 @@ class ServiceClient(TelemetryProvider):
Args:
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
- `user_metadata`: Optional metadata to attach to the new training run
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
Returns:
- `TrainingClient` loaded with the specified weights
@ -256,7 +273,7 @@ class ServiceClient(TelemetryProvider):
# Continue training from the loaded state
```
"""
rest_client = self.create_rest_client()
rest_client = self._get_rest_client_for_weights(weights_access_token)
# Use weights info endpoint which allows access to models with public checkpoints
weights_info = rest_client.get_weights_info_by_tinker_path(path).result()
@ -271,15 +288,18 @@ class ServiceClient(TelemetryProvider):
user_metadata=user_metadata,
)
training_client.load_state(path).result()
training_client.load_state(path, weights_access_token=weights_access_token).result()
return training_client
@capture_exceptions(fatal=True)
async def create_training_client_from_state_async(
self, path: str, user_metadata: dict[str, str] | None = None
self,
path: str,
user_metadata: dict[str, str] | None = None,
weights_access_token: str | None = None,
) -> TrainingClient:
"""Async version of create_training_client_from_state."""
rest_client = self.create_rest_client()
rest_client = self._get_rest_client_for_weights(weights_access_token)
# Use weights info endpoint which allows access to models with public checkpoints
weights_info = await rest_client.get_weights_info_by_tinker_path(path)
@ -297,14 +317,19 @@ class ServiceClient(TelemetryProvider):
user_metadata=user_metadata,
)
load_future = await training_client.load_state_async(path)
load_future = await training_client.load_state_async(
path, weights_access_token=weights_access_token
)
await load_future.result_async()
return training_client
@sync_only
@capture_exceptions(fatal=True)
def create_training_client_from_state_with_optimizer(
self, path: str, user_metadata: dict[str, str] | None = None
self,
path: str,
user_metadata: dict[str, str] | None = None,
weights_access_token: str | None = None,
) -> TrainingClient:
"""Create a TrainingClient from saved model weights and optimizer state.
@ -315,6 +340,7 @@ class ServiceClient(TelemetryProvider):
Args:
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
- `user_metadata`: Optional metadata to attach to the new training run
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
Returns:
- `TrainingClient` loaded with the specified weights and optimizer state
@ -328,7 +354,7 @@ class ServiceClient(TelemetryProvider):
# Continue training with restored optimizer momentum
```
"""
rest_client = self.create_rest_client()
rest_client = self._get_rest_client_for_weights(weights_access_token)
# Use weights info endpoint which allows access to models with public checkpoints
weights_info = rest_client.get_weights_info_by_tinker_path(path).result()
@ -343,15 +369,20 @@ class ServiceClient(TelemetryProvider):
user_metadata=user_metadata,
)
training_client.load_state_with_optimizer(path).result()
training_client.load_state_with_optimizer(
path, weights_access_token=weights_access_token
).result()
return training_client
@capture_exceptions(fatal=True)
async def create_training_client_from_state_with_optimizer_async(
self, path: str, user_metadata: dict[str, str] | None = None
self,
path: str,
user_metadata: dict[str, str] | None = None,
weights_access_token: str | None = None,
) -> TrainingClient:
"""Async version of create_training_client_from_state_with_optimizer."""
rest_client = self.create_rest_client()
rest_client = self._get_rest_client_for_weights(weights_access_token)
# Use weights info endpoint which allows access to models with public checkpoints
weights_info = await rest_client.get_weights_info_by_tinker_path(path)
@ -369,7 +400,9 @@ class ServiceClient(TelemetryProvider):
user_metadata=user_metadata,
)
load_future = await training_client.load_state_with_optimizer_async(path)
load_future = await training_client.load_state_with_optimizer_async(
path, weights_access_token=weights_access_token
)
await load_future.result_async()
return training_client

View file

@ -6,10 +6,12 @@ import asyncio
import logging
import threading
import time
import warnings
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple
from tinker import types
from tinker._exceptions import ConflictError
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
@ -40,6 +42,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# FwdBwdChunkSize
MAX_CHUNK_LEN = 1024
MAX_CHUNK_BYTES_COUNT = 5000000
@ -603,13 +606,44 @@ class TrainingClient(TelemetryProvider):
seq_id=request_id + 1,
ttl_seconds=ttl_seconds,
)
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save(
request=request,
try:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save(
request=request,
max_retries=0,
)
except ConflictError:
# 409 means a checkpoint with this name already exists.
# This is common when retrying after a transient network
# error — the first attempt saved the checkpoint but the
# response was lost. Treat as success: the checkpoint IS
# saved, and crashing a long training run is worse than
# returning a synthetic response.
logger.info(
"Checkpoint '%s' already exists (409 Conflict); "
"treating as success — the checkpoint is saved.",
name,
)
if telemetry := self.holder.get_telemetry():
telemetry.log(
"training_client.save_state.conflict_resolved",
event_data={
"checkpoint_name": name,
"model_id": self._guaranteed_model_id(),
},
severity="INFO",
)
return None
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
# _send_request returns None on 409 conflict (checkpoint already
# saved), or an UntypedAPIFuture on success.
if future is None:
model_id = self._guaranteed_model_id()
return types.SaveWeightsResponse(path=f"tinker://{model_id}/weights/{name}")
return await _APIFuture(
types.SaveWeightsResponse,
self.holder,
@ -629,7 +663,11 @@ class TrainingClient(TelemetryProvider):
@capture_exceptions(fatal=True)
async def _load_state_impl(
self, request_id: int, path: str, optimizer: bool
self,
request_id: int,
path: str,
optimizer: bool,
weights_access_token: str | None = None,
) -> types.LoadWeightsResponse:
start_time = time.time()
@ -639,6 +677,7 @@ class TrainingClient(TelemetryProvider):
path=path,
seq_id=request_id + 1,
optimizer=optimizer,
weights_access_token=weights_access_token,
)
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.load(
@ -657,7 +696,9 @@ class TrainingClient(TelemetryProvider):
)
@capture_exceptions(fatal=True)
def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
def load_state(
self, path: str, weights_access_token: str | None = None
) -> APIFuture[types.LoadWeightsResponse]:
"""Load model weights from a saved checkpoint.
This loads only the model weights, not optimizer state (e.g., Adam momentum).
@ -665,6 +706,7 @@ class TrainingClient(TelemetryProvider):
Args:
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
Returns:
- `APIFuture` containing the load response
@ -678,18 +720,27 @@ class TrainingClient(TelemetryProvider):
```
"""
request_id = self._get_request_id()
return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, False))
return self.holder.run_coroutine_threadsafe(
self._load_state_impl(
request_id, path, False, weights_access_token=weights_access_token
)
)
async def load_state_async(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
async def load_state_async(
self, path: str, weights_access_token: str | None = None
) -> APIFuture[types.LoadWeightsResponse]:
"""Async version of load_state."""
return self.load_state(path)
return self.load_state(path, weights_access_token=weights_access_token)
@capture_exceptions(fatal=True)
def load_state_with_optimizer(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
def load_state_with_optimizer(
self, path: str, weights_access_token: str | None = None
) -> APIFuture[types.LoadWeightsResponse]:
"""Load model weights and optimizer state from a checkpoint.
Args:
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
- `weights_access_token`: Optional access token for loading checkpoints under a different account.
Returns:
- `APIFuture` containing the load response
@ -705,13 +756,15 @@ class TrainingClient(TelemetryProvider):
```
"""
request_id = self._get_request_id()
return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, True))
return self.holder.run_coroutine_threadsafe(
self._load_state_impl(request_id, path, True, weights_access_token=weights_access_token)
)
async def load_state_with_optimizer_async(
self, path: str
self, path: str, weights_access_token: str | None = None
) -> APIFuture[types.LoadWeightsResponse]:
"""Async version of load_state_with_optimizer."""
return self.load_state_with_optimizer(path)
return self.load_state_with_optimizer(path, weights_access_token=weights_access_token)
@capture_exceptions(fatal=True)
async def _save_weights_for_sampler_impl(
@ -739,13 +792,46 @@ class TrainingClient(TelemetryProvider):
sampling_session_seq_id=sampling_session_seq_id,
ttl_seconds=ttl_seconds,
)
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save_for_sampler(
request=request,
try:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.save_for_sampler(
request=request,
max_retries=0,
)
except ConflictError:
if name is None:
# Unnamed saves use server-generated unique paths;
# 409 should be impossible. Re-raise as a real error.
raise
# See save_state for full rationale on treating 409 as success.
logger.info(
"Sampler checkpoint '%s' already exists (409 Conflict); "
"treating as success — the checkpoint is saved.",
name,
)
if telemetry := self.holder.get_telemetry():
telemetry.log(
"training_client.save_weights_for_sampler.conflict_resolved",
event_data={
"checkpoint_name": name,
"model_id": self._guaranteed_model_id(),
},
severity="INFO",
)
return None
async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
# _send_request returns None on 409 conflict (checkpoint already
# saved), or an UntypedAPIFuture on success.
if future is None:
assert name is not None
model_id = self._guaranteed_model_id()
return types.SaveWeightsForSamplerResponseInternal(
path=f"tinker://{model_id}/sampler_weights/{name}"
)
return await _APIFuture(
types.SaveWeightsForSamplerResponseInternal,
self.holder,
@ -906,7 +992,7 @@ class TrainingClient(TelemetryProvider):
"""Save current weights and create a SamplingClient for inference.
Args:
- `name`: Optional name for the saved weights (currently ignored for ephemeral saves)
- `name`: Deprecated, has no effect. Will be removed in a future release.
- `retry_config`: Optional configuration for retrying failed requests
Returns:
@ -923,8 +1009,17 @@ class TrainingClient(TelemetryProvider):
result = sampling_client.sample(prompt, 1, params).result()
```
"""
# Ignore name argument for ephemeral save weights for sampler
_ = name
if name is not None:
warnings.warn(
"The 'name' parameter of save_weights_and_get_sampling_client() is deprecated "
"and has no effect — checkpoints are always ephemeral. "
"This parameter will be removed in a future release. "
"Remove the 'name' argument from your call. "
"If you need a persistent checkpoint, use "
"save_weights_for_sampler(name=...) + create_sampling_client(model_path=...) instead.",
DeprecationWarning,
stacklevel=2,
)
return self.save_weights_and_get_sampling_client_submit(retry_config).result()
@capture_exceptions(fatal=True)
@ -932,8 +1027,17 @@ class TrainingClient(TelemetryProvider):
self, name: str | None = None, retry_config: RetryConfig | None = None
) -> SamplingClient:
"""Async version of save_weights_and_get_sampling_client."""
# Ignore name argument for ephemeral save weights for sampler
_ = name
if name is not None:
warnings.warn(
"The 'name' parameter of save_weights_and_get_sampling_client_async() is deprecated "
"and has no effect — checkpoints are always ephemeral. "
"This parameter will be removed in a future release. "
"Remove the 'name' argument from your call. "
"If you need a persistent checkpoint, use "
"save_weights_for_sampler(name=...) + create_sampling_client(model_path=...) instead.",
DeprecationWarning,
stacklevel=2,
)
return await self.save_weights_and_get_sampling_client_submit(retry_config)
def get_telemetry(self) -> Telemetry | None:

View file

@ -73,7 +73,7 @@ class AsyncFuturesResource(AsyncAPIResource):
FutureRetrieveResponse,
await self._post(
"/api/v1/retrieve_future",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=cast(
Any, FutureRetrieveResponse

View file

@ -60,7 +60,7 @@ class AsyncModelsResource(AsyncAPIResource):
return await self._post(
"/api/v1/create_model",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)
@ -106,7 +106,7 @@ class AsyncModelsResource(AsyncAPIResource):
result = await self._post(
"/api/v1/get_info",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=GetInfoResponse,
)
@ -159,7 +159,7 @@ class AsyncModelsResource(AsyncAPIResource):
return await self._post(
"/api/v1/unload_model",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)

View file

@ -56,7 +56,7 @@ class AsyncSamplingResource(AsyncAPIResource):
return await self._post(
"/api/v1/asample",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)

View file

@ -6,6 +6,9 @@ from .._base_client import make_request_options
from .._compat import model_dump
from .._resource import AsyncAPIResource
from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query
from ..types.auth_token_response import AuthTokenResponse
from ..types.client_config_request import ClientConfigRequest
from ..types.client_config_response import ClientConfigResponse
from ..types.create_sampling_session_request import CreateSamplingSessionRequest
from ..types.create_sampling_session_response import CreateSamplingSessionResponse
from ..types.create_session_request import CreateSessionRequest
@ -63,6 +66,54 @@ class AsyncServiceResource(AsyncAPIResource):
cast_to=HealthResponse,
)
async def auth_token(
self,
*,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int | NotGiven = NOT_GIVEN,
) -> AuthTokenResponse:
"""Exchange the current credential for a short-lived JWT."""
options = make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
)
if max_retries is not NOT_GIVEN:
options["max_retries"] = max_retries
return await self._post(
"/api/v1/auth/token",
body={},
options=options,
cast_to=AuthTokenResponse,
)
async def client_config(
self,
*,
request: ClientConfigRequest,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ClientConfigResponse:
"""Fetch server-side feature flags for this client."""
return await self._post(
"/api/v1/client/config",
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
),
cast_to=ClientConfigResponse,
)
async def create_session(
self,
*,
@ -104,7 +155,7 @@ class AsyncServiceResource(AsyncAPIResource):
return await self._post(
"/api/v1/create_session",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=CreateSessionResponse,
)
@ -148,7 +199,7 @@ class AsyncServiceResource(AsyncAPIResource):
request = SessionHeartbeatRequest(session_id=session_id)
return await self._post(
"/api/v1/session_heartbeat",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=SessionHeartbeatResponse,
)
@ -190,7 +241,7 @@ class AsyncServiceResource(AsyncAPIResource):
return await self._post(
"/api/v1/create_sampling_session",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=CreateSamplingSessionResponse,
)

View file

@ -67,7 +67,7 @@ class AsyncTelemetryResource(AsyncAPIResource):
return await self._post(
"/api/v1/telemetry",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=TelemetryResponse,
)

View file

@ -56,7 +56,7 @@ class AsyncTrainingResource(AsyncAPIResource):
return await self._post(
"/api/v1/forward",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)
@ -102,7 +102,7 @@ class AsyncTrainingResource(AsyncAPIResource):
return await self._post(
"/api/v1/forward_backward",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)
@ -148,7 +148,7 @@ class AsyncTrainingResource(AsyncAPIResource):
return await self._post(
"/api/v1/optim_step",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)

View file

@ -61,7 +61,7 @@ class AsyncWeightsResource(AsyncAPIResource):
return await self._post(
"/api/v1/load_weights",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)
@ -107,7 +107,7 @@ class AsyncWeightsResource(AsyncAPIResource):
return await self._post(
"/api/v1/save_weights",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)
@ -153,7 +153,7 @@ class AsyncWeightsResource(AsyncAPIResource):
return await self._post(
"/api/v1/save_weights_for_sampler",
body=model_dump(request, exclude_unset=True, mode="json"),
body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"),
options=options,
cast_to=UntypedAPIFuture,
)

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from .auth_token_response import AuthTokenResponse as AuthTokenResponse
from .checkpoint import (
Checkpoint as Checkpoint,
)
@ -13,6 +14,8 @@ from .checkpoint_archive_url_response import (
CheckpointArchiveUrlResponse as CheckpointArchiveUrlResponse,
)
from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse
from .client_config_request import ClientConfigRequest as ClientConfigRequest
from .client_config_response import ClientConfigResponse as ClientConfigResponse
from .create_model_request import CreateModelRequest as CreateModelRequest
from .create_model_response import CreateModelResponse as CreateModelResponse
from .create_sampling_session_request import (

View file

@ -0,0 +1,9 @@
from __future__ import annotations
from .._models import BaseModel
__all__ = ["AuthTokenResponse"]
class AuthTokenResponse(BaseModel):
jwt: str

View file

@ -0,0 +1,10 @@
from __future__ import annotations
from .._models import StrictBase
__all__ = ["ClientConfigRequest"]
class ClientConfigRequest(StrictBase):
sdk_version: str
"""The SDK version string for flag resolution."""

View file

@ -0,0 +1,18 @@
from __future__ import annotations
from .._models import BaseModel
__all__ = ["ClientConfigResponse"]
class ClientConfigResponse(BaseModel):
"""Server-side feature flags resolved for this caller.
Uses BaseModel (extra="ignore") so new flags from the server are
silently dropped until the SDK adds fields for them.
"""
pjwt_auth_enabled: bool = False
credential_default_source: str = "api_key"
sample_dispatch_bytes_semaphore_size: int = 10 * 1024 * 1024
inflight_response_bytes_semaphore_size: int = 50 * 1024 * 1024

View file

@ -46,6 +46,9 @@ class Datum(StrictBase):
def _maybe_convert_array(cls, key: str, value: Any) -> Any:
"""Convert torch.Tensor, numpy array, or numeric lists to TensorData if needed."""
if _HAVE_TORCH and isinstance(value, torch.Tensor):
# Auto-sparsify 2-D target_tokens and weights to reduce wire payload
if key in _sparse_eligible_keys and value.ndim == 2:
return TensorData.from_torch_sparse(value)
return TensorData.from_torch(value)
elif isinstance(value, np.ndarray):
return TensorData.from_numpy(value)
@ -81,3 +84,5 @@ _key_to_type = {
"clip_low_threshold": "float32",
"clip_high_threshold": "float32",
}
_sparse_eligible_keys = {"target_tokens", "weights"}

View file

@ -11,6 +11,9 @@ class SupportedModel(BaseModel):
model_name: Optional[str] = None
"""The name of the supported model."""
max_context_length: Optional[int] = None
"""The maximum context length (in tokens) supported by this model."""
class GetServerCapabilitiesResponse(BaseModel):
"""Response containing the server's supported models and capabilities."""

View file

@ -20,6 +20,9 @@ class LoadWeightsRequest(StrictBase):
seq_id: Optional[int] = None
weights_access_token: Optional[str] = None
"""Optional access token for loading checkpoints under a different account."""
type: Literal["load_weights"] = "load_weights"
if PYDANTIC_V2:

View file

@ -33,6 +33,17 @@ class TensorData(StrictBase):
provided, and is generally inferred as a 1D tensor.
"""
sparse_crow_indices: Optional[List[int]] = None
"""Optional CSR compressed row pointers. When set, this tensor is sparse CSR:
- data contains only the non-zero values (flattened)
- sparse_crow_indices contains the row pointers (length = nrows + 1)
- sparse_col_indices contains the column indices (length = nnz)
- shape is required and specifies the dense shape
"""
sparse_col_indices: Optional[List[int]] = None
"""Optional CSR column indices. Must be set together with sparse_crow_indices."""
@classmethod
def from_numpy(cls, array: npt.NDArray[Any]) -> "TensorData":
return cls(
@ -49,8 +60,41 @@ class TensorData(StrictBase):
shape=list(tensor.shape),
)
@classmethod
def from_torch_sparse(cls, tensor: "torch.Tensor") -> "TensorData":
"""Create a sparse CSR TensorData from a dense 2-D torch tensor.
Automatically detects sparsity and encodes as CSR when it saves space.
Falls back to dense if the tensor is 1-D or mostly non-zero.
"""
if not _HAVE_TORCH:
raise ImportError("PyTorch is not installed.")
if tensor.ndim != 2:
return cls.from_torch(tensor)
# Only use sparse if it actually saves space
# Dense: nrows * ncols values
# CSR: (nrows + 1) crow_indices + nnz col_indices + nnz values
nnz = tensor.count_nonzero().item()
dense_size = tensor.shape[0] * tensor.shape[1]
csr_size = (tensor.shape[0] + 1) + 2 * nnz
if csr_size >= dense_size:
return cls.from_torch(tensor)
sparse_csr = tensor.to_sparse_csr()
return cls(
data=sparse_csr.values().tolist(),
dtype=_convert_torch_dtype_to_tensor(tensor.dtype),
shape=list(tensor.shape),
sparse_crow_indices=sparse_csr.crow_indices().tolist(),
sparse_col_indices=sparse_csr.col_indices().tolist(),
)
def to_numpy(self) -> npt.NDArray[Any]:
"""Convert TensorData to numpy array."""
if self.sparse_crow_indices is not None:
return self.to_torch().numpy()
numpy_dtype = _convert_tensor_dtype_to_numpy(self.dtype)
arr = np.array(self.data, dtype=numpy_dtype)
if self.shape is not None:
@ -63,6 +107,17 @@ class TensorData(StrictBase):
raise ImportError("PyTorch is not installed. Cannot convert to torch tensor.")
torch_dtype = _convert_tensor_dtype_to_torch(self.dtype)
if self.sparse_crow_indices is not None:
assert self.sparse_col_indices is not None, (
"sparse_col_indices required with sparse_crow_indices"
)
assert self.shape is not None, "shape is required for sparse tensors"
crow = torch.tensor(self.sparse_crow_indices, dtype=torch.int64)
col = torch.tensor(self.sparse_col_indices, dtype=torch.int64)
values = torch.tensor(self.data, dtype=torch_dtype)
return torch.sparse_csr_tensor(crow, col, values, self.shape).to_dense()
tensor = torch.tensor(self.data, dtype=torch_dtype)
if self.shape is not None:
tensor = tensor.reshape(self.shape)

View file

@ -0,0 +1,79 @@
"""Tests for 409 ConflictError recovery in checkpoint save operations."""
from __future__ import annotations
import asyncio
from contextlib import contextmanager
from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock
import httpx
import pytest
from tinker._exceptions import ConflictError
from tinker.lib.public_interfaces.training_client import TrainingClient
def _make_conflict_error() -> ConflictError:
"""Create a ConflictError for testing."""
request = httpx.Request("POST", "http://test/api/v1/save_weights")
response = httpx.Response(409, request=request)
return ConflictError("conflict", response=response, body=None)
def _make_mock_holder() -> Mock:
"""Create a mock InternalClientHolder whose weights.save() raises ConflictError."""
mock_client = MagicMock()
mock_client.weights.save = AsyncMock(side_effect=_make_conflict_error())
mock_client.weights.save_for_sampler = AsyncMock(side_effect=_make_conflict_error())
@contextmanager
def fake_aclient(*args: Any, **kwargs: Any):
yield mock_client
async def fake_execute_with_retries(fn: Any, *args: Any, **kwargs: Any) -> Any:
return await fn(*args, **kwargs)
holder = Mock()
holder.aclient = fake_aclient
holder.get_telemetry = Mock(return_value=None)
holder.execute_with_retries = fake_execute_with_retries
holder.get_loop = Mock(side_effect=lambda: asyncio.get_event_loop())
def fake_run_coroutine_threadsafe(coro: Any) -> Any:
return asyncio.ensure_future(coro)
holder.run_coroutine_threadsafe = fake_run_coroutine_threadsafe
return holder
@pytest.mark.asyncio
async def test_save_state_returns_synthetic_path_on_conflict() -> None:
"""save_state catches 409 and returns SaveWeightsResponse with synthetic path."""
holder = _make_mock_holder()
client = TrainingClient(holder, model_seq_id=0, model_id="model-123")
result = await client.save_state("ckpt-001")
assert result.path == "tinker://model-123/weights/ckpt-001"
@pytest.mark.asyncio
async def test_save_weights_for_sampler_returns_synthetic_path_on_conflict() -> None:
"""save_weights_for_sampler catches 409 and returns response with synthetic path."""
holder = _make_mock_holder()
holder._sampling_client_counter = 0
client = TrainingClient(holder, model_seq_id=0, model_id="model-789")
result = await client.save_weights_for_sampler("ckpt-001")
assert result.path == "tinker://model-789/sampler_weights/ckpt-001"
@pytest.mark.asyncio
async def test_save_weights_for_sampler_unnamed_reraises_conflict() -> None:
"""409 on unnamed sampler save (name=None) should re-raise, not swallow."""
holder = _make_mock_holder()
holder._sampling_client_counter = 0
client = TrainingClient(holder, model_seq_id=0, model_id="model-000")
with pytest.raises(ConflictError):
await client.save_weights_for_sampler(None)