diff --git a/.sync_state b/.sync_state index 4930328..70ffc85 100644 --- a/.sync_state +++ b/.sync_state @@ -1,4 +1,4 @@ { - "last_synced_sha": "db025e90079a19c36090a13aa88e4b2494d5a502", - "last_sync_time": "2026-03-19T02:39:30.785199" + "last_synced_sha": "d117d1692821faa297ea5d2ee7e4dc21b5c8bd0a", + "last_sync_time": "2026-04-14T00:00:48.831738" } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index c9a8cc7..bc3516e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "tinker" -version = "0.16.1" +version = "0.18.0" description = "The official Python SDK for the tinker API" readme = "README.md" license = "Apache-2.0" diff --git a/src/tinker/_base_client.py b/src/tinker/_base_client.py index 2990ea0..3480b4b 100644 --- a/src/tinker/_base_client.py +++ b/src/tinker/_base_client.py @@ -428,7 +428,9 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]): retries_taken: int = 0, ) -> httpx.Request: if log.isEnabledFor(logging.DEBUG): - log.debug("Request options: %s", model_dump(options, exclude_unset=True)) + log.debug( + "Request options: %s", model_dump(options, exclude_unset=False, exclude_none=True) + ) kwargs: dict[str, Any] = {} diff --git a/src/tinker/_client.py b/src/tinker/_client.py index 50bbff8..25d5005 100644 --- a/src/tinker/_client.py +++ b/src/tinker/_client.py @@ -12,7 +12,7 @@ from ._base_client import ( AsyncAPIClient, ) from ._compat import cached_property -from ._exceptions import APIStatusError, TinkerError +from ._exceptions import APIStatusError from ._qs import Querystring from ._streaming import AsyncStream as AsyncStream from ._types import ( @@ -26,6 +26,7 @@ from ._types import ( ) from ._utils import get_async_library, is_given from ._version import __version__ +from .lib._auth_token_provider import ApiKeyAuthProvider, AuthTokenProvider if TYPE_CHECKING: from .resources import futures, telemetry @@ -47,9 +48,6 @@ __all__ = [ class AsyncTinker(AsyncAPIClient): - # client options - api_key: str - def __init__( self, *, @@ -72,20 +70,16 @@ class AsyncTinker(AsyncAPIClient): # outlining your use-case to help us decide if it should be # part of our public interface in the future. _strict_response_validation: bool = False, + _auth: AuthTokenProvider | None = None, ) -> None: """Construct a new async AsyncTinker client instance. This automatically infers the `api_key` argument from the `TINKER_API_KEY` environment variable if it is not provided. """ - if api_key is None: - api_key = os.environ.get("TINKER_API_KEY") - if api_key is None: - raise TinkerError( - "The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable" - ) - if not api_key.startswith("tml-"): - raise TinkerError("The api_key must start with the 'tml-' prefix") - self.api_key = api_key + if _auth is not None: + self._auth = _auth + else: + self._auth = ApiKeyAuthProvider(api_key=api_key) if base_url is None: base_url = os.environ.get("TINKER_BASE_URL") @@ -158,9 +152,8 @@ class AsyncTinker(AsyncAPIClient): @property @override - def auth_headers(self) -> dict[str, str]: - api_key = self.api_key - return {"X-API-Key": api_key} + def custom_auth(self) -> AuthTokenProvider: + return self._auth @property @override @@ -174,7 +167,6 @@ class AsyncTinker(AsyncAPIClient): def copy( self, *, - api_key: str | None = None, base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, @@ -212,7 +204,7 @@ class AsyncTinker(AsyncAPIClient): http_client = http_client or self._client return self.__class__( - api_key=api_key or self.api_key, + _auth=self._auth, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, diff --git a/src/tinker/_compat.py b/src/tinker/_compat.py index 7414246..d4617e9 100644 --- a/src/tinker/_compat.py +++ b/src/tinker/_compat.py @@ -136,6 +136,7 @@ def model_dump( exclude: IncEx | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, + exclude_none: bool = False, warnings: bool = True, mode: Literal["json", "python"] = "python", ) -> dict[str, Any]: @@ -145,6 +146,7 @@ def model_dump( exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, + exclude_none=exclude_none, # warnings are not supported in Pydantic v1 warnings=warnings if PYDANTIC_V2 else True, ) @@ -154,6 +156,7 @@ def model_dump( exclude=exclude, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, + exclude_none=exclude_none, ), ) diff --git a/src/tinker/_response.py b/src/tinker/_response.py index ddfc5f3..6ac64ef 100644 --- a/src/tinker/_response.py +++ b/src/tinker/_response.py @@ -21,6 +21,7 @@ from typing import ( import anyio import httpx +import orjson import pydantic from typing_extensions import Awaitable, ParamSpec, get_origin, override @@ -351,7 +352,7 @@ class APIResponse(BaseAPIResponse[R]): def json(self) -> object: """Read and decode the JSON response content.""" self.read() - return self.http_response.json() + return orjson.loads(self.http_response.content) def close(self) -> None: """Close the response and release the connection. @@ -451,7 +452,7 @@ class AsyncAPIResponse(BaseAPIResponse[R]): async def json(self) -> object: """Read and decode the JSON response content.""" await self.read() - return self.http_response.json() + return orjson.loads(self.http_response.content) async def close(self) -> None: """Close the response and release the connection. diff --git a/src/tinker/_utils/_transform.py b/src/tinker/_utils/_transform.py index ba4c2e5..f06e90a 100644 --- a/src/tinker/_utils/_transform.py +++ b/src/tinker/_utils/_transform.py @@ -215,7 +215,7 @@ def _transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True, mode="json") + return model_dump(data, exclude_unset=False, exclude_none=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: @@ -382,7 +382,7 @@ async def _async_transform_recursive( return data if isinstance(data, pydantic.BaseModel): - return model_dump(data, exclude_unset=True, mode="json") + return model_dump(data, exclude_unset=False, exclude_none=True, mode="json") annotated_type = _get_annotated_type(annotation) if annotated_type is None: diff --git a/src/tinker/cli/AGENTS.md b/src/tinker/cli/AGENTS.md new file mode 100644 index 0000000..29655c6 --- /dev/null +++ b/src/tinker/cli/AGENTS.md @@ -0,0 +1,285 @@ +# Tinker CLI Design Documentation + +## Overview + +The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance. + +## Key Design Decisions + +### 1. Lazy Import Strategy with Click + +**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level. + +**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands. + +**Implementation**: +- Main `__init__.py` only imports `click` and `lazy_group` +- Command modules are loaded only when invoked via `LazyGroup` +- Output formatting imports `rich` only when table output is needed +- JSON module imported only when JSON output is requested +- Version information loaded from `_version.py` only when `tinker version` is used + +### 2. Click Framework with LazyGroup + +**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading. + +**Rationale**: +- Click provides cleaner command structure with decorators +- Better subcommand isolation - each command file is self-contained +- Automatic help generation with better formatting +- Built-in type conversion and validation +- LazyGroup enables fast startup by deferring imports + +**LazyGroup Implementation**: +```python +class LazyGroup(click.Group): + def __init__(self, *args, lazy_subcommands=None, **kwargs): + # Map of command name to "module.path:command_name" + self.lazy_subcommands = lazy_subcommands or {} + + def get_command(self, ctx, cmd_name): + if cmd_name in self.lazy_subcommands: + # Import only when command is actually invoked + import_path = self.lazy_subcommands[cmd_name] + module_name, attr_name = import_path.rsplit(":", 1) + mod = importlib.import_module(module_name) + return getattr(mod, attr_name) +``` + +### 3. Hierarchical Command Structure + +**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`. + +**Rationale**: +- Provides a consistent, predictable interface +- Groups related functionality together +- Makes the CLI extensible for future commands +- Follows common CLI patterns (like `git`, `docker`, etc.) + +**Examples**: +- `tinker version` - Show CLI and SDK version +- `tinker run list` - List all training runs +- `tinker run info ` - Show details of a specific run +- `tinker checkpoint list` - List all checkpoints +- `tinker checkpoint info ` - Show checkpoint details +- `tinker checkpoint push-hf ` - Upload a checkpoint to Hugging Face Hub + +### 4. Output System with Inheritance + +**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class. + +**Rationale**: +- Enforces consistent interface across all commands +- Encapsulates output logic with the command that generates it +- Makes it easy to support multiple output formats (table, JSON) +- Keeps related code together in the same module + +**Implementation**: +- `OutputBase` in `output.py` defines the contract +- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`) +- Base class handles format selection and rendering + +### 5. Self-Contained Command Modules + +**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point. + +**Rationale**: +- Modular architecture - commands can be developed independently +- Clear separation of concerns +- Easy to add new commands without modifying core files +- Consistent pattern across all commands + +**Command Structure**: +```python +# Each command file follows this pattern: +@click.group() # or @click.command() for simple commands +def cli(): + """Command description.""" + pass + +@cli.command() # For subcommands +def list(): + """Subcommand implementation.""" + pass +``` + +### 6. Centralized Client Management + +**Decision**: All SDK client creation and error handling is centralized in `client.py`. + +**Rationale**: +- Single place to handle authentication and connection errors +- Consistent error messages across all commands +- Reusable error handling decorator +- Clean separation of concerns + +### 7. Rich Tables for Human-Readable Output + +**Decision**: Use the `rich` library for table formatting, kept as an optional dependency. + +**Rationale**: +- Provides beautiful, formatted tables with colors and borders +- Handles column width adjustment automatically +- Supports both dark and light terminal themes +- Optional dependency keeps the core package lightweight + +### 8. Unix-Style Default Output + +**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output. + +**Rationale**: +- Follows Unix philosophy +- Tables are better for human consumption +- JSON is better for scripting and automation +- Single flag switches between formats consistently + +## Performance Optimizations + +1. **LazyGroup for deferred imports** - Commands only loaded when invoked +2. **No heavy imports at module level** - Only Click imported initially +3. **Lazy loading** of all SDK dependencies +4. **Progress indicators** that clear themselves +5. **Efficient data fetching** - fetch all data by default instead of pagination + +## Error Handling Strategy + +1. **User-friendly messages** - Technical errors are translated to helpful messages +2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes +3. **Graceful degradation** - Continue operation when possible +4. **Detailed error info** - Show details when available, traceback only in TTY + +### TinkerCliError Exception Pattern + +All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`: + +```python +from ..exceptions import TinkerCliError + +# Instead of: +print(f"Error: Something went wrong", file=sys.stderr) +sys.exit(1) + +# Use: +raise TinkerCliError( + "Something went wrong", + "Optional details or help text", + exit_code=1 # Optional, defaults to 1 +) +``` + +**Benefits:** +- Better testability (can catch exceptions in tests) +- Centralized error formatting in `__main__.py` +- Consistent exit codes across the CLI +- Stack traces preserved for debugging + +**Important Notes:** +- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification +- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages +- The main error handler in `__main__.py` handles printing to stderr and exiting + +## Future Extensibility + +The architecture supports easy addition of: + +### New Commands +- Create new module in `commands/` directory +- Define output classes in the same module if needed +- Add command to lazy_subcommands in `__init__.py` + +### New Subcommands +- Add new Click command decorator to existing command module +- Define corresponding output class if needed +- Subcommands automatically discovered by Click + +### New Output Formats +- Override `print()` method in `OutputBase` +- Or add new format handling to base class + +## Testing Guidelines + +1. **Startup time**: `time tinker --help` should be <50ms +2. **Import verification**: Check that modules aren't imported unnecessarily +3. **Output formats**: Test both table and JSON output +4. **Error cases**: Test with missing auth, invalid IDs, network errors +5. **Empty results**: Ensure graceful handling of no data + +## Module Structure + +``` +cli/ +├── __init__.py # Main entry with LazyGroup configuration +├── __main__.py # Module execution support +├── lazy_group.py # LazyGroup implementation for lazy loading +├── output.py # OutputBase class and formatting utilities +├── client.py # SDK client creation and error handling +├── commands/ +│ ├── __init__.py # Command module marker +│ ├── version.py # Version command +│ ├── run.py # Run commands and output classes +│ └── checkpoint.py # Checkpoint commands and output classes +└── CLAUDE.md # This documentation +``` + +## Command Examples + +```bash +# Show version +tinker version + +# List all training runs +tinker run list + +# Show run details +tinker run info run-abc123 + +# List all checkpoints +tinker checkpoint list + +# List checkpoints for specific run +tinker checkpoint list run-abc123 + +# Show checkpoint details +tinker checkpoint info ckpt-xyz789 + +# Upload checkpoint to Hugging Face Hub +tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter + +# JSON output +tinker --format json run list +tinker --format json checkpoint list +``` + +## Dependencies + +### Required +- Python 3.11+ +- tinker SDK (main package) +- click>=8.0.0 (CLI framework) + +### Optional +- `rich` - For table formatting (installed with `pip install tinker[cli]`) + +## Migration from Argparse to Click + +### Key Changes: +1. **Command Definition**: Decorators instead of `parser.add_argument()` +2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch +3. **Context Passing**: Click's context system for sharing format option +4. **Error Handling**: Click handles exits and error formatting +5. **Help Generation**: Automatic from docstrings and decorators + +### Benefits: +- Cleaner, more Pythonic code +- Better command organization +- Built-in testing utilities +- Easier to extend with plugins +- More consistent behavior across commands + +## Maintenance Notes + +1. **Keep imports lazy** - Use LazyGroup for all commands +2. **Test startup time** - Regularly verify fast startup is maintained +3. **Follow Click patterns** - Use decorators and context properly +4. **Document changes** - Update this file when making architectural changes +5. **Maintain consistency** - All commands should follow the same structure diff --git a/src/tinker/cli/CLAUDE.md b/src/tinker/cli/CLAUDE.md index 29655c6..43c994c 100644 --- a/src/tinker/cli/CLAUDE.md +++ b/src/tinker/cli/CLAUDE.md @@ -1,285 +1 @@ -# Tinker CLI Design Documentation - -## Overview - -The Tinker CLI is a command-line interface for the Tinker SDK, designed with a focus on fast startup times, modular architecture, and user-friendly output formats. The CLI uses Click framework with custom lazy loading to maintain performance. - -## Key Design Decisions - -### 1. Lazy Import Strategy with Click - -**Decision**: Use Click framework with a custom `LazyGroup` class for lazy loading. Only Click is imported at the module level. - -**Rationale**: This ensures that `tinker --help` is lightning fast (<50ms startup time). Users shouldn't have to wait for heavy imports when they just want to see available commands. - -**Implementation**: -- Main `__init__.py` only imports `click` and `lazy_group` -- Command modules are loaded only when invoked via `LazyGroup` -- Output formatting imports `rich` only when table output is needed -- JSON module imported only when JSON output is requested -- Version information loaded from `_version.py` only when `tinker version` is used - -### 2. Click Framework with LazyGroup - -**Decision**: Migrated from argparse to Click, implementing a custom `LazyGroup` class that extends Click's Group to support lazy loading. - -**Rationale**: -- Click provides cleaner command structure with decorators -- Better subcommand isolation - each command file is self-contained -- Automatic help generation with better formatting -- Built-in type conversion and validation -- LazyGroup enables fast startup by deferring imports - -**LazyGroup Implementation**: -```python -class LazyGroup(click.Group): - def __init__(self, *args, lazy_subcommands=None, **kwargs): - # Map of command name to "module.path:command_name" - self.lazy_subcommands = lazy_subcommands or {} - - def get_command(self, ctx, cmd_name): - if cmd_name in self.lazy_subcommands: - # Import only when command is actually invoked - import_path = self.lazy_subcommands[cmd_name] - module_name, attr_name = import_path.rsplit(":", 1) - mod = importlib.import_module(module_name) - return getattr(mod, attr_name) -``` - -### 3. Hierarchical Command Structure - -**Decision**: Commands are organized hierarchically with main commands and subcommands (e.g., `tinker run list`, `tinker checkpoint info`), plus standalone commands like `tinker version`. - -**Rationale**: -- Provides a consistent, predictable interface -- Groups related functionality together -- Makes the CLI extensible for future commands -- Follows common CLI patterns (like `git`, `docker`, etc.) - -**Examples**: -- `tinker version` - Show CLI and SDK version -- `tinker run list` - List all training runs -- `tinker run info ` - Show details of a specific run -- `tinker checkpoint list` - List all checkpoints -- `tinker checkpoint info ` - Show checkpoint details -- `tinker checkpoint push-hf ` - Upload a checkpoint to Hugging Face Hub - -### 4. Output System with Inheritance - -**Decision**: Use an abstract base class (`OutputBase`) that all command outputs inherit from. Each command defines its own output class. - -**Rationale**: -- Enforces consistent interface across all commands -- Encapsulates output logic with the command that generates it -- Makes it easy to support multiple output formats (table, JSON) -- Keeps related code together in the same module - -**Implementation**: -- `OutputBase` in `output.py` defines the contract -- Each command module contains its own output classes (e.g., `RunListOutput`, `RunInfoOutput`) -- Base class handles format selection and rendering - -### 5. Self-Contained Command Modules - -**Decision**: Each command is a self-contained Click command/group in its own file with a `cli` entry point. - -**Rationale**: -- Modular architecture - commands can be developed independently -- Clear separation of concerns -- Easy to add new commands without modifying core files -- Consistent pattern across all commands - -**Command Structure**: -```python -# Each command file follows this pattern: -@click.group() # or @click.command() for simple commands -def cli(): - """Command description.""" - pass - -@cli.command() # For subcommands -def list(): - """Subcommand implementation.""" - pass -``` - -### 6. Centralized Client Management - -**Decision**: All SDK client creation and error handling is centralized in `client.py`. - -**Rationale**: -- Single place to handle authentication and connection errors -- Consistent error messages across all commands -- Reusable error handling decorator -- Clean separation of concerns - -### 7. Rich Tables for Human-Readable Output - -**Decision**: Use the `rich` library for table formatting, kept as an optional dependency. - -**Rationale**: -- Provides beautiful, formatted tables with colors and borders -- Handles column width adjustment automatically -- Supports both dark and light terminal themes -- Optional dependency keeps the core package lightweight - -### 8. Unix-Style Default Output - -**Decision**: Default output is human-readable tables, with `--format json` flag for machine-readable output. - -**Rationale**: -- Follows Unix philosophy -- Tables are better for human consumption -- JSON is better for scripting and automation -- Single flag switches between formats consistently - -## Performance Optimizations - -1. **LazyGroup for deferred imports** - Commands only loaded when invoked -2. **No heavy imports at module level** - Only Click imported initially -3. **Lazy loading** of all SDK dependencies -4. **Progress indicators** that clear themselves -5. **Efficient data fetching** - fetch all data by default instead of pagination - -## Error Handling Strategy - -1. **User-friendly messages** - Technical errors are translated to helpful messages -2. **Proper exit codes** - Uses TinkerCliError for consistent exit codes -3. **Graceful degradation** - Continue operation when possible -4. **Detailed error info** - Show details when available, traceback only in TTY - -### TinkerCliError Exception Pattern - -All CLI errors should raise `TinkerCliError` instead of calling `sys.exit()`: - -```python -from ..exceptions import TinkerCliError - -# Instead of: -print(f"Error: Something went wrong", file=sys.stderr) -sys.exit(1) - -# Use: -raise TinkerCliError( - "Something went wrong", - "Optional details or help text", - exit_code=1 # Optional, defaults to 1 -) -``` - -**Benefits:** -- Better testability (can catch exceptions in tests) -- Centralized error formatting in `__main__.py` -- Consistent exit codes across the CLI -- Stack traces preserved for debugging - -**Important Notes:** -- The `handle_api_errors` decorator automatically re-raises `TinkerCliError` without modification -- Always catch and convert specific exceptions to `TinkerCliError` with helpful messages -- The main error handler in `__main__.py` handles printing to stderr and exiting - -## Future Extensibility - -The architecture supports easy addition of: - -### New Commands -- Create new module in `commands/` directory -- Define output classes in the same module if needed -- Add command to lazy_subcommands in `__init__.py` - -### New Subcommands -- Add new Click command decorator to existing command module -- Define corresponding output class if needed -- Subcommands automatically discovered by Click - -### New Output Formats -- Override `print()` method in `OutputBase` -- Or add new format handling to base class - -## Testing Guidelines - -1. **Startup time**: `time tinker --help` should be <50ms -2. **Import verification**: Check that modules aren't imported unnecessarily -3. **Output formats**: Test both table and JSON output -4. **Error cases**: Test with missing auth, invalid IDs, network errors -5. **Empty results**: Ensure graceful handling of no data - -## Module Structure - -``` -cli/ -├── __init__.py # Main entry with LazyGroup configuration -├── __main__.py # Module execution support -├── lazy_group.py # LazyGroup implementation for lazy loading -├── output.py # OutputBase class and formatting utilities -├── client.py # SDK client creation and error handling -├── commands/ -│ ├── __init__.py # Command module marker -│ ├── version.py # Version command -│ ├── run.py # Run commands and output classes -│ └── checkpoint.py # Checkpoint commands and output classes -└── CLAUDE.md # This documentation -``` - -## Command Examples - -```bash -# Show version -tinker version - -# List all training runs -tinker run list - -# Show run details -tinker run info run-abc123 - -# List all checkpoints -tinker checkpoint list - -# List checkpoints for specific run -tinker checkpoint list run-abc123 - -# Show checkpoint details -tinker checkpoint info ckpt-xyz789 - -# Upload checkpoint to Hugging Face Hub -tinker checkpoint push-hf tinker://run-abc123/sampler_weights/000040 --repo username/my-lora-adapter - -# JSON output -tinker --format json run list -tinker --format json checkpoint list -``` - -## Dependencies - -### Required -- Python 3.11+ -- tinker SDK (main package) -- click>=8.0.0 (CLI framework) - -### Optional -- `rich` - For table formatting (installed with `pip install tinker[cli]`) - -## Migration from Argparse to Click - -### Key Changes: -1. **Command Definition**: Decorators instead of `parser.add_argument()` -2. **Lazy Loading**: Custom `LazyGroup` instead of manual dispatch -3. **Context Passing**: Click's context system for sharing format option -4. **Error Handling**: Click handles exits and error formatting -5. **Help Generation**: Automatic from docstrings and decorators - -### Benefits: -- Cleaner, more Pythonic code -- Better command organization -- Built-in testing utilities -- Easier to extend with plugins -- More consistent behavior across commands - -## Maintenance Notes - -1. **Keep imports lazy** - Use LazyGroup for all commands -2. **Test startup time** - Regularly verify fast startup is maintained -3. **Follow Click patterns** - Use decorators and context properly -4. **Document changes** - Update this file when making architectural changes -5. **Maintain consistency** - All commands should follow the same structure +@AGENTS.md diff --git a/src/tinker/lib/_auth_token_provider.py b/src/tinker/lib/_auth_token_provider.py new file mode 100644 index 0000000..ef1a3a4 --- /dev/null +++ b/src/tinker/lib/_auth_token_provider.py @@ -0,0 +1,104 @@ +"""Authentication credential management for the Tinker SDK. + +Provides composable credential providers that plug into httpx's async auth flow: +- AuthTokenProvider: abstract base (httpx.Auth) — subclasses implement get_token() +- ApiKeyAuthProvider: resolves from api_key arg or TINKER_API_KEY env var +- CredentialCmdAuthProvider: runs a command on every call for fresh credentials +- resolve_auth_provider(): factory that picks the right provider +""" + +from __future__ import annotations + +import abc +import asyncio +import os +from collections.abc import AsyncGenerator + +import httpx + +from tinker._exceptions import TinkerError + + +class AuthTokenProvider(httpx.Auth): + """Abstract base auth provider. Subclasses implement get_token().""" + + @abc.abstractmethod + async def get_token(self) -> str | None: ... + + async def async_auth_flow( + self, request: httpx.Request + ) -> AsyncGenerator[httpx.Request, httpx.Response]: + token = await self.get_token() + if token: + request.headers["X-API-Key"] = token + yield request + + +class ApiKeyAuthProvider(AuthTokenProvider): + """Resolves api_key from constructor arg or TINKER_API_KEY env var.""" + + def __init__(self, api_key: str | None = None) -> None: + resolved = api_key or os.environ.get("TINKER_API_KEY") + if not resolved: + raise TinkerError( + "The api_key client option must be set either by passing api_key to the client" + " or by setting the TINKER_API_KEY environment variable" + ) + if not resolved.startswith("tml-") and not resolved.startswith("eyJ"): + raise TinkerError("The api_key must start with the 'tml-' prefix") + self._token = resolved + + async def get_token(self) -> str | None: + return self._token + + +class CredentialCmdAuthProvider(AuthTokenProvider): + """Runs TINKER_CREDENTIAL_CMD on every get_token() call. + + Always produces a fresh credential (e.g. short-lived bearer tokens). + Uses async subprocess to avoid blocking the event loop. + """ + + def __init__(self, cmd: str) -> None: + if not cmd: + raise TinkerError( + "Your organization requires dynamic credentials — set TINKER_CREDENTIAL_CMD" + " to a command that prints a valid credential." + ) + self._cmd = cmd + + async def get_token(self) -> str | None: + proc = await asyncio.create_subprocess_shell( + self._cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + stdout, _ = await proc.communicate() + credential = stdout.decode().strip() + if not credential: + raise TinkerError("TINKER_CREDENTIAL_CMD returned an empty credential.") + return credential + + +def resolve_auth_provider(api_key: str | None, enforce_cmd: bool) -> AuthTokenProvider: + """Construct the appropriate auth provider based on available credentials. + + - enforce_cmd=True: uses TINKER_CREDENTIAL_CMD, unless the api_key is + already a JWT (dynamic credential) — in which case it's used directly. + - enforce_cmd=False: tries api_key first, falls back to TINKER_CREDENTIAL_CMD + """ + credential_cmd = os.environ.get("TINKER_CREDENTIAL_CMD", "") + + # A JWT passed as api_key is already a dynamic credential — use it + # directly even when credential_cmd is enforced. + resolved = api_key or os.environ.get("TINKER_API_KEY", "") + if resolved and resolved.startswith("eyJ"): + return ApiKeyAuthProvider(api_key=resolved) + + if enforce_cmd: + return CredentialCmdAuthProvider(credential_cmd) + + try: + return ApiKeyAuthProvider(api_key=api_key) + except TinkerError: + if credential_cmd: + return CredentialCmdAuthProvider(credential_cmd) + raise diff --git a/src/tinker/lib/_jwt_auth.py b/src/tinker/lib/_jwt_auth.py new file mode 100644 index 0000000..0bf8c3b --- /dev/null +++ b/src/tinker/lib/_jwt_auth.py @@ -0,0 +1,90 @@ +"""JWT authentication for Tinker SDK. + +Internal to the SDK; not part of the public API. + +When the server sets pjwt_auth_enabled, the SDK exchanges the caller's +credential for a short-lived JWT minted by the Tinker server. The JWT is +cached and refreshed in the background before it expires, so callers always +send a valid token without any per-request overhead. +""" + +from __future__ import annotations + +import asyncio +import base64 +import json +import logging +import time +from collections.abc import Callable +from contextlib import AbstractContextManager + +from tinker.lib._auth_token_provider import AuthTokenProvider + +logger = logging.getLogger(__name__) + +_REFRESH_BEFORE_EXPIRY_SECS = 300 # refresh 5 min before expiry +_RETRY_DELAY_SECS = 60 + + +def _jwt_expiry(jwt: str) -> float: + """Return the exp claim of a JWT as a Unix timestamp.""" + try: + payload = jwt.split(".")[1] + payload += "=" * (-len(payload) % 4) + return float(json.loads(base64.urlsafe_b64decode(payload))["exp"]) + except Exception as e: + raise ValueError(f"Failed to parse JWT expiry: {e}") from e + + +class JwtAuthProvider(AuthTokenProvider): + """AuthTokenProvider that exchanges a credential for a short-lived JWT. + + After init(), get_token() returns the current JWT. A background task + refreshes the JWT before it expires. + """ + + def __init__( + self, + aclient_fn: Callable[[], AbstractContextManager], + seed_token: str | None = None, + ) -> None: + self._token = seed_token or "" + self._aclient_fn = aclient_fn + + async def get_token(self) -> str | None: + return self._token + + async def init(self) -> None: + """Fetch a JWT (unless seeded) then start the background refresh loop. + + When seed_token was provided, skips the initial fetch and starts + refreshing from the seed — useful for shadow holders that already + have a valid JWT from the primary holder. + """ + token = self._token if self._token else await self._fetch() + self._refresh_task = asyncio.create_task(self._refresh_loop(token)) + + async def _fetch(self) -> str: + """Exchange the current credential for a JWT via /api/v1/auth/token.""" + with self._aclient_fn() as client: + response = await client.service.auth_token() + self._token = response.jwt + return response.jwt + + async def _refresh_loop(self, token: str) -> None: + while True: + try: + delay = max( + _RETRY_DELAY_SECS, + _jwt_expiry(token) - time.time() - _REFRESH_BEFORE_EXPIRY_SECS, + ) + except ValueError: + logger.debug("Failed to parse JWT expiry, retrying in %ds", _RETRY_DELAY_SECS) + delay = _RETRY_DELAY_SECS + try: + await asyncio.sleep(delay) + token = await self._fetch() + except asyncio.CancelledError: + return + except Exception as e: + logger.debug("JWT refresh failed, retrying in %ds: %s", _RETRY_DELAY_SECS, e) diff --git a/src/tinker/lib/_jwt_auth_test.py b/src/tinker/lib/_jwt_auth_test.py new file mode 100644 index 0000000..e26f155 --- /dev/null +++ b/src/tinker/lib/_jwt_auth_test.py @@ -0,0 +1,156 @@ +"""Tests for JWT authentication helpers.""" + +from __future__ import annotations + +import base64 +import json +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from tinker._exceptions import TinkerError +from tinker.lib._auth_token_provider import ( + ApiKeyAuthProvider, + CredentialCmdAuthProvider, + resolve_auth_provider, +) +from tinker.lib._jwt_auth import ( + JwtAuthProvider, + _jwt_expiry, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_jwt(exp: float) -> str: + """Build a minimal fake JWT with a given exp claim.""" + header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode() + payload_bytes = json.dumps({"exp": exp, "sub": "test"}).encode() + payload = base64.urlsafe_b64encode(payload_bytes).rstrip(b"=").decode() + return f"{header}.{payload}.fakesig" + + +class _MockAuthResponse: + def __init__(self, jwt: str) -> None: + self.jwt = jwt + + +class _MockHolder: + """Minimal mock providing aclient() for testing JwtAuthProvider.""" + + def __init__(self, response_jwt: str, *, fail: bool = False) -> None: + service = MagicMock() + if fail: + service.auth_token = AsyncMock(side_effect=Exception("network error")) + else: + service.auth_token = AsyncMock(return_value=_MockAuthResponse(response_jwt)) + client = MagicMock() + client.service = service + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=client) + cm.__exit__ = MagicMock(return_value=None) + self._cm = cm + + def aclient(self): + return self._cm + + +# --------------------------------------------------------------------------- +# _jwt_expiry +# --------------------------------------------------------------------------- + + +def test_jwt_expiry_parses_valid(): + exp = time.time() + 3600 + assert abs(_jwt_expiry(_make_jwt(exp)) - exp) < 1 + + +def test_jwt_expiry_raises_on_invalid(): + with pytest.raises(Exception): + _jwt_expiry("not.a.jwt") + + +def test_jwt_expiry_raises_on_missing_exp(): + header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode() + payload = base64.urlsafe_b64encode(b'{"sub":"x"}').rstrip(b"=").decode() + with pytest.raises(Exception): + _jwt_expiry(f"{header}.{payload}.sig") + + +# --------------------------------------------------------------------------- +# AuthTokenProvider hierarchy +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_api_key_provider_resolves_key(): + auth = ApiKeyAuthProvider(api_key="tml-test-key") + assert await auth.get_token() == "tml-test-key" + + +@pytest.mark.asyncio +async def test_credential_cmd_provider_runs_command(): + auth = CredentialCmdAuthProvider("echo test-credential") + assert await auth.get_token() == "test-credential" + + +@pytest.mark.asyncio +async def test_resolve_auth_provider_fallback_to_cmd(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("TINKER_API_KEY", raising=False) + monkeypatch.setenv("TINKER_CREDENTIAL_CMD", "echo fallback-cred") + auth = resolve_auth_provider(api_key=None, enforce_cmd=False) + assert isinstance(auth, CredentialCmdAuthProvider) + assert await auth.get_token() == "fallback-cred" + + +def test_credential_cmd_provider_raises_with_empty_cmd(): + with pytest.raises(TinkerError, match="dynamic credentials"): + CredentialCmdAuthProvider("") + + +# --------------------------------------------------------------------------- +# JwtAuthProvider.init +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_init_fetches_jwt_and_stores_it(): + exp = time.time() + 7200 + jwt = _make_jwt(exp) + holder = _MockHolder(jwt) + provider = JwtAuthProvider(holder.aclient) + + await provider.init() + + assert await provider.get_token() == jwt + holder._cm.__enter__.return_value.service.auth_token.assert_called_once() + + +@pytest.mark.asyncio +async def test_init_raises_on_fetch_failure(): + holder = _MockHolder("some-jwt", fail=True) + provider = JwtAuthProvider(holder.aclient) + + with pytest.raises(Exception, match="network error"): + await provider.init() + + +# --------------------------------------------------------------------------- +# JwtAuthProvider._fetch +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_fetch_returns_and_stores_token(): + exp = time.time() + 7200 + jwt = _make_jwt(exp) + holder = _MockHolder(jwt) + provider = JwtAuthProvider(holder.aclient) + + result = await provider._fetch() + + assert result == jwt + assert await provider.get_token() == jwt diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index 80407c3..94ef5cd 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -21,6 +21,12 @@ from tinker import types from tinker._client import AsyncTinker from tinker._exceptions import APIConnectionError, APIStatusError from tinker._version import __version__ as tinker_sdk_version +from tinker.lib._auth_token_provider import ( + ApiKeyAuthProvider, + AuthTokenProvider, + resolve_auth_provider, +) +from tinker.lib._jwt_auth import JwtAuthProvider from tinker.lib.async_tinker_provider import AsyncTinkerProvider from tinker.lib.client_connection_pool_type import ClientConnectionPoolType from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture @@ -180,17 +186,63 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): project_id: str | None = None, *, session_id: str | None = None, + api_key: str | None = None, + _client_config: dict[str, str | int | bool] | None = None, + _jwt_auth_seed: str | None = None, **kwargs: Any, ) -> None: - self._constructor_kwargs = kwargs + self._api_key = api_key + self._constructor_kwargs = dict(kwargs) self._loop: asyncio.AbstractEventLoop = _internal_client_holder_thread_singleton.get_loop() self._client_pools: dict[ClientConnectionPoolType, ClientConnectionPool] = {} self._sample_backoff_until: float | None = None self._sample_dispatch_semaphore: asyncio.Semaphore = asyncio.Semaphore(400) self._sample_dispatch_throttled_semaphore: asyncio.Semaphore = asyncio.Semaphore(10) - self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) - self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore(5 * 1024 * 1024) self._training_client_lock: threading.Lock = threading.Lock() + self._telemetry: Telemetry | None = None + + # Fetch server-side client config before any server contact so that + # flags are available for subsequent setup steps. Shadow holders + # receive the config via kwargs to avoid a redundant fetch (and + # potential deadlock on the event loop thread). + if _client_config is not None: + self._client_config = types.ClientConfigResponse.model_validate(_client_config) + else: + self._assert_not_on_event_loop("fetch client config") + config_auth = resolve_auth_provider(api_key, enforce_cmd=False) + self._client_config = self.run_coroutine_threadsafe( + self._fetch_client_config(config_auth) + ).result() + + self._sample_dispatch_bytes_semaphore: BytesSemaphore = BytesSemaphore( + self._client_config.sample_dispatch_bytes_semaphore_size + ) + self._inflight_response_bytes_semaphore: BytesSemaphore = BytesSemaphore( + self._client_config.inflight_response_bytes_semaphore_size + ) + + if not self._client_config.pjwt_auth_enabled: + # Without JWT exchange, only API keys are accepted by the server. + # Replace any cmd-based provider with a plain API key provider. + self._default_auth = ApiKeyAuthProvider(api_key=api_key) + else: + # Create a dedicated pool for JWT exchange with the appropriate + # credential provider. The lambda captures the pool so it stays alive. + use_cmd = self._client_config.credential_default_source == "credential_cmd" + auth_pool_auth = resolve_auth_provider(self._api_key, use_cmd) + auth_kwargs = {**self._constructor_kwargs, "_auth": auth_pool_auth} + auth_pool = ClientConnectionPool(self.get_loop(), 1, auth_kwargs) + auth_aclient = lambda: auth_pool.aclient() # noqa: E731 + self._default_auth = JwtAuthProvider(auth_aclient, seed_token=_jwt_auth_seed) + if _jwt_auth_seed: + # Shadow holder: start refresh in background, don't block. + self.run_coroutine_threadsafe(self._default_auth.init()) + else: + # Primary holder: must have a valid JWT before proceeding. + self._assert_not_on_event_loop("exchange JWT") + self.run_coroutine_threadsafe( + self.execute_with_retries(self._default_auth.init) + ).result() if session_id is not None: # Shadow mode: reuse existing session, can't create new clients @@ -199,14 +251,7 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): self._sampling_client_counter: int | None = None else: # Normal mode: create new session. - # This blocks on .result() — must NOT be called from the event - # loop thread (e.g. inside the sidecar subprocess). Shadow - # holders (session_id is not None) skip this path. - if self._loop.is_running() and _current_loop() is self._loop: - raise RuntimeError( - "Cannot create a new session from the event loop thread. " - "Use session_id= to create a shadow holder instead." - ) + self._assert_not_on_event_loop("create a new session") self._session_id = self.run_coroutine_threadsafe( self._create_session(user_metadata=user_metadata, project_id=project_id) ).result() @@ -230,6 +275,26 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): """Get or create a shadow holder from the singleton cache.""" return _shadow_holder_singleton.get_or_create(session_id, kwargs) + def _assert_not_on_event_loop(self, action: str) -> None: + """Raise if called from the event loop thread (would deadlock on .result()).""" + if self._loop.is_running() and _current_loop() is self._loop: + raise RuntimeError( + f"Cannot {action} from the event loop thread. " + "Use session_id= to create a shadow holder instead." + ) + + @property + def shadow_kwargs(self) -> dict[str, Any]: + """Constructor kwargs for shadow holders, including cached server config and JWT seed.""" + result = { + **self._constructor_kwargs, + "api_key": self._api_key, + "_client_config": self._client_config.model_dump(), + } + if isinstance(self._default_auth, JwtAuthProvider): + result["_jwt_auth_seed"] = self._default_auth._token + return result + @asynccontextmanager async def _sample_dispatch_count_rate_limit(self): async with self._sample_dispatch_semaphore: @@ -316,6 +381,23 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): """Start the session heartbeat task.""" return asyncio.create_task(self._session_heartbeat(self._session_id)) + async def _fetch_client_config(self, auth: AuthTokenProvider) -> types.ClientConfigResponse: + """Call /api/v1/client/config and return server feature flags. + + Creates a one-off connection pool with the given auth. Retries + transient failures via execute_with_retries. + """ + kwargs = {**self._constructor_kwargs, "_auth": auth} + pool = ClientConnectionPool(self.get_loop(), 1, kwargs) + + async def _once() -> types.ClientConfigResponse: + with pool.aclient() as client: + return await client.service.client_config( + request=types.ClientConfigRequest(sdk_version=tinker_sdk_version) + ) + + return await self.execute_with_retries(_once) + async def _create_session( self, user_metadata: dict[str, str] | None = None, @@ -350,8 +432,9 @@ class InternalClientHolder(AsyncTinkerProvider, TelemetryProvider): if client_pool_type == ClientConnectionPoolType.TRAIN else MAX_REQUESTS_PER_HTTPX_CLIENT ) + kwargs = {**self._constructor_kwargs, "_auth": self._default_auth} self._client_pools[client_pool_type] = ClientConnectionPool( - self.get_loop(), max_requests_per_client, self._constructor_kwargs + self.get_loop(), max_requests_per_client, kwargs ) return self._client_pools[client_pool_type] diff --git a/src/tinker/lib/internal_client_holder_test.py b/src/tinker/lib/internal_client_holder_test.py new file mode 100644 index 0000000..b6bdfc9 --- /dev/null +++ b/src/tinker/lib/internal_client_holder_test.py @@ -0,0 +1,96 @@ +"""Tests for InternalClientHolder helpers.""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from tinker.lib._auth_token_provider import AuthTokenProvider +from tinker.lib.internal_client_holder import ClientConnectionPool, InternalClientHolder +from tinker.types.client_config_response import ClientConfigResponse as _ClientConfigResponse + + +class _MockHolder: + """Minimal stand-in for testing _fetch_client_config.""" + + def __init__(self, response: _ClientConfigResponse | Exception) -> None: + service = MagicMock() + if isinstance(response, Exception): + service.client_config = AsyncMock(side_effect=response) + else: + service.client_config = AsyncMock(return_value=response) + client = MagicMock() + client.service = service + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=client) + cm.__exit__ = MagicMock(return_value=None) + self._cm = cm + + self._constructor_kwargs: dict[str, Any] = {} + self._default_auth = MagicMock(spec=AuthTokenProvider) + self._loop = asyncio.get_event_loop() + + def get_loop(self) -> asyncio.AbstractEventLoop: + return self._loop + + async def execute_with_retries(self, func: Any, *args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + # Bind the real method so the pool it creates uses our mock client + _fetch_client_config = InternalClientHolder._fetch_client_config + + +def _patch_pool(monkeypatch: pytest.MonkeyPatch, holder: _MockHolder) -> None: + monkeypatch.setattr(ClientConnectionPool, "aclient", lambda self: holder._cm) + + +# --------------------------------------------------------------------------- +# _fetch_client_config +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_fetch_client_config_returns_flags_from_server( + monkeypatch: pytest.MonkeyPatch, +) -> None: + holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=True)) + _patch_pool(monkeypatch, holder) + result = await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type] + assert result.pjwt_auth_enabled is True + + +@pytest.mark.asyncio +async def test_fetch_client_config_returns_defaults_when_server_disables( + monkeypatch: pytest.MonkeyPatch, +) -> None: + holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=False)) + _patch_pool(monkeypatch, holder) + result = await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type] + assert result.pjwt_auth_enabled is False + + +@pytest.mark.asyncio +async def test_fetch_client_config_raises_on_network_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + holder = _MockHolder(Exception("connection refused")) + _patch_pool(monkeypatch, holder) + with pytest.raises(Exception, match="connection refused"): + await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_fetch_client_config_passes_sdk_version( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from tinker._version import __version__ as tinker_sdk_version + + holder = _MockHolder(_ClientConfigResponse(pjwt_auth_enabled=False)) + _patch_pool(monkeypatch, holder) + await InternalClientHolder._fetch_client_config(holder, holder._default_auth) # type: ignore[arg-type] + + call_kwargs = holder._cm.__enter__.return_value.service.client_config.call_args + assert call_kwargs.kwargs["request"].sdk_version == tinker_sdk_version diff --git a/src/tinker/lib/public_interfaces/sampling_client.py b/src/tinker/lib/public_interfaces/sampling_client.py index 0f3355c..8a2ff01 100644 --- a/src/tinker/lib/public_interfaces/sampling_client.py +++ b/src/tinker/lib/public_interfaces/sampling_client.py @@ -451,7 +451,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver): _SamplingClientPickleState( session_id=self.holder.get_session_id(), sampling_session_id=self._sampling_session_id, - constructor_kwargs=self.holder._constructor_kwargs, + constructor_kwargs=self.holder.shadow_kwargs, subprocess_sampling=self._sampling_client_sidecar_handle is not None, ), ), diff --git a/src/tinker/lib/public_interfaces/service_client.py b/src/tinker/lib/public_interfaces/service_client.py index bc66215..b0b3c3e 100644 --- a/src/tinker/lib/public_interfaces/service_client.py +++ b/src/tinker/lib/public_interfaces/service_client.py @@ -230,10 +230,26 @@ class ServiceClient(TelemetryProvider): user_metadata, ).result_async() + def _get_rest_client_for_weights(self, weights_access_token: str | None = None) -> RestClient: + """Get a rest client for weights info lookups. + + If weights_access_token is provided, creates a separate ServiceClient + authenticated with that token. + """ + if weights_access_token is not None: + token_client = ServiceClient( + api_key=weights_access_token, **self.holder._constructor_kwargs + ) + return token_client.create_rest_client() + return self.create_rest_client() + @sync_only @capture_exceptions(fatal=True) def create_training_client_from_state( - self, path: str, user_metadata: dict[str, str] | None = None + self, + path: str, + user_metadata: dict[str, str] | None = None, + weights_access_token: str | None = None, ) -> TrainingClient: """Create a TrainingClient from saved model weights. @@ -243,6 +259,7 @@ class ServiceClient(TelemetryProvider): Args: - `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") - `user_metadata`: Optional metadata to attach to the new training run + - `weights_access_token`: Optional access token for loading checkpoints under a different account. Returns: - `TrainingClient` loaded with the specified weights @@ -256,7 +273,7 @@ class ServiceClient(TelemetryProvider): # Continue training from the loaded state ``` """ - rest_client = self.create_rest_client() + rest_client = self._get_rest_client_for_weights(weights_access_token) # Use weights info endpoint which allows access to models with public checkpoints weights_info = rest_client.get_weights_info_by_tinker_path(path).result() @@ -271,15 +288,18 @@ class ServiceClient(TelemetryProvider): user_metadata=user_metadata, ) - training_client.load_state(path).result() + training_client.load_state(path, weights_access_token=weights_access_token).result() return training_client @capture_exceptions(fatal=True) async def create_training_client_from_state_async( - self, path: str, user_metadata: dict[str, str] | None = None + self, + path: str, + user_metadata: dict[str, str] | None = None, + weights_access_token: str | None = None, ) -> TrainingClient: """Async version of create_training_client_from_state.""" - rest_client = self.create_rest_client() + rest_client = self._get_rest_client_for_weights(weights_access_token) # Use weights info endpoint which allows access to models with public checkpoints weights_info = await rest_client.get_weights_info_by_tinker_path(path) @@ -297,14 +317,19 @@ class ServiceClient(TelemetryProvider): user_metadata=user_metadata, ) - load_future = await training_client.load_state_async(path) + load_future = await training_client.load_state_async( + path, weights_access_token=weights_access_token + ) await load_future.result_async() return training_client @sync_only @capture_exceptions(fatal=True) def create_training_client_from_state_with_optimizer( - self, path: str, user_metadata: dict[str, str] | None = None + self, + path: str, + user_metadata: dict[str, str] | None = None, + weights_access_token: str | None = None, ) -> TrainingClient: """Create a TrainingClient from saved model weights and optimizer state. @@ -315,6 +340,7 @@ class ServiceClient(TelemetryProvider): Args: - `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") - `user_metadata`: Optional metadata to attach to the new training run + - `weights_access_token`: Optional access token for loading checkpoints under a different account. Returns: - `TrainingClient` loaded with the specified weights and optimizer state @@ -328,7 +354,7 @@ class ServiceClient(TelemetryProvider): # Continue training with restored optimizer momentum ``` """ - rest_client = self.create_rest_client() + rest_client = self._get_rest_client_for_weights(weights_access_token) # Use weights info endpoint which allows access to models with public checkpoints weights_info = rest_client.get_weights_info_by_tinker_path(path).result() @@ -343,15 +369,20 @@ class ServiceClient(TelemetryProvider): user_metadata=user_metadata, ) - training_client.load_state_with_optimizer(path).result() + training_client.load_state_with_optimizer( + path, weights_access_token=weights_access_token + ).result() return training_client @capture_exceptions(fatal=True) async def create_training_client_from_state_with_optimizer_async( - self, path: str, user_metadata: dict[str, str] | None = None + self, + path: str, + user_metadata: dict[str, str] | None = None, + weights_access_token: str | None = None, ) -> TrainingClient: """Async version of create_training_client_from_state_with_optimizer.""" - rest_client = self.create_rest_client() + rest_client = self._get_rest_client_for_weights(weights_access_token) # Use weights info endpoint which allows access to models with public checkpoints weights_info = await rest_client.get_weights_info_by_tinker_path(path) @@ -369,7 +400,9 @@ class ServiceClient(TelemetryProvider): user_metadata=user_metadata, ) - load_future = await training_client.load_state_with_optimizer_async(path) + load_future = await training_client.load_state_with_optimizer_async( + path, weights_access_token=weights_access_token + ) await load_future.result_async() return training_client diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index b047e2d..cc812f8 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -6,10 +6,12 @@ import asyncio import logging import threading import time +import warnings from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple from tinker import types +from tinker._exceptions import ConflictError from tinker.lib.client_connection_pool_type import ClientConnectionPoolType from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture from tinker.lib.telemetry import Telemetry, capture_exceptions @@ -40,6 +42,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + # FwdBwdChunkSize MAX_CHUNK_LEN = 1024 MAX_CHUNK_BYTES_COUNT = 5000000 @@ -603,13 +606,44 @@ class TrainingClient(TelemetryProvider): seq_id=request_id + 1, ttl_seconds=ttl_seconds, ) - with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: - return await client.weights.save( - request=request, + try: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.weights.save( + request=request, + max_retries=0, + ) + except ConflictError: + # 409 means a checkpoint with this name already exists. + # This is common when retrying after a transient network + # error — the first attempt saved the checkpoint but the + # response was lost. Treat as success: the checkpoint IS + # saved, and crashing a long training run is worse than + # returning a synthetic response. + logger.info( + "Checkpoint '%s' already exists (409 Conflict); " + "treating as success — the checkpoint is saved.", + name, ) + if telemetry := self.holder.get_telemetry(): + telemetry.log( + "training_client.save_state.conflict_resolved", + event_data={ + "checkpoint_name": name, + "model_id": self._guaranteed_model_id(), + }, + severity="INFO", + ) + return None async with self._take_turn(request_id): future = await self.holder.execute_with_retries(_send_request) + + # _send_request returns None on 409 conflict (checkpoint already + # saved), or an UntypedAPIFuture on success. + if future is None: + model_id = self._guaranteed_model_id() + return types.SaveWeightsResponse(path=f"tinker://{model_id}/weights/{name}") + return await _APIFuture( types.SaveWeightsResponse, self.holder, @@ -629,7 +663,11 @@ class TrainingClient(TelemetryProvider): @capture_exceptions(fatal=True) async def _load_state_impl( - self, request_id: int, path: str, optimizer: bool + self, + request_id: int, + path: str, + optimizer: bool, + weights_access_token: str | None = None, ) -> types.LoadWeightsResponse: start_time = time.time() @@ -639,6 +677,7 @@ class TrainingClient(TelemetryProvider): path=path, seq_id=request_id + 1, optimizer=optimizer, + weights_access_token=weights_access_token, ) with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: return await client.weights.load( @@ -657,7 +696,9 @@ class TrainingClient(TelemetryProvider): ) @capture_exceptions(fatal=True) - def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]: + def load_state( + self, path: str, weights_access_token: str | None = None + ) -> APIFuture[types.LoadWeightsResponse]: """Load model weights from a saved checkpoint. This loads only the model weights, not optimizer state (e.g., Adam momentum). @@ -665,6 +706,7 @@ class TrainingClient(TelemetryProvider): Args: - `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + - `weights_access_token`: Optional access token for loading checkpoints under a different account. Returns: - `APIFuture` containing the load response @@ -678,18 +720,27 @@ class TrainingClient(TelemetryProvider): ``` """ request_id = self._get_request_id() - return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, False)) + return self.holder.run_coroutine_threadsafe( + self._load_state_impl( + request_id, path, False, weights_access_token=weights_access_token + ) + ) - async def load_state_async(self, path: str) -> APIFuture[types.LoadWeightsResponse]: + async def load_state_async( + self, path: str, weights_access_token: str | None = None + ) -> APIFuture[types.LoadWeightsResponse]: """Async version of load_state.""" - return self.load_state(path) + return self.load_state(path, weights_access_token=weights_access_token) @capture_exceptions(fatal=True) - def load_state_with_optimizer(self, path: str) -> APIFuture[types.LoadWeightsResponse]: + def load_state_with_optimizer( + self, path: str, weights_access_token: str | None = None + ) -> APIFuture[types.LoadWeightsResponse]: """Load model weights and optimizer state from a checkpoint. Args: - `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001") + - `weights_access_token`: Optional access token for loading checkpoints under a different account. Returns: - `APIFuture` containing the load response @@ -705,13 +756,15 @@ class TrainingClient(TelemetryProvider): ``` """ request_id = self._get_request_id() - return self.holder.run_coroutine_threadsafe(self._load_state_impl(request_id, path, True)) + return self.holder.run_coroutine_threadsafe( + self._load_state_impl(request_id, path, True, weights_access_token=weights_access_token) + ) async def load_state_with_optimizer_async( - self, path: str + self, path: str, weights_access_token: str | None = None ) -> APIFuture[types.LoadWeightsResponse]: """Async version of load_state_with_optimizer.""" - return self.load_state_with_optimizer(path) + return self.load_state_with_optimizer(path, weights_access_token=weights_access_token) @capture_exceptions(fatal=True) async def _save_weights_for_sampler_impl( @@ -739,13 +792,46 @@ class TrainingClient(TelemetryProvider): sampling_session_seq_id=sampling_session_seq_id, ttl_seconds=ttl_seconds, ) - with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: - return await client.weights.save_for_sampler( - request=request, + try: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.weights.save_for_sampler( + request=request, + max_retries=0, + ) + except ConflictError: + if name is None: + # Unnamed saves use server-generated unique paths; + # 409 should be impossible. Re-raise as a real error. + raise + # See save_state for full rationale on treating 409 as success. + logger.info( + "Sampler checkpoint '%s' already exists (409 Conflict); " + "treating as success — the checkpoint is saved.", + name, ) + if telemetry := self.holder.get_telemetry(): + telemetry.log( + "training_client.save_weights_for_sampler.conflict_resolved", + event_data={ + "checkpoint_name": name, + "model_id": self._guaranteed_model_id(), + }, + severity="INFO", + ) + return None async with self._take_turn(request_id): future = await self.holder.execute_with_retries(_send_request) + + # _send_request returns None on 409 conflict (checkpoint already + # saved), or an UntypedAPIFuture on success. + if future is None: + assert name is not None + model_id = self._guaranteed_model_id() + return types.SaveWeightsForSamplerResponseInternal( + path=f"tinker://{model_id}/sampler_weights/{name}" + ) + return await _APIFuture( types.SaveWeightsForSamplerResponseInternal, self.holder, @@ -906,7 +992,7 @@ class TrainingClient(TelemetryProvider): """Save current weights and create a SamplingClient for inference. Args: - - `name`: Optional name for the saved weights (currently ignored for ephemeral saves) + - `name`: Deprecated, has no effect. Will be removed in a future release. - `retry_config`: Optional configuration for retrying failed requests Returns: @@ -923,8 +1009,17 @@ class TrainingClient(TelemetryProvider): result = sampling_client.sample(prompt, 1, params).result() ``` """ - # Ignore name argument for ephemeral save weights for sampler - _ = name + if name is not None: + warnings.warn( + "The 'name' parameter of save_weights_and_get_sampling_client() is deprecated " + "and has no effect — checkpoints are always ephemeral. " + "This parameter will be removed in a future release. " + "Remove the 'name' argument from your call. " + "If you need a persistent checkpoint, use " + "save_weights_for_sampler(name=...) + create_sampling_client(model_path=...) instead.", + DeprecationWarning, + stacklevel=2, + ) return self.save_weights_and_get_sampling_client_submit(retry_config).result() @capture_exceptions(fatal=True) @@ -932,8 +1027,17 @@ class TrainingClient(TelemetryProvider): self, name: str | None = None, retry_config: RetryConfig | None = None ) -> SamplingClient: """Async version of save_weights_and_get_sampling_client.""" - # Ignore name argument for ephemeral save weights for sampler - _ = name + if name is not None: + warnings.warn( + "The 'name' parameter of save_weights_and_get_sampling_client_async() is deprecated " + "and has no effect — checkpoints are always ephemeral. " + "This parameter will be removed in a future release. " + "Remove the 'name' argument from your call. " + "If you need a persistent checkpoint, use " + "save_weights_for_sampler(name=...) + create_sampling_client(model_path=...) instead.", + DeprecationWarning, + stacklevel=2, + ) return await self.save_weights_and_get_sampling_client_submit(retry_config) def get_telemetry(self) -> Telemetry | None: diff --git a/src/tinker/resources/futures.py b/src/tinker/resources/futures.py index 0b089e9..3e25d57 100644 --- a/src/tinker/resources/futures.py +++ b/src/tinker/resources/futures.py @@ -73,7 +73,7 @@ class AsyncFuturesResource(AsyncAPIResource): FutureRetrieveResponse, await self._post( "/api/v1/retrieve_future", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=cast( Any, FutureRetrieveResponse diff --git a/src/tinker/resources/models.py b/src/tinker/resources/models.py index bfa80ad..861dbe1 100644 --- a/src/tinker/resources/models.py +++ b/src/tinker/resources/models.py @@ -60,7 +60,7 @@ class AsyncModelsResource(AsyncAPIResource): return await self._post( "/api/v1/create_model", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) @@ -106,7 +106,7 @@ class AsyncModelsResource(AsyncAPIResource): result = await self._post( "/api/v1/get_info", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=GetInfoResponse, ) @@ -159,7 +159,7 @@ class AsyncModelsResource(AsyncAPIResource): return await self._post( "/api/v1/unload_model", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) diff --git a/src/tinker/resources/sampling.py b/src/tinker/resources/sampling.py index f6a97a5..96dba41 100644 --- a/src/tinker/resources/sampling.py +++ b/src/tinker/resources/sampling.py @@ -56,7 +56,7 @@ class AsyncSamplingResource(AsyncAPIResource): return await self._post( "/api/v1/asample", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) diff --git a/src/tinker/resources/service.py b/src/tinker/resources/service.py index fb95c5f..8aace8b 100644 --- a/src/tinker/resources/service.py +++ b/src/tinker/resources/service.py @@ -6,6 +6,9 @@ from .._base_client import make_request_options from .._compat import model_dump from .._resource import AsyncAPIResource from .._types import NOT_GIVEN, Body, Headers, NotGiven, Query +from ..types.auth_token_response import AuthTokenResponse +from ..types.client_config_request import ClientConfigRequest +from ..types.client_config_response import ClientConfigResponse from ..types.create_sampling_session_request import CreateSamplingSessionRequest from ..types.create_sampling_session_response import CreateSamplingSessionResponse from ..types.create_session_request import CreateSessionRequest @@ -63,6 +66,54 @@ class AsyncServiceResource(AsyncAPIResource): cast_to=HealthResponse, ) + async def auth_token( + self, + *, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + max_retries: int | NotGiven = NOT_GIVEN, + ) -> AuthTokenResponse: + """Exchange the current credential for a short-lived JWT.""" + options = make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ) + if max_retries is not NOT_GIVEN: + options["max_retries"] = max_retries + + return await self._post( + "/api/v1/auth/token", + body={}, + options=options, + cast_to=AuthTokenResponse, + ) + + async def client_config( + self, + *, + request: ClientConfigRequest, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ClientConfigResponse: + """Fetch server-side feature flags for this client.""" + return await self._post( + "/api/v1/client/config", + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + ), + cast_to=ClientConfigResponse, + ) + async def create_session( self, *, @@ -104,7 +155,7 @@ class AsyncServiceResource(AsyncAPIResource): return await self._post( "/api/v1/create_session", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=CreateSessionResponse, ) @@ -148,7 +199,7 @@ class AsyncServiceResource(AsyncAPIResource): request = SessionHeartbeatRequest(session_id=session_id) return await self._post( "/api/v1/session_heartbeat", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=SessionHeartbeatResponse, ) @@ -190,7 +241,7 @@ class AsyncServiceResource(AsyncAPIResource): return await self._post( "/api/v1/create_sampling_session", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=CreateSamplingSessionResponse, ) diff --git a/src/tinker/resources/telemetry.py b/src/tinker/resources/telemetry.py index 0a95fc2..e16b1a9 100644 --- a/src/tinker/resources/telemetry.py +++ b/src/tinker/resources/telemetry.py @@ -67,7 +67,7 @@ class AsyncTelemetryResource(AsyncAPIResource): return await self._post( "/api/v1/telemetry", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=TelemetryResponse, ) diff --git a/src/tinker/resources/training.py b/src/tinker/resources/training.py index a6cf01d..3888f1d 100644 --- a/src/tinker/resources/training.py +++ b/src/tinker/resources/training.py @@ -56,7 +56,7 @@ class AsyncTrainingResource(AsyncAPIResource): return await self._post( "/api/v1/forward", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) @@ -102,7 +102,7 @@ class AsyncTrainingResource(AsyncAPIResource): return await self._post( "/api/v1/forward_backward", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) @@ -148,7 +148,7 @@ class AsyncTrainingResource(AsyncAPIResource): return await self._post( "/api/v1/optim_step", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) diff --git a/src/tinker/resources/weights.py b/src/tinker/resources/weights.py index e6cf130..8fae0c8 100644 --- a/src/tinker/resources/weights.py +++ b/src/tinker/resources/weights.py @@ -61,7 +61,7 @@ class AsyncWeightsResource(AsyncAPIResource): return await self._post( "/api/v1/load_weights", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) @@ -107,7 +107,7 @@ class AsyncWeightsResource(AsyncAPIResource): return await self._post( "/api/v1/save_weights", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) @@ -153,7 +153,7 @@ class AsyncWeightsResource(AsyncAPIResource): return await self._post( "/api/v1/save_weights_for_sampler", - body=model_dump(request, exclude_unset=True, mode="json"), + body=model_dump(request, exclude_unset=False, exclude_none=True, mode="json"), options=options, cast_to=UntypedAPIFuture, ) diff --git a/src/tinker/types/__init__.py b/src/tinker/types/__init__.py index 81b5a93..37a78a6 100644 --- a/src/tinker/types/__init__.py +++ b/src/tinker/types/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from .auth_token_response import AuthTokenResponse as AuthTokenResponse from .checkpoint import ( Checkpoint as Checkpoint, ) @@ -13,6 +14,8 @@ from .checkpoint_archive_url_response import ( CheckpointArchiveUrlResponse as CheckpointArchiveUrlResponse, ) from .checkpoints_list_response import CheckpointsListResponse as CheckpointsListResponse +from .client_config_request import ClientConfigRequest as ClientConfigRequest +from .client_config_response import ClientConfigResponse as ClientConfigResponse from .create_model_request import CreateModelRequest as CreateModelRequest from .create_model_response import CreateModelResponse as CreateModelResponse from .create_sampling_session_request import ( diff --git a/src/tinker/types/auth_token_response.py b/src/tinker/types/auth_token_response.py new file mode 100644 index 0000000..61c9475 --- /dev/null +++ b/src/tinker/types/auth_token_response.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +from .._models import BaseModel + +__all__ = ["AuthTokenResponse"] + + +class AuthTokenResponse(BaseModel): + jwt: str diff --git a/src/tinker/types/client_config_request.py b/src/tinker/types/client_config_request.py new file mode 100644 index 0000000..3692054 --- /dev/null +++ b/src/tinker/types/client_config_request.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from .._models import StrictBase + +__all__ = ["ClientConfigRequest"] + + +class ClientConfigRequest(StrictBase): + sdk_version: str + """The SDK version string for flag resolution.""" diff --git a/src/tinker/types/client_config_response.py b/src/tinker/types/client_config_response.py new file mode 100644 index 0000000..826f687 --- /dev/null +++ b/src/tinker/types/client_config_response.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from .._models import BaseModel + +__all__ = ["ClientConfigResponse"] + + +class ClientConfigResponse(BaseModel): + """Server-side feature flags resolved for this caller. + + Uses BaseModel (extra="ignore") so new flags from the server are + silently dropped until the SDK adds fields for them. + """ + + pjwt_auth_enabled: bool = False + credential_default_source: str = "api_key" + sample_dispatch_bytes_semaphore_size: int = 10 * 1024 * 1024 + inflight_response_bytes_semaphore_size: int = 50 * 1024 * 1024 diff --git a/src/tinker/types/datum.py b/src/tinker/types/datum.py index d6d772f..f928c81 100644 --- a/src/tinker/types/datum.py +++ b/src/tinker/types/datum.py @@ -46,6 +46,9 @@ class Datum(StrictBase): def _maybe_convert_array(cls, key: str, value: Any) -> Any: """Convert torch.Tensor, numpy array, or numeric lists to TensorData if needed.""" if _HAVE_TORCH and isinstance(value, torch.Tensor): + # Auto-sparsify 2-D target_tokens and weights to reduce wire payload + if key in _sparse_eligible_keys and value.ndim == 2: + return TensorData.from_torch_sparse(value) return TensorData.from_torch(value) elif isinstance(value, np.ndarray): return TensorData.from_numpy(value) @@ -81,3 +84,5 @@ _key_to_type = { "clip_low_threshold": "float32", "clip_high_threshold": "float32", } + +_sparse_eligible_keys = {"target_tokens", "weights"} diff --git a/src/tinker/types/get_server_capabilities_response.py b/src/tinker/types/get_server_capabilities_response.py index c69e9e0..32e4061 100644 --- a/src/tinker/types/get_server_capabilities_response.py +++ b/src/tinker/types/get_server_capabilities_response.py @@ -11,6 +11,9 @@ class SupportedModel(BaseModel): model_name: Optional[str] = None """The name of the supported model.""" + max_context_length: Optional[int] = None + """The maximum context length (in tokens) supported by this model.""" + class GetServerCapabilitiesResponse(BaseModel): """Response containing the server's supported models and capabilities.""" diff --git a/src/tinker/types/load_weights_request.py b/src/tinker/types/load_weights_request.py index 9caaccc..ac5e8b7 100644 --- a/src/tinker/types/load_weights_request.py +++ b/src/tinker/types/load_weights_request.py @@ -20,6 +20,9 @@ class LoadWeightsRequest(StrictBase): seq_id: Optional[int] = None + weights_access_token: Optional[str] = None + """Optional access token for loading checkpoints under a different account.""" + type: Literal["load_weights"] = "load_weights" if PYDANTIC_V2: diff --git a/src/tinker/types/tensor_data.py b/src/tinker/types/tensor_data.py index eab161b..54ad629 100644 --- a/src/tinker/types/tensor_data.py +++ b/src/tinker/types/tensor_data.py @@ -33,6 +33,17 @@ class TensorData(StrictBase): provided, and is generally inferred as a 1D tensor. """ + sparse_crow_indices: Optional[List[int]] = None + """Optional CSR compressed row pointers. When set, this tensor is sparse CSR: + - data contains only the non-zero values (flattened) + - sparse_crow_indices contains the row pointers (length = nrows + 1) + - sparse_col_indices contains the column indices (length = nnz) + - shape is required and specifies the dense shape + """ + + sparse_col_indices: Optional[List[int]] = None + """Optional CSR column indices. Must be set together with sparse_crow_indices.""" + @classmethod def from_numpy(cls, array: npt.NDArray[Any]) -> "TensorData": return cls( @@ -49,8 +60,41 @@ class TensorData(StrictBase): shape=list(tensor.shape), ) + @classmethod + def from_torch_sparse(cls, tensor: "torch.Tensor") -> "TensorData": + """Create a sparse CSR TensorData from a dense 2-D torch tensor. + + Automatically detects sparsity and encodes as CSR when it saves space. + Falls back to dense if the tensor is 1-D or mostly non-zero. + """ + if not _HAVE_TORCH: + raise ImportError("PyTorch is not installed.") + + if tensor.ndim != 2: + return cls.from_torch(tensor) + + # Only use sparse if it actually saves space + # Dense: nrows * ncols values + # CSR: (nrows + 1) crow_indices + nnz col_indices + nnz values + nnz = tensor.count_nonzero().item() + dense_size = tensor.shape[0] * tensor.shape[1] + csr_size = (tensor.shape[0] + 1) + 2 * nnz + if csr_size >= dense_size: + return cls.from_torch(tensor) + + sparse_csr = tensor.to_sparse_csr() + return cls( + data=sparse_csr.values().tolist(), + dtype=_convert_torch_dtype_to_tensor(tensor.dtype), + shape=list(tensor.shape), + sparse_crow_indices=sparse_csr.crow_indices().tolist(), + sparse_col_indices=sparse_csr.col_indices().tolist(), + ) + def to_numpy(self) -> npt.NDArray[Any]: """Convert TensorData to numpy array.""" + if self.sparse_crow_indices is not None: + return self.to_torch().numpy() numpy_dtype = _convert_tensor_dtype_to_numpy(self.dtype) arr = np.array(self.data, dtype=numpy_dtype) if self.shape is not None: @@ -63,6 +107,17 @@ class TensorData(StrictBase): raise ImportError("PyTorch is not installed. Cannot convert to torch tensor.") torch_dtype = _convert_tensor_dtype_to_torch(self.dtype) + + if self.sparse_crow_indices is not None: + assert self.sparse_col_indices is not None, ( + "sparse_col_indices required with sparse_crow_indices" + ) + assert self.shape is not None, "shape is required for sparse tensors" + crow = torch.tensor(self.sparse_crow_indices, dtype=torch.int64) + col = torch.tensor(self.sparse_col_indices, dtype=torch.int64) + values = torch.tensor(self.data, dtype=torch_dtype) + return torch.sparse_csr_tensor(crow, col, values, self.shape).to_dense() + tensor = torch.tensor(self.data, dtype=torch_dtype) if self.shape is not None: tensor = tensor.reshape(self.shape) diff --git a/tests/test_conflict_recovery.py b/tests/test_conflict_recovery.py new file mode 100644 index 0000000..d6f4791 --- /dev/null +++ b/tests/test_conflict_recovery.py @@ -0,0 +1,79 @@ +"""Tests for 409 ConflictError recovery in checkpoint save operations.""" + +from __future__ import annotations + +import asyncio +from contextlib import contextmanager +from typing import Any +from unittest.mock import AsyncMock, MagicMock, Mock + +import httpx +import pytest + +from tinker._exceptions import ConflictError +from tinker.lib.public_interfaces.training_client import TrainingClient + + +def _make_conflict_error() -> ConflictError: + """Create a ConflictError for testing.""" + request = httpx.Request("POST", "http://test/api/v1/save_weights") + response = httpx.Response(409, request=request) + return ConflictError("conflict", response=response, body=None) + + +def _make_mock_holder() -> Mock: + """Create a mock InternalClientHolder whose weights.save() raises ConflictError.""" + mock_client = MagicMock() + mock_client.weights.save = AsyncMock(side_effect=_make_conflict_error()) + mock_client.weights.save_for_sampler = AsyncMock(side_effect=_make_conflict_error()) + + @contextmanager + def fake_aclient(*args: Any, **kwargs: Any): + yield mock_client + + async def fake_execute_with_retries(fn: Any, *args: Any, **kwargs: Any) -> Any: + return await fn(*args, **kwargs) + + holder = Mock() + holder.aclient = fake_aclient + holder.get_telemetry = Mock(return_value=None) + holder.execute_with_retries = fake_execute_with_retries + holder.get_loop = Mock(side_effect=lambda: asyncio.get_event_loop()) + + def fake_run_coroutine_threadsafe(coro: Any) -> Any: + return asyncio.ensure_future(coro) + + holder.run_coroutine_threadsafe = fake_run_coroutine_threadsafe + return holder + + +@pytest.mark.asyncio +async def test_save_state_returns_synthetic_path_on_conflict() -> None: + """save_state catches 409 and returns SaveWeightsResponse with synthetic path.""" + holder = _make_mock_holder() + client = TrainingClient(holder, model_seq_id=0, model_id="model-123") + + result = await client.save_state("ckpt-001") + assert result.path == "tinker://model-123/weights/ckpt-001" + + +@pytest.mark.asyncio +async def test_save_weights_for_sampler_returns_synthetic_path_on_conflict() -> None: + """save_weights_for_sampler catches 409 and returns response with synthetic path.""" + holder = _make_mock_holder() + holder._sampling_client_counter = 0 + client = TrainingClient(holder, model_seq_id=0, model_id="model-789") + + result = await client.save_weights_for_sampler("ckpt-001") + assert result.path == "tinker://model-789/sampler_weights/ckpt-001" + + +@pytest.mark.asyncio +async def test_save_weights_for_sampler_unnamed_reraises_conflict() -> None: + """409 on unnamed sampler save (name=None) should re-raise, not swallow.""" + holder = _make_mock_holder() + holder._sampling_client_counter = 0 + client = TrainingClient(holder, model_seq_id=0, model_id="model-000") + + with pytest.raises(ConflictError): + await client.save_weights_for_sampler(None)