Sync contents

This commit is contained in:
Dylan Huang 2026-03-19 00:10:49 +00:00
parent 35a77e79fe
commit f3c0b1f179
17 changed files with 2069 additions and 632 deletions

View file

@ -1,4 +1,4 @@
{
"last_synced_sha": "b31ac99f31c26082a4076c3c79719034b2e0cab2",
"last_sync_time": "2026-03-09T00:49:41.558594"
"last_synced_sha": "2138c60c730c51b7bc19146a38320d631c38e0cc",
"last_sync_time": "2026-03-19T00:10:49.917056"
}

View file

@ -121,6 +121,38 @@ class InternalServerError(APIStatusError)
HTTP 500+: An error occurred on the server.
## `SidecarError` Objects
```python
class SidecarError(TinkerError)
```
Base exception for subprocess sidecar errors.
## `SidecarStartupError` Objects
```python
class SidecarStartupError(SidecarError)
```
Raised when the sidecar subprocess fails to start or times out.
## `SidecarDiedError` Objects
```python
class SidecarDiedError(SidecarError)
```
Raised when the sidecar subprocess exits unexpectedly while requests are pending.
## `SidecarIPCError` Objects
```python
class SidecarIPCError(SidecarError)
```
Raised when communication with the sidecar subprocess fails.
## `RequestFailedError` Objects
```python

View file

@ -20,6 +20,7 @@ Key methods:
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
- set_checkpoint_ttl_from_tinker_path() - set or remove TTL on a checkpoint
Args:
- `holder`: Internal client managing HTTP connections and async operations
@ -39,7 +40,9 @@ for checkpoint in checkpoints.checkpoints:
```python
def get_training_run(
training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]
training_run_id: types.ModelID,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRun]
```
Get training run info.
@ -61,7 +64,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
```python
async def get_training_run_async(
training_run_id: types.ModelID) -> types.TrainingRun
training_run_id: types.ModelID,
access_scope: Literal["owned",
"accessible"] = "owned") -> types.TrainingRun
```
Async version of get_training_run.
@ -70,7 +75,9 @@ Async version of get_training_run.
```python
def get_training_run_by_tinker_path(
tinker_path: str) -> ConcurrentFuture[types.TrainingRun]
tinker_path: str,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRun]
```
Get training run info.
@ -92,7 +99,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
```python
async def get_training_run_by_tinker_path_async(
tinker_path: str) -> types.TrainingRun
tinker_path: str,
access_scope: Literal["owned",
"accessible"] = "owned") -> types.TrainingRun
```
Async version of get_training_run_by_tinker_path.
@ -123,8 +132,10 @@ print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
```python
def list_training_runs(
limit: int = 20,
offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse]
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRunsResponse]
```
List training runs with pagination support.
@ -149,9 +160,11 @@ next_page = rest_client.list_training_runs(limit=50, offset=50)
#### `list_training_runs_async`
```python
async def list_training_runs_async(limit: int = 20,
offset: int = 0
) -> types.TrainingRunsResponse
async def list_training_runs_async(
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> types.TrainingRunsResponse
```
Async version of list_training_runs.
@ -367,6 +380,46 @@ async def unpublish_checkpoint_from_tinker_path_async(
Async version of unpublish_checkpoint_from_tinker_path.
#### `set_checkpoint_ttl_from_tinker_path`
```python
def set_checkpoint_ttl_from_tinker_path(
tinker_path: str, ttl_seconds: int | None) -> ConcurrentFuture[None]
```
Set or remove the TTL on a checkpoint referenced by a tinker path.
If ttl_seconds is provided, the checkpoint will expire after that many seconds from now.
If ttl_seconds is None, any existing expiration will be removed.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
- `ttl_seconds`: Number of seconds until expiration, or None to remove TTL
Returns:
- A `Future` that completes when the TTL is set
Raises:
HTTPException: 400 if checkpoint identifier is invalid or ttl_seconds <= 0
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 500 if there's an error setting the TTL
Example:
```python
future = rest_client.set_checkpoint_ttl_from_tinker_path("tinker://run-id/weights/0001", 86400)
future.result() # Wait for completion
print("Checkpoint TTL set successfully")
```
#### `set_checkpoint_ttl_from_tinker_path_async`
```python
async def set_checkpoint_ttl_from_tinker_path_async(
tinker_path: str, ttl_seconds: int | None) -> None
```
Async version of set_checkpoint_ttl_from_tinker_path.
#### `list_user_checkpoints`
```python
@ -414,7 +467,10 @@ Async version of list_user_checkpoints.
#### `get_session`
```python
def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse]
def get_session(
session_id: str,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.GetSessionResponse]
```
Get session information including all training runs and samplers.
@ -436,7 +492,10 @@ print(f"Samplers: {len(response.sampler_ids)}")
#### `get_session_async`
```python
async def get_session_async(session_id: str) -> types.GetSessionResponse
async def get_session_async(
session_id: str,
access_scope: Literal["owned", "accessible"] = "owned"
) -> types.GetSessionResponse
```
Async version of get_session.
@ -445,8 +504,10 @@ Async version of get_session.
```python
def list_sessions(
limit: int = 20,
offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse]
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.ListSessionsResponse]
```
List sessions with pagination support.
@ -470,8 +531,11 @@ next_page = rest_client.list_sessions(limit=50, offset=50)
#### `list_sessions_async`
```python
async def list_sessions_async(limit: int = 20,
offset: int = 0) -> types.ListSessionsResponse
async def list_sessions_async(
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> types.ListSessionsResponse
```
Async version of list_sessions.

View file

@ -30,6 +30,20 @@ future = sampling_client.sample(prompt=prompt, sampling_params=params, num_sampl
result = future.result()
```
Multi-processing support:
This class is picklable, so it can be passed to a separate process/worker to sample. It is also
safe to pass the same instance of SamplingClient to multiple processes/workers.
If you are using Tinker SDK with more than one process you should always create SamplingClient from
the main process and then pass it to the other processes/workers.
ServiceClient and TrainingClient should always be managed from the main process.
Subprocess isolation:
Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated
subprocess, preventing GIL contention from CPU-heavy user code (grading, environment
interactions) from stalling networking IO and heartbeats. This is transparent — the same
API works with or without it.
#### `sample`
```python
@ -121,3 +135,33 @@ Get the tokenizer for the current model.
Returns:
- `PreTrainedTokenizer` compatible with the model
#### `get_base_model`
```python
def get_base_model() -> str
```
Get the base model name for the current sampling session.
#### `get_base_model_async`
```python
async def get_base_model_async() -> str
```
Async version of get_base_model.
#### `__reduce__`
```python
def __reduce__() -> tuple[Any, tuple[_SamplingClientPickleState]]
```
Enable pickling of SamplingClient for subprocess use.
Serializes into a ``_SamplingClientPickleState`` dataclass. The
``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a
bool flag is stored. The unpickled copy creates its own handle via
the per-process sidecar singleton. Do not add ``__getstate__``
without preserving this behavior.

View file

@ -3,7 +3,7 @@ TrainingClient for Tinker API.
## `TrainingClient` Objects
```python
class TrainingClient(TelemetryProvider, QueueStateObserver)
class TrainingClient(TelemetryProvider)
```
Client for training ML models with forward/backward passes and optimization.
@ -127,8 +127,11 @@ Async version of forward_backward.
```python
def forward_backward_custom(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs"
) -> APIFuture[types.ForwardBackwardOutput]
```
Compute forward/backward with a custom loss function.
@ -139,6 +142,7 @@ The custom function receives logprobs and computes loss and gradients.
Args:
- `data`: List of training data samples
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
Returns:
- `APIFuture` containing the forward/backward outputs with custom loss
@ -161,8 +165,11 @@ print(f"Metrics: {result.metrics}")
```python
async def forward_backward_custom_async(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs"
) -> APIFuture[types.ForwardBackwardOutput]
```
Async version of forward_backward_custom.

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[project]
name = "tinker"
version = "0.15.0"
version = "0.16.0"
description = "The official Python SDK for the tinker API"
readme = "README.md"
license = "Apache-2.0"

View file

@ -10,6 +10,7 @@
import ast
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
@ -95,8 +96,42 @@ class DocumentationGenerator:
def run_pydoc_markdown(self, modules: List[str], output_file: Path) -> bool:
"""Run pydoc-markdown for specific modules."""
try:
# Build the command
cmd = ["pydoc-markdown", "pydoc-markdown.yml", "-I", "src"]
# Invoke pydoc-markdown via uvx/uv tool run so the required
# doc-generation dependencies are resolved consistently.
uvx_path = shutil.which("uvx")
if uvx_path:
cmd = [
uvx_path,
"--from",
"pydoc-markdown>=4.8.0",
"--with",
"pyyaml>=6.0",
"--with",
"setuptools",
"pydoc-markdown",
"pydoc-markdown.yml",
"-I",
"src",
]
else:
uv_path = shutil.which("uv")
if uv_path is None:
raise FileNotFoundError("Could not find `uvx` or `uv` on PATH")
cmd = [
uv_path,
"tool",
"run",
"--from",
"pydoc-markdown>=4.8.0",
"--with",
"pyyaml>=6.0",
"--with",
"setuptools",
"pydoc-markdown",
"pydoc-markdown.yml",
"-I",
"src",
]
# Add modules
for module in modules:

View file

@ -16,6 +16,10 @@ from ._exceptions import (
PermissionDeniedError,
RateLimitError,
RequestFailedError,
SidecarDiedError,
SidecarError,
SidecarIPCError,
SidecarStartupError,
TinkerError,
UnprocessableEntityError,
)
@ -97,6 +101,10 @@ __all__ = [
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"SidecarError",
"SidecarStartupError",
"SidecarDiedError",
"SidecarIPCError",
# Keep types module for advanced use
"types",
# Version info

View file

@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Dict, List
import click
if TYPE_CHECKING:
from datetime import datetime
from tinker.lib.public_interfaces.rest_client import RestClient
from tinker.types import Checkpoint, TrainingRun
@ -838,53 +840,249 @@ def set_ttl(cli_context: CLIContext, checkpoint_path: str, ttl: int | None, remo
client.set_checkpoint_ttl_from_tinker_path(checkpoint_path, ttl_seconds).result()
def _parse_date(value: str) -> "datetime":
"""Parse an ISO 8601 date or datetime string to a timezone-aware datetime.
Accepts: 2024-01-01, 2024-01-01T12:00:00, 2024-01-01T12:00:00Z,
2024-01-01T12:00:00+00:00. Date-only values are treated as midnight UTC.
"""
from datetime import UTC, datetime
value = value.strip()
# Python < 3.11 doesn't handle trailing 'Z' in fromisoformat
if value.endswith("Z"):
value = value[:-1] + "+00:00"
try:
dt = datetime.fromisoformat(value)
except ValueError:
raise TinkerCliError(
f"Invalid date: {value}",
"Use ISO 8601 format: 2024-01-01, 2024-01-01T12:00:00Z",
)
# Assume UTC if no timezone provided
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
return dt
_CHECKPOINT_TYPE_MAP = {
"weights": "training",
"sampler_weights": "sampler",
"training": "training",
"sampler": "sampler",
}
def _filter_checkpoints(
checkpoints: "List[Checkpoint]",
checkpoint_type: str | None,
before: "datetime | None",
after: "datetime | None",
) -> "List[Checkpoint]":
"""Filter checkpoints by type, before date, and/or after date."""
filtered = checkpoints
if checkpoint_type:
mapped_type = _CHECKPOINT_TYPE_MAP.get(checkpoint_type)
if mapped_type is None:
raise TinkerCliError(
f"Invalid checkpoint type: {checkpoint_type}",
"Valid types: weights, sampler_weights",
)
filtered = [c for c in filtered if c.checkpoint_type == mapped_type]
if before is not None:
filtered = [c for c in filtered if c.time < before]
if after is not None:
filtered = [c for c in filtered if c.time > after]
return filtered
def _confirm_deletion(paths: "List[str]", checkpoints: "List[Checkpoint] | None" = None) -> bool:
"""Show deletion summary and prompt for confirmation. Returns True if confirmed."""
count = len(paths)
if checkpoints is not None:
total_size = sum(c.size_bytes or 0 for c in checkpoints)
click.echo(f"Will delete {count} checkpoint(s):")
for ckpt in checkpoints:
size_str = format_size(ckpt.size_bytes) if ckpt.size_bytes is not None else "N/A"
time_str = format_timestamp(ckpt.time)
click.echo(f" - {ckpt.tinker_path} ({size_str}, created {time_str})")
click.echo(f"\nTotal size: {format_size(total_size)}")
else:
click.echo(f"Will delete {count} checkpoint(s):")
for path in paths:
click.echo(f" - {path}")
click.echo()
click.echo("WARNING: This action is permanent and cannot be undone.")
return click.confirm(f"Are you sure you want to delete {count} checkpoint(s)?")
_DELETE_CONCURRENCY = 32
def _delete_one(client: "RestClient", path: str) -> "tuple[str, str] | None":
"""Delete a single checkpoint. Returns (path, error) on failure, None on success."""
try:
client.delete_checkpoint_from_tinker_path(path).result()
return None
except Exception as e:
return (path, str(e))
def _delete_paths(
client: "RestClient",
paths: "List[str]",
format: str,
) -> None:
"""Delete a list of tinker paths concurrently and print results."""
from concurrent.futures import ThreadPoolExecutor, as_completed
deleted_count = 0
failed: "List[tuple[str, str]]" = []
with (
ThreadPoolExecutor(max_workers=_DELETE_CONCURRENCY) as pool,
click.progressbar(
length=len(paths),
label="Deleting checkpoints",
show_percent=True,
show_pos=True,
hidden=format != "table",
) as bar,
):
futures = {pool.submit(_delete_one, client, p): p for p in paths}
for future in as_completed(futures):
result = future.result()
if result is None:
deleted_count += 1
else:
failed.append(result)
bar.update(1)
if format == "json":
import json
click.echo(
json.dumps(
{
"deleted_count": deleted_count,
"failed": [{"tinker_path": p, "error": e} for p, e in failed],
}
)
)
else:
click.echo(f"Deleted {deleted_count} checkpoint(s)")
if failed:
click.echo(f"Failed to delete {len(failed)} checkpoint(s):")
for path, error in failed:
click.echo(f" - {path}: {error}")
@cli.command()
@click.argument("checkpoint_paths", nargs=-1, required=True)
@click.argument("checkpoint_paths", nargs=-1, required=False)
@click.option("--run-id", default=None, help="Delete all checkpoints for a training run")
@click.option(
"--type", "checkpoint_type", default=None, help="Filter by type: weights or sampler_weights"
)
@click.option(
"--before",
default=None,
help="Filter: created before date in UTC (ISO 8601, e.g. 2024-01-01, 2024-01-01T08:00:00Z)",
)
@click.option(
"--after",
default=None,
help="Filter: created after date in UTC (ISO 8601, e.g. 2024-01-01, 2024-01-01T08:00:00Z)",
)
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
@click.pass_obj
@handle_api_errors
def delete(cli_context: CLIContext, checkpoint_paths: tuple[str, ...], yes: bool) -> None:
def delete(
cli_context: CLIContext,
checkpoint_paths: tuple[str, ...],
run_id: str | None,
checkpoint_type: str | None,
before: str | None,
after: str | None,
yes: bool,
) -> None:
"""Delete one or more checkpoints permanently.
CHECKPOINT_PATHS must be tinker paths (e.g., tinker://run-id/weights/0001).
Delete by explicit paths:
tinker checkpoint delete tinker://run-id/weights/0001 tinker://run-id/weights/0002
Delete all checkpoints for a training run:
tinker checkpoint delete --run-id <run-id>
Delete with filters:
tinker checkpoint delete --run-id <run-id> --type weights --before 2024-06-01
Delete checkpoints in a date range:
tinker checkpoint delete --run-id <run-id> --after 2024-01-01 --before 2024-02-01
Dates are interpreted as UTC. Use full ISO 8601 datetime for precision:
tinker checkpoint delete --run-id <run-id> --before 2024-06-01T08:00:00Z
Only the owner of the training run can delete checkpoints.
WARNING: This action is permanent and cannot be undone.
"""
# Validate all paths upfront
for path in checkpoint_paths:
if not path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
if not checkpoint_paths and not run_id:
raise TinkerCliError(
"Must specify checkpoint paths or --run-id",
"Examples:\n"
" tinker checkpoint delete tinker://run-id/weights/0001\n"
" tinker checkpoint delete --run-id <run-id>\n"
" tinker checkpoint delete --run-id <run-id> --type weights --before 2024-06-01",
)
# If not using --yes, show checkpoint list and prompt for confirmation
if not yes:
count = len(checkpoint_paths)
click.echo(f"Will delete {count} checkpoint(s):")
if checkpoint_paths and run_id:
raise TinkerCliError(
"Cannot specify both checkpoint paths and --run-id",
"Use either explicit paths or --run-id with optional filters",
)
has_filters = checkpoint_type or before or after
if has_filters and not run_id:
raise TinkerCliError(
"--type, --before, and --after require --run-id",
"Example: tinker checkpoint delete --run-id <run-id> --type weights --before 2024-06-01",
)
client = create_rest_client()
if run_id:
before_dt = _parse_date(before) if before else None
after_dt = _parse_date(after) if after else None
response = client.list_checkpoints(run_id).result()
checkpoints = _filter_checkpoints(
response.checkpoints, checkpoint_type, before_dt, after_dt
)
if not checkpoints:
click.echo(f"No checkpoints found for run {run_id} matching filters")
return
paths_to_delete = [c.tinker_path for c in checkpoints]
if not yes and not _confirm_deletion(paths_to_delete, checkpoints):
click.echo("Deletion cancelled.")
return
else:
for path in checkpoint_paths:
click.echo(f" - {path}")
click.echo()
# Confirmation prompt
click.echo("WARNING: This action is permanent and cannot be undone.")
if not click.confirm(f"Are you sure you want to delete {count} checkpoint(s)?"):
if not path.startswith("tinker://"):
raise TinkerCliError(
f"Invalid checkpoint path: {path}",
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
)
paths_to_delete = list(checkpoint_paths)
if not yes and not _confirm_deletion(paths_to_delete):
click.echo("Deletion cancelled.")
return
# Create client and delete with progress bar
client = create_rest_client()
with click.progressbar(
checkpoint_paths,
label="Deleting checkpoints",
show_percent=True,
show_pos=True,
hidden=cli_context.format != "table",
) as bar:
for path in bar:
client.delete_checkpoint_from_tinker_path(path).result()
_delete_paths(client, paths_to_delete, cli_context.format)
@cli.command()

View file

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import dataclasses
import logging
import os
import time
@ -15,6 +16,7 @@ import tinker
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.sidecar import SidecarHandle, SidecarRPC, create_sidecar_handle
from tinker.lib.telemetry import Telemetry, capture_exceptions
from tinker.lib.telemetry_provider import TelemetryProvider
@ -33,6 +35,56 @@ logger = logging.getLogger(__name__)
U = TypeVar("U")
# ---------------------------------------------------------------------------
# Pickle serialization state
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class _SamplingClientPickleState:
"""Serialized state for pickling SamplingClient across processes."""
session_id: str
sampling_session_id: str
constructor_kwargs: dict[str, Any]
subprocess_sampling: bool
# ---------------------------------------------------------------------------
# Typed RPCs for subprocess-isolated sampling
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class _SampleRPC(SidecarRPC):
"""Typed RPC for SamplingClient.sample()."""
prompt: types.ModelInput
num_samples: int
sampling_params: types.SamplingParams
include_prompt_logprobs: bool
topk_prompt_logprobs: int
async def execute(self, target: Any) -> Any:
return target.sample(
prompt=self.prompt,
num_samples=self.num_samples,
sampling_params=self.sampling_params,
include_prompt_logprobs=self.include_prompt_logprobs,
topk_prompt_logprobs=self.topk_prompt_logprobs,
)
@dataclasses.dataclass
class _ComputeLogprobsRPC(SidecarRPC):
"""Typed RPC for SamplingClient.compute_logprobs()."""
prompt: types.ModelInput
async def execute(self, target: Any) -> Any:
return target.compute_logprobs(prompt=self.prompt)
class SamplingClient(TelemetryProvider, QueueStateObserver):
"""Client for text generation and inference from trained or base models.
@ -65,6 +117,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
If you are using Tinker SDK with more than one process you should always create SamplingClient from
the main process and then pass it to the other processes/workers.
ServiceClient and TrainingClient should always be managed from the main process.
Subprocess isolation:
Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated
subprocess, preventing GIL contention from CPU-heavy user code (grading, environment
interactions) from stalling networking IO and heartbeats. This is transparent the same
API works with or without it.
"""
def __init__(
@ -74,6 +132,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
sampling_session_id: str,
shadow: bool = False,
retry_config: RetryConfig | None = None,
subprocess_sampling: bool | None = None,
):
self.holder = holder
@ -97,6 +156,20 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
# We use 1B as the base and mod for uuid because the maximum int value is 2^63-1 and 1B*1B is less than 2^63-1.
self._request_id_counter = 1_000_000_000 * (int(uuid.uuid4()) % 1_000_000_000 + 1)
# Subprocess isolation: read env var if not explicitly set
if subprocess_sampling is None:
subprocess_sampling = os.environ.get("TINKER_SUBPROCESS_SAMPLING", "").lower() in (
"1",
"true",
"yes",
)
self._sampling_client_sidecar_handle: SidecarHandle | None = None
if subprocess_sampling:
from tinker.lib.sidecar import _inside_sidecar
if not _inside_sidecar:
self._sampling_client_sidecar_handle = create_sidecar_handle(self)
@staticmethod
async def _create_impl(
holder: InternalClientHolder,
@ -237,6 +310,16 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
print(tokenizer.decode(sample.tokens))
```
"""
if self._sampling_client_sidecar_handle is not None:
return self._sampling_client_sidecar_handle.submit_rpc(
_SampleRPC(
prompt=prompt,
num_samples=num_samples,
sampling_params=sampling_params,
include_prompt_logprobs=include_prompt_logprobs,
topk_prompt_logprobs=topk_prompt_logprobs,
)
)
async def _sample_async():
return await self._sample_async_impl(
@ -294,6 +377,10 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
print(f"Token {i}: logprob = {logprob:.4f}")
```
"""
if self._sampling_client_sidecar_handle is not None:
return self._sampling_client_sidecar_handle.submit_rpc(
_ComputeLogprobsRPC(prompt=prompt)
)
async def _compute_logprobs_async() -> list[float | None]:
sample_res = await self._sample_async_impl(
@ -349,18 +436,24 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
def get_telemetry(self) -> Telemetry | None:
return self.holder.get_telemetry()
def __reduce__(self) -> tuple[Any, tuple[str, str, dict[str, Any]]]:
def __reduce__(self) -> tuple[Any, tuple[_SamplingClientPickleState]]:
"""Enable pickling of SamplingClient for subprocess use.
Stores the sampling_session_id and holder constructor kwargs.
On unpickle, creates a shadow holder and reconstructs the client.
Serializes into a ``_SamplingClientPickleState`` dataclass. The
``_sampling_client_sidecar_handle`` handle is deliberately omitted only a
bool flag is stored. The unpickled copy creates its own handle via
the per-process sidecar singleton. Do not add ``__getstate__``
without preserving this behavior.
"""
return (
_unpickle_sampling_client,
(
self.holder.get_session_id(),
self._sampling_session_id,
self.holder._constructor_kwargs,
_SamplingClientPickleState(
session_id=self.holder.get_session_id(),
sampling_session_id=self._sampling_session_id,
constructor_kwargs=self.holder._constructor_kwargs,
subprocess_sampling=self._sampling_client_sidecar_handle is not None,
),
),
)
@ -386,21 +479,21 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
)
def _unpickle_sampling_client(
session_id: str,
sampling_session_id: str,
constructor_kwargs: dict[str, Any],
) -> SamplingClient:
"""Reconstruct a SamplingClient from pickled data.
def _unpickle_sampling_client(state: _SamplingClientPickleState) -> SamplingClient:
"""Reconstruct a SamplingClient from pickled state.
Creates a shadow InternalClientHolder and builds a new SamplingClient.
The request_id_counter starts at a random high value to avoid collisions.
Subprocess enablement is handled by the constructor.
"""
from ..internal_client_holder import InternalClientHolder
holder = InternalClientHolder.get_shadow_holder(session_id, constructor_kwargs)
client = SamplingClient(holder, sampling_session_id=sampling_session_id, shadow=True)
return client
holder = InternalClientHolder.get_shadow_holder(state.session_id, state.constructor_kwargs)
return SamplingClient(
holder,
sampling_session_id=state.sampling_session_id,
shadow=True,
subprocess_sampling=state.subprocess_sampling,
)
@lru_cache(maxsize=100)

View file

@ -7,7 +7,7 @@ import logging
import threading
import time
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple
from tinker import types
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
@ -49,6 +49,11 @@ MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initi
# Args: (data: List[Datum], model_outputs: List[Any]) -> (loss: Any, metrics: Dict[str, float])
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
_SUPPORTED_CUSTOM_BACKEND_LOSS_FNS = frozenset({"cross_entropy"})
_CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE: dict[Literal["logprobs"], types.LossFnType] = {
"logprobs": "cross_entropy",
}
class TrainingClient(TelemetryProvider):
"""Client for training ML models with forward/backward passes and optimization.
@ -331,7 +336,11 @@ class TrainingClient(TelemetryProvider):
@sync_only
@capture_exceptions(fatal=True)
def forward_backward_custom(
self, data: List[types.Datum], loss_fn: CustomLossFnV1
self,
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs",
) -> APIFuture[types.ForwardBackwardOutput]:
"""Compute forward/backward with a custom loss function.
@ -341,6 +350,7 @@ class TrainingClient(TelemetryProvider):
Args:
- `data`: List of training data samples
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
Returns:
- `APIFuture` containing the forward/backward outputs with custom loss
@ -360,23 +370,49 @@ class TrainingClient(TelemetryProvider):
```
"""
return self.holder.run_coroutine_threadsafe(
self.forward_backward_custom_async(data, loss_fn)
self.forward_backward_custom_async(
data,
loss_fn,
loss_type_input=loss_type_input,
)
).result()
@capture_exceptions(fatal=True)
async def forward_backward_custom_async(
self, data: List[types.Datum], loss_fn: CustomLossFnV1
self,
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs",
) -> APIFuture[types.ForwardBackwardOutput]:
"""Async version of forward_backward_custom."""
if torch is None:
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
if loss_type_input not in _CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE:
supported = ", ".join(sorted(_CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE))
raise ValueError(
f"Unsupported loss_type_input={loss_type_input!r}. "
f"Supported values are: {supported}"
)
surrogate_loss_fn = _CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE[loss_type_input]
forward_data = self._get_custom_loss_forward_data(data, surrogate_loss_fn)
# First do a forward pass and get logprobs
forward_future = await self.forward_async(data, "cross_entropy")
forward_future = await self.forward_async(
forward_data,
surrogate_loss_fn,
None,
)
forward_result = await forward_future.result_async()
logprobs_list = []
for out in forward_result.loss_fn_outputs:
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
logprob = torch.tensor(out["logprobs"].data)
if out["logprobs"].shape is not None:
logprob = logprob.reshape(out["logprobs"].shape)
logprob = logprob.clone().detach().requires_grad_(True)
logprobs_list.append(logprob)
# Now apply user-provided function
@ -392,7 +428,9 @@ class TrainingClient(TelemetryProvider):
for datum, grad in zip(data, grads, strict=True):
loss_fn_inputs: Any = {
"target_tokens": datum.loss_fn_inputs["target_tokens"],
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
# Backend CE is L = sum(-logprobs * weights), so to backpropagate a
# client-side custom loss C(logprobs) we must send weights = -dC/dlogprobs.
"weights": -grad,
}
linear_loss_data.append(
types.Datum(
@ -402,7 +440,11 @@ class TrainingClient(TelemetryProvider):
)
# Do the backward pass with the gradients
backward_future = await self.forward_backward_async(linear_loss_data, "cross_entropy")
backward_future = await self.forward_backward_async(
linear_loss_data,
surrogate_loss_fn,
None,
)
# We need to slightly modify the future to add the custom metrics, so we use _CombinedAPIFuture
# to transform the future.
@ -415,6 +457,49 @@ class TrainingClient(TelemetryProvider):
return _CombinedAPIFuture([backward_future], add_custom_metrics, self.holder)
def _get_custom_loss_forward_data(
self,
data: List[types.Datum],
surrogate_loss_fn: types.LossFnType,
) -> List[types.Datum]:
assert surrogate_loss_fn in _SUPPORTED_CUSTOM_BACKEND_LOSS_FNS, (
"forward_backward_custom_async should validate surrogate_loss_fn before "
"_get_custom_loss_forward_data is called"
)
forward_data = []
for datum in data:
target_tokens = datum.loss_fn_inputs.get("target_tokens")
if target_tokens is None:
raise ValueError("target_tokens must be provided when using cross_entropy")
unexpected_keys = sorted(set(datum.loss_fn_inputs) - {"target_tokens", "weights"})
if unexpected_keys:
raise ValueError(
"forward_backward_custom only supports loss_fn_inputs keys "
"{'target_tokens', 'weights'}; "
f"found unexpected keys: {unexpected_keys}"
)
if "weights" in datum.loss_fn_inputs:
forward_data.append(datum)
continue
forward_loss_fn_inputs = dict(datum.loss_fn_inputs)
forward_loss_fn_inputs["weights"] = types.TensorData(
data=[0.0] * len(target_tokens.data),
dtype="float32",
shape=target_tokens.shape,
)
forward_data.append(
types.Datum(
model_input=datum.model_input,
loss_fn_inputs=forward_loss_fn_inputs,
)
)
return forward_data
@capture_exceptions(fatal=True)
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
"""Update model parameters using Adam optimizer.

View file

@ -44,13 +44,30 @@ class Datum(StrictBase):
@classmethod
def _maybe_convert_array(cls, key: str, value: Any) -> Any:
"""Convert torch.Tensor, numpy array, or 1-D list to TensorData if needed."""
"""Convert torch.Tensor, numpy array, or numeric lists to TensorData if needed."""
if _HAVE_TORCH and isinstance(value, torch.Tensor):
return TensorData.from_torch(value)
elif isinstance(value, np.ndarray):
return TensorData.from_numpy(value)
elif isinstance(value, list):
# assume it's 1d and infer the dtype from the key
try:
array = np.asarray(value)
except ValueError as exc:
if any(isinstance(item, list) for item in value):
raise ValueError(
f"{key} must be a rectangular numeric array; ragged nested lists are not supported"
) from exc
raise
if array.dtype.kind in ("f", "i", "u"):
if _key_to_type[key] == "int64":
array = array.astype(np.int64)
else:
array = array.astype(np.float32)
return TensorData.from_numpy(array)
if any(isinstance(item, list) for item in value):
raise ValueError(
f"{key} must be a rectangular numeric array; ragged nested lists are not supported"
)
return TensorData(data=value, dtype=_key_to_type[key], shape=[len(value)])
else:
return value

View file

@ -2,4 +2,10 @@ from typing_extensions import Literal, TypeAlias
__all__ = ["LossFnType"]
LossFnType: TypeAlias = Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"]
LossFnType: TypeAlias = Literal[
"cross_entropy",
"importance_sampling",
"ppo",
"cispo",
"dro",
]

View file

@ -0,0 +1,196 @@
"""Tests for bulk checkpoint deletion: CLI flags and date parsing."""
from datetime import UTC, datetime, timedelta
import pytest
from click.testing import CliRunner
from tinker.cli.commands.checkpoint import _filter_checkpoints, _parse_date
class TestParseDate:
"""Tests for the _parse_date ISO 8601 parser."""
def test_date_only(self) -> None:
dt = _parse_date("2024-01-15")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
assert dt.tzinfo is not None
def test_datetime_with_z(self) -> None:
dt = _parse_date("2024-06-01T12:00:00Z")
assert dt.year == 2024
assert dt.month == 6
assert dt.hour == 12
assert dt.tzinfo is not None
def test_datetime_with_offset(self) -> None:
dt = _parse_date("2024-06-01T12:00:00+00:00")
assert dt.year == 2024
assert dt.tzinfo is not None
def test_datetime_naive_gets_utc(self) -> None:
dt = _parse_date("2024-06-01T12:00:00")
assert dt.tzinfo is not None
def test_whitespace_stripped(self) -> None:
dt = _parse_date(" 2024-01-01 ")
assert dt.year == 2024
def test_invalid_raises(self) -> None:
from tinker.cli.exceptions import TinkerCliError
with pytest.raises(TinkerCliError):
_parse_date("not-a-date")
def test_invalid_format_raises(self) -> None:
from tinker.cli.exceptions import TinkerCliError
with pytest.raises(TinkerCliError):
_parse_date("01/15/2024")
class TestFilterCheckpoints:
"""Tests for the _filter_checkpoints function."""
@pytest.fixture()
def sample_checkpoints(self):
from tinker.types.checkpoint import Checkpoint
now = datetime.now(UTC)
return [
Checkpoint(
checkpoint_id="weights/0001",
checkpoint_type="training",
time=now - timedelta(days=10),
tinker_path="tinker://run-1/weights/0001",
size_bytes=1000,
),
Checkpoint(
checkpoint_id="weights/0002",
checkpoint_type="training",
time=now - timedelta(days=3),
tinker_path="tinker://run-1/weights/0002",
size_bytes=2000,
),
Checkpoint(
checkpoint_id="sampler_weights/0001",
checkpoint_type="sampler",
time=now - timedelta(days=10),
tinker_path="tinker://run-1/sampler_weights/0001",
size_bytes=500,
),
Checkpoint(
checkpoint_id="weights/0003",
checkpoint_type="training",
time=now - timedelta(hours=1),
tinker_path="tinker://run-1/weights/0003",
size_bytes=3000,
),
]
def test_no_filters(self, sample_checkpoints) -> None:
result = _filter_checkpoints(sample_checkpoints, None, None, None)
assert len(result) == 4
def test_filter_by_weights_type(self, sample_checkpoints) -> None:
result = _filter_checkpoints(sample_checkpoints, "weights", None, None)
assert len(result) == 3
assert all(c.checkpoint_type == "training" for c in result)
def test_filter_by_sampler_weights_type(self, sample_checkpoints) -> None:
result = _filter_checkpoints(sample_checkpoints, "sampler_weights", None, None)
assert len(result) == 1
assert result[0].checkpoint_type == "sampler"
def test_filter_before(self, sample_checkpoints) -> None:
# Before 7 days ago → only the 10-day-old checkpoints
cutoff = datetime.now(UTC) - timedelta(days=7)
result = _filter_checkpoints(sample_checkpoints, None, cutoff, None)
assert len(result) == 2
assert all("0001" in c.checkpoint_id for c in result)
def test_filter_after(self, sample_checkpoints) -> None:
# After 7 days ago → the 3-day-old and 1-hour-old checkpoints
cutoff = datetime.now(UTC) - timedelta(days=7)
result = _filter_checkpoints(sample_checkpoints, None, None, cutoff)
assert len(result) == 2
paths = {c.tinker_path for c in result}
assert "tinker://run-1/weights/0002" in paths
assert "tinker://run-1/weights/0003" in paths
def test_filter_date_range(self, sample_checkpoints) -> None:
# Between 5 and 2 days ago → only the 3-day-old checkpoint
after_dt = datetime.now(UTC) - timedelta(days=5)
before_dt = datetime.now(UTC) - timedelta(days=2)
result = _filter_checkpoints(sample_checkpoints, None, before_dt, after_dt)
assert len(result) == 1
assert result[0].tinker_path == "tinker://run-1/weights/0002"
def test_filter_by_type_and_before(self, sample_checkpoints) -> None:
cutoff = datetime.now(UTC) - timedelta(days=7)
result = _filter_checkpoints(sample_checkpoints, "weights", cutoff, None)
assert len(result) == 1
assert result[0].tinker_path == "tinker://run-1/weights/0001"
def test_invalid_type_raises(self, sample_checkpoints) -> None:
from tinker.cli.exceptions import TinkerCliError
with pytest.raises(TinkerCliError):
_filter_checkpoints(sample_checkpoints, "invalid_type", None, None)
class TestDeleteCLIValidation:
"""Tests for CLI delete command argument validation."""
def _get_error_message(self, result) -> str:
"""Get error message from either output or exception."""
if result.output:
return result.output
if result.exception:
return str(result.exception)
return ""
def test_no_args_shows_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(cli, ["delete"])
assert result.exit_code != 0
def test_paths_and_run_id_conflict(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--run-id", "run-1"])
assert result.exit_code != 0
assert "Cannot specify both" in self._get_error_message(result)
def test_type_without_run_id_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--type", "weights"])
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)
def test_before_without_run_id_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(
cli, ["delete", "tinker://run-1/weights/0001", "--before", "2024-01-01"]
)
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)
def test_after_without_run_id_error(self) -> None:
from tinker.cli.commands.checkpoint import cli
runner = CliRunner()
result = runner.invoke(
cli, ["delete", "tinker://run-1/weights/0001", "--after", "2024-01-01"]
)
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)

View file

@ -0,0 +1,384 @@
"""Tests for SamplingClient subprocess mode.
These tests use a picklable fake SamplingClient to verify that
subprocess mode correctly routes sample() and compute_logprobs()
through the sidecar subprocess.
Test organization:
TestRPCRouting sample/compute_logprobs delegation through sidecar
TestErrorHandling error propagation, sidecar death
TestPickle roundtrip with/without sidecar, re-enable mode
TestConcurrency multithreaded, async, cancelled futures, mixed ops
"""
from __future__ import annotations
import asyncio
import pickle
import threading
import time
from concurrent.futures import Future as ConcurrentFuture
from typing import Any
import pytest
from tinker import types
from tinker._exceptions import SidecarDiedError
from tinker.lib.sidecar import create_sidecar_handle
# ---------------------------------------------------------------------------
# Picklable fake SamplingClient (must be module-level for pickling)
# ---------------------------------------------------------------------------
class _FakeSamplingClient:
"""A picklable fake that simulates SamplingClient for testing.
This is NOT a real SamplingClient it provides just enough interface
to test the sidecar integration. Real SamplingClient requires an
InternalClientHolder and API connection.
"""
def __init__(self, delay: float = 0.0, fail: bool = False, subprocess_sampling: bool = False):
self._delay = delay
self._fail = fail
self._sampling_client_sidecar_handle = None # set by create_sidecar_handle() in tests
if subprocess_sampling:
from tinker.lib.sidecar import _inside_sidecar
if not _inside_sidecar:
self._sampling_client_sidecar_handle = create_sidecar_handle(self)
def sample(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
) -> Any:
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
if self._sampling_client_sidecar_handle is not None:
from tinker.lib.public_interfaces.sampling_client import _SampleRPC
return self._sampling_client_sidecar_handle.submit_rpc(
_SampleRPC(
prompt,
num_samples,
sampling_params,
include_prompt_logprobs,
topk_prompt_logprobs,
)
)
f: ConcurrentFuture[types.SampleResponse] = ConcurrentFuture()
if self._fail:
f.set_exception(RuntimeError("Simulated sample failure"))
elif self._delay > 0:
def _delayed():
time.sleep(self._delay)
f.set_result(_make_sample_response())
threading.Thread(target=_delayed, daemon=True).start()
else:
f.set_result(_make_sample_response())
return f
def compute_logprobs(self, prompt: types.ModelInput) -> Any:
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
if self._sampling_client_sidecar_handle is not None:
from tinker.lib.public_interfaces.sampling_client import _ComputeLogprobsRPC
return self._sampling_client_sidecar_handle.submit_rpc(_ComputeLogprobsRPC(prompt))
f: ConcurrentFuture[list[float | None]] = ConcurrentFuture()
if self._fail:
f.set_exception(RuntimeError("Simulated logprobs failure"))
else:
f.set_result([0.1, 0.2, None])
return f
def __reduce__(self) -> tuple[type, tuple[float, bool, bool]]:
return (
_FakeSamplingClient,
(self._delay, self._fail, self._sampling_client_sidecar_handle is not None),
)
def _make_sample_response() -> types.SampleResponse:
return types.SampleResponse(
sequences=[
types.SampledSequence(
stop_reason="length",
tokens=[1, 2, 3],
logprobs=[0.1, 0.2, 0.3],
)
],
type="sample",
)
def _create_proxy(delay: float = 0.0, fail: bool = False) -> _FakeSamplingClient:
"""Create a fake client with sidecar handle for testing."""
client = _FakeSamplingClient(delay=delay, fail=fail)
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
return client
_PROMPT = types.ModelInput.from_ints([1, 2, 3])
_PARAMS = types.SamplingParams(max_tokens=10)
# ===========================================================================
# Tests
# ===========================================================================
class TestRPCRouting:
"""Verify sample() and compute_logprobs() are routed through the sidecar."""
def test_sample(self):
"""sample() → subprocess → SampleResponse."""
proxy = _create_proxy()
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
assert result.sequences[0].tokens == [1, 2, 3]
def test_constructor_enables_subprocess_mode(self):
"""subprocess_sampling=True in __init__ creates the sidecar handle."""
client = _FakeSamplingClient(subprocess_sampling=True)
assert client._sampling_client_sidecar_handle is not None
result = client.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
def test_compute_logprobs(self):
"""compute_logprobs() → subprocess → list of logprobs."""
proxy = _create_proxy()
result = proxy.compute_logprobs(_PROMPT).result(timeout=10)
assert result == [0.1, 0.2, None]
def test_mixed_sample_and_logprobs(self):
"""Interleaved sample() and compute_logprobs() all resolve correctly."""
proxy = _create_proxy(delay=0.01)
futures_sample = [proxy.sample(_PROMPT, 1, _PARAMS) for _ in range(10)]
futures_logprobs = [proxy.compute_logprobs(_PROMPT) for _ in range(10)]
for f in futures_sample:
result = f.result(timeout=10)
assert isinstance(result, types.SampleResponse)
assert result.sequences[0].tokens == [1, 2, 3]
for f in futures_logprobs:
assert f.result(timeout=10) == [0.1, 0.2, None]
class TestErrorHandling:
"""Error propagation from subprocess to caller."""
def test_sample_error(self):
"""Exceptions from sample() in the subprocess are propagated."""
proxy = _create_proxy(fail=True)
with pytest.raises(RuntimeError, match="Simulated sample failure"):
proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
def test_compute_logprobs_error(self):
"""Exceptions from compute_logprobs() in the subprocess are propagated."""
proxy = _create_proxy(fail=True)
with pytest.raises(RuntimeError, match="Simulated logprobs failure"):
proxy.compute_logprobs(_PROMPT).result(timeout=10)
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_sidecar_death_fails_pending_futures(self):
"""When the subprocess is killed, pending futures get SidecarDiedError."""
proxy = _create_proxy(delay=0.5)
future = proxy.sample(_PROMPT, 1, _PARAMS)
# Kill the underlying subprocess directly
sidecar = proxy._sampling_client_sidecar_handle._sidecar
assert sidecar._process is not None
sidecar._process.kill()
sidecar._process.join(timeout=5)
with pytest.raises(SidecarDiedError):
future.result(timeout=5)
class TestPickle:
"""Pickle roundtrip preserves subprocess mode correctly."""
def test_roundtrip_preserves_subprocess_mode(self):
"""Pickling a sidecar-enabled client re-enables subprocess mode on unpickle."""
proxy = _create_proxy()
assert proxy._sampling_client_sidecar_handle is not None
restored = pickle.loads(pickle.dumps(proxy))
assert restored._sampling_client_sidecar_handle is not None
result = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
def test_roundtrip_without_sidecar(self):
"""Pickling a client without subprocess mode keeps it disabled."""
client = _FakeSamplingClient()
assert client._sampling_client_sidecar_handle is None
restored = pickle.loads(pickle.dumps(client))
assert restored._sampling_client_sidecar_handle is None
def test_re_enable_subprocess_mode(self):
"""Replacing the sidecar handle works cleanly."""
client = _FakeSamplingClient()
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
# First handle works
assert isinstance(
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
)
# Replace with a new handle (old one is GC'd and unregistered)
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
# New handle also works
assert isinstance(
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
)
class TestConcurrency:
"""Thread safety and concurrent operations through the sidecar."""
def test_multithreaded_samples(self):
"""sample() from 20 threads all resolve correctly."""
proxy = _create_proxy(delay=0.01)
results: list[types.SampleResponse | None] = [None] * 20
errors: list[Exception] = []
def _worker(idx: int) -> None:
try:
results[idx] = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=30)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=_worker, args=(i,)) for i in range(20)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
assert not errors, f"Threads raised: {errors}"
for r in results:
assert isinstance(r, types.SampleResponse)
assert r.sequences[0].tokens == [1, 2, 3]
def test_multithreaded_mixed_operations(self):
"""sample() and compute_logprobs() from different threads simultaneously."""
proxy = _create_proxy(delay=0.01)
errors: list[Exception] = []
def _sample_worker() -> None:
try:
for _ in range(10):
r = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(r, types.SampleResponse)
except Exception as e:
errors.append(e)
def _logprobs_worker() -> None:
try:
for _ in range(10):
r = proxy.compute_logprobs(_PROMPT).result(timeout=10)
assert r == [0.1, 0.2, None]
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=_sample_worker) for _ in range(3)]
threads += [threading.Thread(target=_logprobs_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
assert not errors, f"Errors: {errors}"
def test_async_concurrent_samples(self):
"""Multiple async sample calls via asyncio.gather all resolve."""
proxy = _create_proxy(delay=0.01)
async def _run() -> list[types.SampleResponse]:
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
coros = [
AwaitableConcurrentFuture(proxy.sample(_PROMPT, 1, _PARAMS)) for _ in range(20)
]
return await asyncio.gather(*coros)
results = asyncio.run(_run())
assert len(results) == 20
for r in results:
assert isinstance(r, types.SampleResponse)
def test_cancelled_future_does_not_crash_collector(self):
"""Cancelling a future doesn't kill the collector thread."""
proxy = _create_proxy(delay=0.5)
future1 = proxy.sample(_PROMPT, 1, _PARAMS)
future1.cancel()
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(result, types.SampleResponse)
def test_multiple_clients_share_sidecar(self):
"""Two independent clients sharing the sidecar singleton work concurrently."""
proxy1 = _create_proxy(delay=0.01)
proxy2 = _create_proxy(delay=0.01)
errors: list[Exception] = []
def _worker1() -> None:
try:
for _ in range(10):
r = proxy1.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(r, types.SampleResponse)
except Exception as e:
errors.append(e)
def _worker2() -> None:
try:
for _ in range(10):
r = proxy2.compute_logprobs(_PROMPT).result(timeout=10)
assert r == [0.1, 0.2, None]
except Exception as e:
errors.append(e)
t1 = threading.Thread(target=_worker1)
t2 = threading.Thread(target=_worker2)
t1.start()
t2.start()
t1.join(timeout=30)
t2.join(timeout=30)
assert not errors, f"Errors: {errors}"
def test_pickle_roundtrip_then_concurrent_use(self):
"""Pickle a client, restore it, then use from multiple threads."""
proxy = _create_proxy(delay=0.01)
restored = pickle.loads(pickle.dumps(proxy))
assert restored._sampling_client_sidecar_handle is not None
errors: list[Exception] = []
def _worker() -> None:
try:
for _ in range(10):
r = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
assert isinstance(r, types.SampleResponse)
except Exception as e:
errors.append(e)
threads = [threading.Thread(target=_worker) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=30)
assert not errors, f"Errors: {errors}"

View file

@ -0,0 +1,234 @@
from __future__ import annotations
from concurrent.futures import Future
import pytest
import torch
from tinker import types
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.public_interfaces.training_client import TrainingClient
class _DummyHolder:
def run_coroutine_threadsafe(self, coro):
future: Future = Future()
future.set_result(coro)
return future
def get_telemetry(self):
return None
class _FakeTrainingClient(TrainingClient):
def __init__(self):
self.holder = _DummyHolder()
self.forward_calls: list[
tuple[list[types.Datum], types.LossFnType, dict[str, float] | None]
] = []
self.backward_calls: list[
tuple[list[types.Datum], types.LossFnType, dict[str, float] | None]
] = []
async def forward_async(
self,
data: list[types.Datum],
loss_fn: types.LossFnType,
loss_fn_config: dict[str, float] | None = None,
):
self.forward_calls.append((data, loss_fn, loss_fn_config))
result = types.ForwardBackwardOutput(
metrics={},
loss_fn_output_type="target_token_logprobs",
loss_fn_outputs=[
{
"logprobs": types.TensorData(
data=[-3.0, -2.0, -1.0, 0.0],
dtype="float32",
shape=[2, 2],
),
}
],
)
future: Future = Future()
future.set_result(result)
return AwaitableConcurrentFuture(future)
async def forward_backward_async(
self,
data: list[types.Datum],
loss_fn: types.LossFnType,
loss_fn_config: dict[str, float] | None = None,
):
self.backward_calls.append((data, loss_fn, loss_fn_config))
result = types.ForwardBackwardOutput(
metrics={"base:sum": 1.0},
loss_fn_output_type="target_token_logprobs",
loss_fn_outputs=[],
)
future: Future = Future()
future.set_result(result)
return AwaitableConcurrentFuture(future)
@pytest.mark.asyncio
async def test_forward_backward_custom_supports_2d_cross_entropy_targets():
client = _FakeTrainingClient()
datum = types.Datum(
model_input=types.ModelInput.from_ints([1, 2]),
loss_fn_inputs={
"target_tokens": [[101, 102], [201, 202]],
},
)
assert datum.loss_fn_inputs["target_tokens"].shape == [2, 2]
def custom_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
del data
logprobs = logprobs_list[0]
assert logprobs.shape == (2, 2)
probs = torch.softmax(logprobs[1], dim=-1)
target_distribution = torch.tensor([0.0, 1.0], dtype=torch.float32)
loss = torch.sum((probs - target_distribution) ** 2)
return loss, {"selected_prob:mean": float(probs[1].detach())}
result_future = await client.forward_backward_custom_async(
[datum],
custom_loss,
loss_type_input="logprobs",
)
result = await result_future.result_async()
assert client.forward_calls[0][1] == "cross_entropy"
forward_datum = client.forward_calls[0][0][0]
assert forward_datum.loss_fn_inputs["weights"].shape == [2, 2]
assert client.backward_calls[0][1] == "cross_entropy"
backward_datum = client.backward_calls[0][0][0]
assert backward_datum.loss_fn_inputs["target_tokens"].shape == [2, 2]
assert backward_datum.loss_fn_inputs["weights"].shape == [2, 2]
assert "weights" not in datum.loss_fn_inputs
assert result.metrics["selected_prob:mean"] > 0.0
@pytest.mark.asyncio
async def test_forward_backward_custom_preserves_1d_cross_entropy_targets():
client = _FakeTrainingClient()
datum = types.Datum(
model_input=types.ModelInput.from_ints([1, 2]),
loss_fn_inputs={"target_tokens": [101, 102]},
)
async def forward_async_1d(
data: list[types.Datum],
loss_fn: types.LossFnType,
loss_fn_config: dict[str, float] | None = None,
):
client.forward_calls.append((data, loss_fn, loss_fn_config))
result = types.ForwardBackwardOutput(
metrics={},
loss_fn_output_type="target_token_logprobs",
loss_fn_outputs=[
{
"logprobs": types.TensorData(
data=[-3.0, -1.0],
dtype="float32",
shape=[2],
),
}
],
)
future: Future = Future()
future.set_result(result)
return AwaitableConcurrentFuture(future)
setattr(client, "forward_async", forward_async_1d)
def custom_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
del data
logprobs = logprobs_list[0]
assert logprobs.shape == (2,)
loss = -logprobs[-1]
return loss, {"selected_logprob:last": float(logprobs[-1].detach())}
result_future = await client.forward_backward_custom_async(
[datum],
custom_loss,
loss_type_input="logprobs",
)
result = await result_future.result_async()
assert client.forward_calls[0][1] == "cross_entropy"
forward_datum = client.forward_calls[0][0][0]
assert forward_datum.loss_fn_inputs["weights"].shape == [2]
assert client.backward_calls[0][1] == "cross_entropy"
backward_datum = client.backward_calls[0][0][0]
assert backward_datum.loss_fn_inputs["target_tokens"].shape == [2]
assert backward_datum.loss_fn_inputs["weights"].shape == [2]
torch.testing.assert_close(
torch.tensor(backward_datum.loss_fn_inputs["weights"].data).reshape(
backward_datum.loss_fn_inputs["weights"].shape
),
torch.tensor([0.0, 1.0], dtype=torch.float32),
)
assert result.metrics["selected_logprob:last"] < 0.0
@pytest.mark.asyncio
async def test_forward_backward_custom_rejects_unsupported_loss_type_input():
client = _FakeTrainingClient()
datum = types.Datum(
model_input=types.ModelInput.from_ints([1, 2]),
loss_fn_inputs={"target_tokens": [101, 102]},
)
def custom_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
del data, logprobs_list
return torch.tensor(0.0, requires_grad=True), {}
with pytest.raises(ValueError, match="Unsupported loss_type_input"):
await client.forward_backward_custom_async(
[datum],
custom_loss,
loss_type_input="logits", # type: ignore[arg-type]
)
def test_datum_rejects_ragged_nested_target_tokens():
with pytest.raises(ValueError, match="ragged nested lists are not supported"):
types.Datum(
model_input=types.ModelInput.from_ints([1, 2]),
loss_fn_inputs={"target_tokens": [[101, 102], [201]]},
)
@pytest.mark.asyncio
async def test_forward_backward_custom_rejects_unexpected_loss_fn_input_keys():
client = _FakeTrainingClient()
datum = types.Datum(
model_input=types.ModelInput.from_ints([1, 2]),
loss_fn_inputs={
"target_tokens": [101, 102],
"advantages": [1.0, 1.0],
},
)
def custom_loss(
data: list[types.Datum], logprobs_list: list[torch.Tensor]
) -> tuple[torch.Tensor, dict[str, float]]:
del data, logprobs_list
return torch.tensor(0.0, requires_grad=True), {}
with pytest.raises(ValueError, match="only supports loss_fn_inputs keys"):
await client.forward_backward_custom_async(
[datum],
custom_loss,
loss_type_input="logprobs",
)