mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
Sync contents
This commit is contained in:
parent
35a77e79fe
commit
f3c0b1f179
17 changed files with 2069 additions and 632 deletions
|
|
@ -1,4 +1,4 @@
|
|||
{
|
||||
"last_synced_sha": "b31ac99f31c26082a4076c3c79719034b2e0cab2",
|
||||
"last_sync_time": "2026-03-09T00:49:41.558594"
|
||||
"last_synced_sha": "2138c60c730c51b7bc19146a38320d631c38e0cc",
|
||||
"last_sync_time": "2026-03-19T00:10:49.917056"
|
||||
}
|
||||
|
|
@ -121,6 +121,38 @@ class InternalServerError(APIStatusError)
|
|||
|
||||
HTTP 500+: An error occurred on the server.
|
||||
|
||||
## `SidecarError` Objects
|
||||
|
||||
```python
|
||||
class SidecarError(TinkerError)
|
||||
```
|
||||
|
||||
Base exception for subprocess sidecar errors.
|
||||
|
||||
## `SidecarStartupError` Objects
|
||||
|
||||
```python
|
||||
class SidecarStartupError(SidecarError)
|
||||
```
|
||||
|
||||
Raised when the sidecar subprocess fails to start or times out.
|
||||
|
||||
## `SidecarDiedError` Objects
|
||||
|
||||
```python
|
||||
class SidecarDiedError(SidecarError)
|
||||
```
|
||||
|
||||
Raised when the sidecar subprocess exits unexpectedly while requests are pending.
|
||||
|
||||
## `SidecarIPCError` Objects
|
||||
|
||||
```python
|
||||
class SidecarIPCError(SidecarError)
|
||||
```
|
||||
|
||||
Raised when communication with the sidecar subprocess fails.
|
||||
|
||||
## `RequestFailedError` Objects
|
||||
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ Key methods:
|
|||
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
||||
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
|
||||
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
|
||||
- set_checkpoint_ttl_from_tinker_path() - set or remove TTL on a checkpoint
|
||||
|
||||
Args:
|
||||
- `holder`: Internal client managing HTTP connections and async operations
|
||||
|
|
@ -39,7 +40,9 @@ for checkpoint in checkpoints.checkpoints:
|
|||
|
||||
```python
|
||||
def get_training_run(
|
||||
training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]
|
||||
training_run_id: types.ModelID,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.TrainingRun]
|
||||
```
|
||||
|
||||
Get training run info.
|
||||
|
|
@ -61,7 +64,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
|
|||
|
||||
```python
|
||||
async def get_training_run_async(
|
||||
training_run_id: types.ModelID) -> types.TrainingRun
|
||||
training_run_id: types.ModelID,
|
||||
access_scope: Literal["owned",
|
||||
"accessible"] = "owned") -> types.TrainingRun
|
||||
```
|
||||
|
||||
Async version of get_training_run.
|
||||
|
|
@ -70,7 +75,9 @@ Async version of get_training_run.
|
|||
|
||||
```python
|
||||
def get_training_run_by_tinker_path(
|
||||
tinker_path: str) -> ConcurrentFuture[types.TrainingRun]
|
||||
tinker_path: str,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.TrainingRun]
|
||||
```
|
||||
|
||||
Get training run info.
|
||||
|
|
@ -92,7 +99,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
|
|||
|
||||
```python
|
||||
async def get_training_run_by_tinker_path_async(
|
||||
tinker_path: str) -> types.TrainingRun
|
||||
tinker_path: str,
|
||||
access_scope: Literal["owned",
|
||||
"accessible"] = "owned") -> types.TrainingRun
|
||||
```
|
||||
|
||||
Async version of get_training_run_by_tinker_path.
|
||||
|
|
@ -123,8 +132,10 @@ print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
|
|||
|
||||
```python
|
||||
def list_training_runs(
|
||||
limit: int = 20,
|
||||
offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse]
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.TrainingRunsResponse]
|
||||
```
|
||||
|
||||
List training runs with pagination support.
|
||||
|
|
@ -149,9 +160,11 @@ next_page = rest_client.list_training_runs(limit=50, offset=50)
|
|||
#### `list_training_runs_async`
|
||||
|
||||
```python
|
||||
async def list_training_runs_async(limit: int = 20,
|
||||
offset: int = 0
|
||||
) -> types.TrainingRunsResponse
|
||||
async def list_training_runs_async(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> types.TrainingRunsResponse
|
||||
```
|
||||
|
||||
Async version of list_training_runs.
|
||||
|
|
@ -367,6 +380,46 @@ async def unpublish_checkpoint_from_tinker_path_async(
|
|||
|
||||
Async version of unpublish_checkpoint_from_tinker_path.
|
||||
|
||||
#### `set_checkpoint_ttl_from_tinker_path`
|
||||
|
||||
```python
|
||||
def set_checkpoint_ttl_from_tinker_path(
|
||||
tinker_path: str, ttl_seconds: int | None) -> ConcurrentFuture[None]
|
||||
```
|
||||
|
||||
Set or remove the TTL on a checkpoint referenced by a tinker path.
|
||||
|
||||
If ttl_seconds is provided, the checkpoint will expire after that many seconds from now.
|
||||
If ttl_seconds is None, any existing expiration will be removed.
|
||||
|
||||
Args:
|
||||
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||
- `ttl_seconds`: Number of seconds until expiration, or None to remove TTL
|
||||
|
||||
Returns:
|
||||
- A `Future` that completes when the TTL is set
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if checkpoint identifier is invalid or ttl_seconds <= 0
|
||||
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||
HTTPException: 500 if there's an error setting the TTL
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.set_checkpoint_ttl_from_tinker_path("tinker://run-id/weights/0001", 86400)
|
||||
future.result() # Wait for completion
|
||||
print("Checkpoint TTL set successfully")
|
||||
```
|
||||
|
||||
#### `set_checkpoint_ttl_from_tinker_path_async`
|
||||
|
||||
```python
|
||||
async def set_checkpoint_ttl_from_tinker_path_async(
|
||||
tinker_path: str, ttl_seconds: int | None) -> None
|
||||
```
|
||||
|
||||
Async version of set_checkpoint_ttl_from_tinker_path.
|
||||
|
||||
#### `list_user_checkpoints`
|
||||
|
||||
```python
|
||||
|
|
@ -414,7 +467,10 @@ Async version of list_user_checkpoints.
|
|||
#### `get_session`
|
||||
|
||||
```python
|
||||
def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse]
|
||||
def get_session(
|
||||
session_id: str,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.GetSessionResponse]
|
||||
```
|
||||
|
||||
Get session information including all training runs and samplers.
|
||||
|
|
@ -436,7 +492,10 @@ print(f"Samplers: {len(response.sampler_ids)}")
|
|||
#### `get_session_async`
|
||||
|
||||
```python
|
||||
async def get_session_async(session_id: str) -> types.GetSessionResponse
|
||||
async def get_session_async(
|
||||
session_id: str,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> types.GetSessionResponse
|
||||
```
|
||||
|
||||
Async version of get_session.
|
||||
|
|
@ -445,8 +504,10 @@ Async version of get_session.
|
|||
|
||||
```python
|
||||
def list_sessions(
|
||||
limit: int = 20,
|
||||
offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse]
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.ListSessionsResponse]
|
||||
```
|
||||
|
||||
List sessions with pagination support.
|
||||
|
|
@ -470,8 +531,11 @@ next_page = rest_client.list_sessions(limit=50, offset=50)
|
|||
#### `list_sessions_async`
|
||||
|
||||
```python
|
||||
async def list_sessions_async(limit: int = 20,
|
||||
offset: int = 0) -> types.ListSessionsResponse
|
||||
async def list_sessions_async(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> types.ListSessionsResponse
|
||||
```
|
||||
|
||||
Async version of list_sessions.
|
||||
|
|
|
|||
|
|
@ -30,6 +30,20 @@ future = sampling_client.sample(prompt=prompt, sampling_params=params, num_sampl
|
|||
result = future.result()
|
||||
```
|
||||
|
||||
Multi-processing support:
|
||||
This class is picklable, so it can be passed to a separate process/worker to sample. It is also
|
||||
safe to pass the same instance of SamplingClient to multiple processes/workers.
|
||||
|
||||
If you are using Tinker SDK with more than one process you should always create SamplingClient from
|
||||
the main process and then pass it to the other processes/workers.
|
||||
ServiceClient and TrainingClient should always be managed from the main process.
|
||||
|
||||
Subprocess isolation:
|
||||
Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated
|
||||
subprocess, preventing GIL contention from CPU-heavy user code (grading, environment
|
||||
interactions) from stalling networking IO and heartbeats. This is transparent — the same
|
||||
API works with or without it.
|
||||
|
||||
#### `sample`
|
||||
|
||||
```python
|
||||
|
|
@ -121,3 +135,33 @@ Get the tokenizer for the current model.
|
|||
|
||||
Returns:
|
||||
- `PreTrainedTokenizer` compatible with the model
|
||||
|
||||
#### `get_base_model`
|
||||
|
||||
```python
|
||||
def get_base_model() -> str
|
||||
```
|
||||
|
||||
Get the base model name for the current sampling session.
|
||||
|
||||
#### `get_base_model_async`
|
||||
|
||||
```python
|
||||
async def get_base_model_async() -> str
|
||||
```
|
||||
|
||||
Async version of get_base_model.
|
||||
|
||||
#### `__reduce__`
|
||||
|
||||
```python
|
||||
def __reduce__() -> tuple[Any, tuple[_SamplingClientPickleState]]
|
||||
```
|
||||
|
||||
Enable pickling of SamplingClient for subprocess use.
|
||||
|
||||
Serializes into a ``_SamplingClientPickleState`` dataclass. The
|
||||
``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a
|
||||
bool flag is stored. The unpickled copy creates its own handle via
|
||||
the per-process sidecar singleton. Do not add ``__getstate__``
|
||||
without preserving this behavior.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ TrainingClient for Tinker API.
|
|||
## `TrainingClient` Objects
|
||||
|
||||
```python
|
||||
class TrainingClient(TelemetryProvider, QueueStateObserver)
|
||||
class TrainingClient(TelemetryProvider)
|
||||
```
|
||||
|
||||
Client for training ML models with forward/backward passes and optimization.
|
||||
|
|
@ -127,8 +127,11 @@ Async version of forward_backward.
|
|||
|
||||
```python
|
||||
def forward_backward_custom(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1,
|
||||
*,
|
||||
loss_type_input: Literal["logprobs"] = "logprobs"
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Compute forward/backward with a custom loss function.
|
||||
|
|
@ -139,6 +142,7 @@ The custom function receives logprobs and computes loss and gradients.
|
|||
Args:
|
||||
- `data`: List of training data samples
|
||||
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
|
||||
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the forward/backward outputs with custom loss
|
||||
|
|
@ -161,8 +165,11 @@ print(f"Metrics: {result.metrics}")
|
|||
|
||||
```python
|
||||
async def forward_backward_custom_async(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1,
|
||||
*,
|
||||
loss_type_input: Literal["logprobs"] = "logprobs"
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Async version of forward_backward_custom.
|
||||
|
|
|
|||
1130
docs/api/types.md
1130
docs/api/types.md
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "tinker"
|
||||
version = "0.15.0"
|
||||
version = "0.16.0"
|
||||
description = "The official Python SDK for the tinker API"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
import ast
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
|
@ -95,8 +96,42 @@ class DocumentationGenerator:
|
|||
def run_pydoc_markdown(self, modules: List[str], output_file: Path) -> bool:
|
||||
"""Run pydoc-markdown for specific modules."""
|
||||
try:
|
||||
# Build the command
|
||||
cmd = ["pydoc-markdown", "pydoc-markdown.yml", "-I", "src"]
|
||||
# Invoke pydoc-markdown via uvx/uv tool run so the required
|
||||
# doc-generation dependencies are resolved consistently.
|
||||
uvx_path = shutil.which("uvx")
|
||||
if uvx_path:
|
||||
cmd = [
|
||||
uvx_path,
|
||||
"--from",
|
||||
"pydoc-markdown>=4.8.0",
|
||||
"--with",
|
||||
"pyyaml>=6.0",
|
||||
"--with",
|
||||
"setuptools",
|
||||
"pydoc-markdown",
|
||||
"pydoc-markdown.yml",
|
||||
"-I",
|
||||
"src",
|
||||
]
|
||||
else:
|
||||
uv_path = shutil.which("uv")
|
||||
if uv_path is None:
|
||||
raise FileNotFoundError("Could not find `uvx` or `uv` on PATH")
|
||||
cmd = [
|
||||
uv_path,
|
||||
"tool",
|
||||
"run",
|
||||
"--from",
|
||||
"pydoc-markdown>=4.8.0",
|
||||
"--with",
|
||||
"pyyaml>=6.0",
|
||||
"--with",
|
||||
"setuptools",
|
||||
"pydoc-markdown",
|
||||
"pydoc-markdown.yml",
|
||||
"-I",
|
||||
"src",
|
||||
]
|
||||
|
||||
# Add modules
|
||||
for module in modules:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ from ._exceptions import (
|
|||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
RequestFailedError,
|
||||
SidecarDiedError,
|
||||
SidecarError,
|
||||
SidecarIPCError,
|
||||
SidecarStartupError,
|
||||
TinkerError,
|
||||
UnprocessableEntityError,
|
||||
)
|
||||
|
|
@ -97,6 +101,10 @@ __all__ = [
|
|||
"UnprocessableEntityError",
|
||||
"RateLimitError",
|
||||
"InternalServerError",
|
||||
"SidecarError",
|
||||
"SidecarStartupError",
|
||||
"SidecarDiedError",
|
||||
"SidecarIPCError",
|
||||
# Keep types module for advanced use
|
||||
"types",
|
||||
# Version info
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from typing import TYPE_CHECKING, Any, Dict, List
|
|||
import click
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
|
||||
from tinker.lib.public_interfaces.rest_client import RestClient
|
||||
from tinker.types import Checkpoint, TrainingRun
|
||||
|
||||
|
|
@ -838,53 +840,249 @@ def set_ttl(cli_context: CLIContext, checkpoint_path: str, ttl: int | None, remo
|
|||
client.set_checkpoint_ttl_from_tinker_path(checkpoint_path, ttl_seconds).result()
|
||||
|
||||
|
||||
def _parse_date(value: str) -> "datetime":
|
||||
"""Parse an ISO 8601 date or datetime string to a timezone-aware datetime.
|
||||
|
||||
Accepts: 2024-01-01, 2024-01-01T12:00:00, 2024-01-01T12:00:00Z,
|
||||
2024-01-01T12:00:00+00:00. Date-only values are treated as midnight UTC.
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
value = value.strip()
|
||||
# Python < 3.11 doesn't handle trailing 'Z' in fromisoformat
|
||||
if value.endswith("Z"):
|
||||
value = value[:-1] + "+00:00"
|
||||
try:
|
||||
dt = datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
raise TinkerCliError(
|
||||
f"Invalid date: {value}",
|
||||
"Use ISO 8601 format: 2024-01-01, 2024-01-01T12:00:00Z",
|
||||
)
|
||||
# Assume UTC if no timezone provided
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt
|
||||
|
||||
|
||||
_CHECKPOINT_TYPE_MAP = {
|
||||
"weights": "training",
|
||||
"sampler_weights": "sampler",
|
||||
"training": "training",
|
||||
"sampler": "sampler",
|
||||
}
|
||||
|
||||
|
||||
def _filter_checkpoints(
|
||||
checkpoints: "List[Checkpoint]",
|
||||
checkpoint_type: str | None,
|
||||
before: "datetime | None",
|
||||
after: "datetime | None",
|
||||
) -> "List[Checkpoint]":
|
||||
"""Filter checkpoints by type, before date, and/or after date."""
|
||||
filtered = checkpoints
|
||||
if checkpoint_type:
|
||||
mapped_type = _CHECKPOINT_TYPE_MAP.get(checkpoint_type)
|
||||
if mapped_type is None:
|
||||
raise TinkerCliError(
|
||||
f"Invalid checkpoint type: {checkpoint_type}",
|
||||
"Valid types: weights, sampler_weights",
|
||||
)
|
||||
filtered = [c for c in filtered if c.checkpoint_type == mapped_type]
|
||||
if before is not None:
|
||||
filtered = [c for c in filtered if c.time < before]
|
||||
if after is not None:
|
||||
filtered = [c for c in filtered if c.time > after]
|
||||
return filtered
|
||||
|
||||
|
||||
def _confirm_deletion(paths: "List[str]", checkpoints: "List[Checkpoint] | None" = None) -> bool:
|
||||
"""Show deletion summary and prompt for confirmation. Returns True if confirmed."""
|
||||
count = len(paths)
|
||||
if checkpoints is not None:
|
||||
total_size = sum(c.size_bytes or 0 for c in checkpoints)
|
||||
click.echo(f"Will delete {count} checkpoint(s):")
|
||||
for ckpt in checkpoints:
|
||||
size_str = format_size(ckpt.size_bytes) if ckpt.size_bytes is not None else "N/A"
|
||||
time_str = format_timestamp(ckpt.time)
|
||||
click.echo(f" - {ckpt.tinker_path} ({size_str}, created {time_str})")
|
||||
click.echo(f"\nTotal size: {format_size(total_size)}")
|
||||
else:
|
||||
click.echo(f"Will delete {count} checkpoint(s):")
|
||||
for path in paths:
|
||||
click.echo(f" - {path}")
|
||||
click.echo()
|
||||
click.echo("WARNING: This action is permanent and cannot be undone.")
|
||||
return click.confirm(f"Are you sure you want to delete {count} checkpoint(s)?")
|
||||
|
||||
|
||||
_DELETE_CONCURRENCY = 32
|
||||
|
||||
|
||||
def _delete_one(client: "RestClient", path: str) -> "tuple[str, str] | None":
|
||||
"""Delete a single checkpoint. Returns (path, error) on failure, None on success."""
|
||||
try:
|
||||
client.delete_checkpoint_from_tinker_path(path).result()
|
||||
return None
|
||||
except Exception as e:
|
||||
return (path, str(e))
|
||||
|
||||
|
||||
def _delete_paths(
|
||||
client: "RestClient",
|
||||
paths: "List[str]",
|
||||
format: str,
|
||||
) -> None:
|
||||
"""Delete a list of tinker paths concurrently and print results."""
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
deleted_count = 0
|
||||
failed: "List[tuple[str, str]]" = []
|
||||
with (
|
||||
ThreadPoolExecutor(max_workers=_DELETE_CONCURRENCY) as pool,
|
||||
click.progressbar(
|
||||
length=len(paths),
|
||||
label="Deleting checkpoints",
|
||||
show_percent=True,
|
||||
show_pos=True,
|
||||
hidden=format != "table",
|
||||
) as bar,
|
||||
):
|
||||
futures = {pool.submit(_delete_one, client, p): p for p in paths}
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
if result is None:
|
||||
deleted_count += 1
|
||||
else:
|
||||
failed.append(result)
|
||||
bar.update(1)
|
||||
|
||||
if format == "json":
|
||||
import json
|
||||
|
||||
click.echo(
|
||||
json.dumps(
|
||||
{
|
||||
"deleted_count": deleted_count,
|
||||
"failed": [{"tinker_path": p, "error": e} for p, e in failed],
|
||||
}
|
||||
)
|
||||
)
|
||||
else:
|
||||
click.echo(f"Deleted {deleted_count} checkpoint(s)")
|
||||
if failed:
|
||||
click.echo(f"Failed to delete {len(failed)} checkpoint(s):")
|
||||
for path, error in failed:
|
||||
click.echo(f" - {path}: {error}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("checkpoint_paths", nargs=-1, required=True)
|
||||
@click.argument("checkpoint_paths", nargs=-1, required=False)
|
||||
@click.option("--run-id", default=None, help="Delete all checkpoints for a training run")
|
||||
@click.option(
|
||||
"--type", "checkpoint_type", default=None, help="Filter by type: weights or sampler_weights"
|
||||
)
|
||||
@click.option(
|
||||
"--before",
|
||||
default=None,
|
||||
help="Filter: created before date in UTC (ISO 8601, e.g. 2024-01-01, 2024-01-01T08:00:00Z)",
|
||||
)
|
||||
@click.option(
|
||||
"--after",
|
||||
default=None,
|
||||
help="Filter: created after date in UTC (ISO 8601, e.g. 2024-01-01, 2024-01-01T08:00:00Z)",
|
||||
)
|
||||
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
|
||||
@click.pass_obj
|
||||
@handle_api_errors
|
||||
def delete(cli_context: CLIContext, checkpoint_paths: tuple[str, ...], yes: bool) -> None:
|
||||
def delete(
|
||||
cli_context: CLIContext,
|
||||
checkpoint_paths: tuple[str, ...],
|
||||
run_id: str | None,
|
||||
checkpoint_type: str | None,
|
||||
before: str | None,
|
||||
after: str | None,
|
||||
yes: bool,
|
||||
) -> None:
|
||||
"""Delete one or more checkpoints permanently.
|
||||
|
||||
CHECKPOINT_PATHS must be tinker paths (e.g., tinker://run-id/weights/0001).
|
||||
Delete by explicit paths:
|
||||
|
||||
tinker checkpoint delete tinker://run-id/weights/0001 tinker://run-id/weights/0002
|
||||
|
||||
Delete all checkpoints for a training run:
|
||||
|
||||
tinker checkpoint delete --run-id <run-id>
|
||||
|
||||
Delete with filters:
|
||||
|
||||
tinker checkpoint delete --run-id <run-id> --type weights --before 2024-06-01
|
||||
|
||||
Delete checkpoints in a date range:
|
||||
|
||||
tinker checkpoint delete --run-id <run-id> --after 2024-01-01 --before 2024-02-01
|
||||
|
||||
Dates are interpreted as UTC. Use full ISO 8601 datetime for precision:
|
||||
|
||||
tinker checkpoint delete --run-id <run-id> --before 2024-06-01T08:00:00Z
|
||||
|
||||
Only the owner of the training run can delete checkpoints.
|
||||
|
||||
WARNING: This action is permanent and cannot be undone.
|
||||
"""
|
||||
# Validate all paths upfront
|
||||
for path in checkpoint_paths:
|
||||
if not path.startswith("tinker://"):
|
||||
raise TinkerCliError(
|
||||
f"Invalid checkpoint path: {path}",
|
||||
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||
)
|
||||
if not checkpoint_paths and not run_id:
|
||||
raise TinkerCliError(
|
||||
"Must specify checkpoint paths or --run-id",
|
||||
"Examples:\n"
|
||||
" tinker checkpoint delete tinker://run-id/weights/0001\n"
|
||||
" tinker checkpoint delete --run-id <run-id>\n"
|
||||
" tinker checkpoint delete --run-id <run-id> --type weights --before 2024-06-01",
|
||||
)
|
||||
|
||||
# If not using --yes, show checkpoint list and prompt for confirmation
|
||||
if not yes:
|
||||
count = len(checkpoint_paths)
|
||||
click.echo(f"Will delete {count} checkpoint(s):")
|
||||
if checkpoint_paths and run_id:
|
||||
raise TinkerCliError(
|
||||
"Cannot specify both checkpoint paths and --run-id",
|
||||
"Use either explicit paths or --run-id with optional filters",
|
||||
)
|
||||
|
||||
has_filters = checkpoint_type or before or after
|
||||
if has_filters and not run_id:
|
||||
raise TinkerCliError(
|
||||
"--type, --before, and --after require --run-id",
|
||||
"Example: tinker checkpoint delete --run-id <run-id> --type weights --before 2024-06-01",
|
||||
)
|
||||
|
||||
client = create_rest_client()
|
||||
|
||||
if run_id:
|
||||
before_dt = _parse_date(before) if before else None
|
||||
after_dt = _parse_date(after) if after else None
|
||||
response = client.list_checkpoints(run_id).result()
|
||||
checkpoints = _filter_checkpoints(
|
||||
response.checkpoints, checkpoint_type, before_dt, after_dt
|
||||
)
|
||||
|
||||
if not checkpoints:
|
||||
click.echo(f"No checkpoints found for run {run_id} matching filters")
|
||||
return
|
||||
|
||||
paths_to_delete = [c.tinker_path for c in checkpoints]
|
||||
if not yes and not _confirm_deletion(paths_to_delete, checkpoints):
|
||||
click.echo("Deletion cancelled.")
|
||||
return
|
||||
else:
|
||||
for path in checkpoint_paths:
|
||||
click.echo(f" - {path}")
|
||||
click.echo()
|
||||
|
||||
# Confirmation prompt
|
||||
click.echo("WARNING: This action is permanent and cannot be undone.")
|
||||
if not click.confirm(f"Are you sure you want to delete {count} checkpoint(s)?"):
|
||||
if not path.startswith("tinker://"):
|
||||
raise TinkerCliError(
|
||||
f"Invalid checkpoint path: {path}",
|
||||
"Checkpoint path must be in the format: tinker://run-id/weights/0001",
|
||||
)
|
||||
paths_to_delete = list(checkpoint_paths)
|
||||
if not yes and not _confirm_deletion(paths_to_delete):
|
||||
click.echo("Deletion cancelled.")
|
||||
return
|
||||
|
||||
# Create client and delete with progress bar
|
||||
client = create_rest_client()
|
||||
|
||||
with click.progressbar(
|
||||
checkpoint_paths,
|
||||
label="Deleting checkpoints",
|
||||
show_percent=True,
|
||||
show_pos=True,
|
||||
hidden=cli_context.format != "table",
|
||||
) as bar:
|
||||
for path in bar:
|
||||
client.delete_checkpoint_from_tinker_path(path).result()
|
||||
_delete_paths(client, paths_to_delete, cli_context.format)
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
|
@ -15,6 +16,7 @@ import tinker
|
|||
from tinker import types
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
|
||||
from tinker.lib.sidecar import SidecarHandle, SidecarRPC, create_sidecar_handle
|
||||
from tinker.lib.telemetry import Telemetry, capture_exceptions
|
||||
from tinker.lib.telemetry_provider import TelemetryProvider
|
||||
|
||||
|
|
@ -33,6 +35,56 @@ logger = logging.getLogger(__name__)
|
|||
U = TypeVar("U")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pickle serialization state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _SamplingClientPickleState:
|
||||
"""Serialized state for pickling SamplingClient across processes."""
|
||||
|
||||
session_id: str
|
||||
sampling_session_id: str
|
||||
constructor_kwargs: dict[str, Any]
|
||||
subprocess_sampling: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Typed RPCs for subprocess-isolated sampling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _SampleRPC(SidecarRPC):
|
||||
"""Typed RPC for SamplingClient.sample()."""
|
||||
|
||||
prompt: types.ModelInput
|
||||
num_samples: int
|
||||
sampling_params: types.SamplingParams
|
||||
include_prompt_logprobs: bool
|
||||
topk_prompt_logprobs: int
|
||||
|
||||
async def execute(self, target: Any) -> Any:
|
||||
return target.sample(
|
||||
prompt=self.prompt,
|
||||
num_samples=self.num_samples,
|
||||
sampling_params=self.sampling_params,
|
||||
include_prompt_logprobs=self.include_prompt_logprobs,
|
||||
topk_prompt_logprobs=self.topk_prompt_logprobs,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _ComputeLogprobsRPC(SidecarRPC):
|
||||
"""Typed RPC for SamplingClient.compute_logprobs()."""
|
||||
|
||||
prompt: types.ModelInput
|
||||
|
||||
async def execute(self, target: Any) -> Any:
|
||||
return target.compute_logprobs(prompt=self.prompt)
|
||||
|
||||
|
||||
class SamplingClient(TelemetryProvider, QueueStateObserver):
|
||||
"""Client for text generation and inference from trained or base models.
|
||||
|
||||
|
|
@ -65,6 +117,12 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
If you are using Tinker SDK with more than one process you should always create SamplingClient from
|
||||
the main process and then pass it to the other processes/workers.
|
||||
ServiceClient and TrainingClient should always be managed from the main process.
|
||||
|
||||
Subprocess isolation:
|
||||
Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated
|
||||
subprocess, preventing GIL contention from CPU-heavy user code (grading, environment
|
||||
interactions) from stalling networking IO and heartbeats. This is transparent — the same
|
||||
API works with or without it.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -74,6 +132,7 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
sampling_session_id: str,
|
||||
shadow: bool = False,
|
||||
retry_config: RetryConfig | None = None,
|
||||
subprocess_sampling: bool | None = None,
|
||||
):
|
||||
self.holder = holder
|
||||
|
||||
|
|
@ -97,6 +156,20 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
# We use 1B as the base and mod for uuid because the maximum int value is 2^63-1 and 1B*1B is less than 2^63-1.
|
||||
self._request_id_counter = 1_000_000_000 * (int(uuid.uuid4()) % 1_000_000_000 + 1)
|
||||
|
||||
# Subprocess isolation: read env var if not explicitly set
|
||||
if subprocess_sampling is None:
|
||||
subprocess_sampling = os.environ.get("TINKER_SUBPROCESS_SAMPLING", "").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
self._sampling_client_sidecar_handle: SidecarHandle | None = None
|
||||
if subprocess_sampling:
|
||||
from tinker.lib.sidecar import _inside_sidecar
|
||||
|
||||
if not _inside_sidecar:
|
||||
self._sampling_client_sidecar_handle = create_sidecar_handle(self)
|
||||
|
||||
@staticmethod
|
||||
async def _create_impl(
|
||||
holder: InternalClientHolder,
|
||||
|
|
@ -237,6 +310,16 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
print(tokenizer.decode(sample.tokens))
|
||||
```
|
||||
"""
|
||||
if self._sampling_client_sidecar_handle is not None:
|
||||
return self._sampling_client_sidecar_handle.submit_rpc(
|
||||
_SampleRPC(
|
||||
prompt=prompt,
|
||||
num_samples=num_samples,
|
||||
sampling_params=sampling_params,
|
||||
include_prompt_logprobs=include_prompt_logprobs,
|
||||
topk_prompt_logprobs=topk_prompt_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
async def _sample_async():
|
||||
return await self._sample_async_impl(
|
||||
|
|
@ -294,6 +377,10 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
print(f"Token {i}: logprob = {logprob:.4f}")
|
||||
```
|
||||
"""
|
||||
if self._sampling_client_sidecar_handle is not None:
|
||||
return self._sampling_client_sidecar_handle.submit_rpc(
|
||||
_ComputeLogprobsRPC(prompt=prompt)
|
||||
)
|
||||
|
||||
async def _compute_logprobs_async() -> list[float | None]:
|
||||
sample_res = await self._sample_async_impl(
|
||||
|
|
@ -349,18 +436,24 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
def get_telemetry(self) -> Telemetry | None:
|
||||
return self.holder.get_telemetry()
|
||||
|
||||
def __reduce__(self) -> tuple[Any, tuple[str, str, dict[str, Any]]]:
|
||||
def __reduce__(self) -> tuple[Any, tuple[_SamplingClientPickleState]]:
|
||||
"""Enable pickling of SamplingClient for subprocess use.
|
||||
|
||||
Stores the sampling_session_id and holder constructor kwargs.
|
||||
On unpickle, creates a shadow holder and reconstructs the client.
|
||||
Serializes into a ``_SamplingClientPickleState`` dataclass. The
|
||||
``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a
|
||||
bool flag is stored. The unpickled copy creates its own handle via
|
||||
the per-process sidecar singleton. Do not add ``__getstate__``
|
||||
without preserving this behavior.
|
||||
"""
|
||||
return (
|
||||
_unpickle_sampling_client,
|
||||
(
|
||||
self.holder.get_session_id(),
|
||||
self._sampling_session_id,
|
||||
self.holder._constructor_kwargs,
|
||||
_SamplingClientPickleState(
|
||||
session_id=self.holder.get_session_id(),
|
||||
sampling_session_id=self._sampling_session_id,
|
||||
constructor_kwargs=self.holder._constructor_kwargs,
|
||||
subprocess_sampling=self._sampling_client_sidecar_handle is not None,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -386,21 +479,21 @@ class SamplingClient(TelemetryProvider, QueueStateObserver):
|
|||
)
|
||||
|
||||
|
||||
def _unpickle_sampling_client(
|
||||
session_id: str,
|
||||
sampling_session_id: str,
|
||||
constructor_kwargs: dict[str, Any],
|
||||
) -> SamplingClient:
|
||||
"""Reconstruct a SamplingClient from pickled data.
|
||||
def _unpickle_sampling_client(state: _SamplingClientPickleState) -> SamplingClient:
|
||||
"""Reconstruct a SamplingClient from pickled state.
|
||||
|
||||
Creates a shadow InternalClientHolder and builds a new SamplingClient.
|
||||
The request_id_counter starts at a random high value to avoid collisions.
|
||||
Subprocess enablement is handled by the constructor.
|
||||
"""
|
||||
from ..internal_client_holder import InternalClientHolder
|
||||
|
||||
holder = InternalClientHolder.get_shadow_holder(session_id, constructor_kwargs)
|
||||
client = SamplingClient(holder, sampling_session_id=sampling_session_id, shadow=True)
|
||||
return client
|
||||
holder = InternalClientHolder.get_shadow_holder(state.session_id, state.constructor_kwargs)
|
||||
return SamplingClient(
|
||||
holder,
|
||||
sampling_session_id=state.sampling_session_id,
|
||||
shadow=True,
|
||||
subprocess_sampling=state.subprocess_sampling,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import logging
|
|||
import threading
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple
|
||||
|
||||
from tinker import types
|
||||
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
|
||||
|
|
@ -49,6 +49,11 @@ MODEL_ID_NOT_SET_ERROR = "model_id must be set before calling forward. Try initi
|
|||
# Args: (data: List[Datum], model_outputs: List[Any]) -> (loss: Any, metrics: Dict[str, float])
|
||||
CustomLossFnV1 = Callable[[List[types.Datum], List[Any]], Tuple[Any, Dict[str, float]]]
|
||||
|
||||
_SUPPORTED_CUSTOM_BACKEND_LOSS_FNS = frozenset({"cross_entropy"})
|
||||
_CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE: dict[Literal["logprobs"], types.LossFnType] = {
|
||||
"logprobs": "cross_entropy",
|
||||
}
|
||||
|
||||
|
||||
class TrainingClient(TelemetryProvider):
|
||||
"""Client for training ML models with forward/backward passes and optimization.
|
||||
|
|
@ -331,7 +336,11 @@ class TrainingClient(TelemetryProvider):
|
|||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward_custom(
|
||||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||
self,
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1,
|
||||
*,
|
||||
loss_type_input: Literal["logprobs"] = "logprobs",
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
"""Compute forward/backward with a custom loss function.
|
||||
|
||||
|
|
@ -341,6 +350,7 @@ class TrainingClient(TelemetryProvider):
|
|||
Args:
|
||||
- `data`: List of training data samples
|
||||
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
|
||||
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the forward/backward outputs with custom loss
|
||||
|
|
@ -360,23 +370,49 @@ class TrainingClient(TelemetryProvider):
|
|||
```
|
||||
"""
|
||||
return self.holder.run_coroutine_threadsafe(
|
||||
self.forward_backward_custom_async(data, loss_fn)
|
||||
self.forward_backward_custom_async(
|
||||
data,
|
||||
loss_fn,
|
||||
loss_type_input=loss_type_input,
|
||||
)
|
||||
).result()
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
async def forward_backward_custom_async(
|
||||
self, data: List[types.Datum], loss_fn: CustomLossFnV1
|
||||
self,
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1,
|
||||
*,
|
||||
loss_type_input: Literal["logprobs"] = "logprobs",
|
||||
) -> APIFuture[types.ForwardBackwardOutput]:
|
||||
"""Async version of forward_backward_custom."""
|
||||
if torch is None:
|
||||
raise ImportError("PyTorch is not installed. Cannot run custom forward_backward.")
|
||||
|
||||
if loss_type_input not in _CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE:
|
||||
supported = ", ".join(sorted(_CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE))
|
||||
raise ValueError(
|
||||
f"Unsupported loss_type_input={loss_type_input!r}. "
|
||||
f"Supported values are: {supported}"
|
||||
)
|
||||
|
||||
surrogate_loss_fn = _CUSTOM_BACKEND_LOSS_FN_BY_INPUT_TYPE[loss_type_input]
|
||||
|
||||
forward_data = self._get_custom_loss_forward_data(data, surrogate_loss_fn)
|
||||
|
||||
# First do a forward pass and get logprobs
|
||||
forward_future = await self.forward_async(data, "cross_entropy")
|
||||
forward_future = await self.forward_async(
|
||||
forward_data,
|
||||
surrogate_loss_fn,
|
||||
None,
|
||||
)
|
||||
forward_result = await forward_future.result_async()
|
||||
logprobs_list = []
|
||||
for out in forward_result.loss_fn_outputs:
|
||||
logprob = torch.tensor(out["logprobs"].data).clone().detach().requires_grad_(True)
|
||||
logprob = torch.tensor(out["logprobs"].data)
|
||||
if out["logprobs"].shape is not None:
|
||||
logprob = logprob.reshape(out["logprobs"].shape)
|
||||
logprob = logprob.clone().detach().requires_grad_(True)
|
||||
logprobs_list.append(logprob)
|
||||
|
||||
# Now apply user-provided function
|
||||
|
|
@ -392,7 +428,9 @@ class TrainingClient(TelemetryProvider):
|
|||
for datum, grad in zip(data, grads, strict=True):
|
||||
loss_fn_inputs: Any = {
|
||||
"target_tokens": datum.loss_fn_inputs["target_tokens"],
|
||||
"weights": -grad, # Pass PyTorch tensor directly (will be converted to TensorData)
|
||||
# Backend CE is L = sum(-logprobs * weights), so to backpropagate a
|
||||
# client-side custom loss C(logprobs) we must send weights = -dC/dlogprobs.
|
||||
"weights": -grad,
|
||||
}
|
||||
linear_loss_data.append(
|
||||
types.Datum(
|
||||
|
|
@ -402,7 +440,11 @@ class TrainingClient(TelemetryProvider):
|
|||
)
|
||||
|
||||
# Do the backward pass with the gradients
|
||||
backward_future = await self.forward_backward_async(linear_loss_data, "cross_entropy")
|
||||
backward_future = await self.forward_backward_async(
|
||||
linear_loss_data,
|
||||
surrogate_loss_fn,
|
||||
None,
|
||||
)
|
||||
|
||||
# We need to slightly modify the future to add the custom metrics, so we use _CombinedAPIFuture
|
||||
# to transform the future.
|
||||
|
|
@ -415,6 +457,49 @@ class TrainingClient(TelemetryProvider):
|
|||
|
||||
return _CombinedAPIFuture([backward_future], add_custom_metrics, self.holder)
|
||||
|
||||
def _get_custom_loss_forward_data(
|
||||
self,
|
||||
data: List[types.Datum],
|
||||
surrogate_loss_fn: types.LossFnType,
|
||||
) -> List[types.Datum]:
|
||||
assert surrogate_loss_fn in _SUPPORTED_CUSTOM_BACKEND_LOSS_FNS, (
|
||||
"forward_backward_custom_async should validate surrogate_loss_fn before "
|
||||
"_get_custom_loss_forward_data is called"
|
||||
)
|
||||
|
||||
forward_data = []
|
||||
for datum in data:
|
||||
target_tokens = datum.loss_fn_inputs.get("target_tokens")
|
||||
if target_tokens is None:
|
||||
raise ValueError("target_tokens must be provided when using cross_entropy")
|
||||
|
||||
unexpected_keys = sorted(set(datum.loss_fn_inputs) - {"target_tokens", "weights"})
|
||||
if unexpected_keys:
|
||||
raise ValueError(
|
||||
"forward_backward_custom only supports loss_fn_inputs keys "
|
||||
"{'target_tokens', 'weights'}; "
|
||||
f"found unexpected keys: {unexpected_keys}"
|
||||
)
|
||||
|
||||
if "weights" in datum.loss_fn_inputs:
|
||||
forward_data.append(datum)
|
||||
continue
|
||||
|
||||
forward_loss_fn_inputs = dict(datum.loss_fn_inputs)
|
||||
forward_loss_fn_inputs["weights"] = types.TensorData(
|
||||
data=[0.0] * len(target_tokens.data),
|
||||
dtype="float32",
|
||||
shape=target_tokens.shape,
|
||||
)
|
||||
forward_data.append(
|
||||
types.Datum(
|
||||
model_input=datum.model_input,
|
||||
loss_fn_inputs=forward_loss_fn_inputs,
|
||||
)
|
||||
)
|
||||
|
||||
return forward_data
|
||||
|
||||
@capture_exceptions(fatal=True)
|
||||
def optim_step(self, adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]:
|
||||
"""Update model parameters using Adam optimizer.
|
||||
|
|
|
|||
|
|
@ -44,13 +44,30 @@ class Datum(StrictBase):
|
|||
|
||||
@classmethod
|
||||
def _maybe_convert_array(cls, key: str, value: Any) -> Any:
|
||||
"""Convert torch.Tensor, numpy array, or 1-D list to TensorData if needed."""
|
||||
"""Convert torch.Tensor, numpy array, or numeric lists to TensorData if needed."""
|
||||
if _HAVE_TORCH and isinstance(value, torch.Tensor):
|
||||
return TensorData.from_torch(value)
|
||||
elif isinstance(value, np.ndarray):
|
||||
return TensorData.from_numpy(value)
|
||||
elif isinstance(value, list):
|
||||
# assume it's 1d and infer the dtype from the key
|
||||
try:
|
||||
array = np.asarray(value)
|
||||
except ValueError as exc:
|
||||
if any(isinstance(item, list) for item in value):
|
||||
raise ValueError(
|
||||
f"{key} must be a rectangular numeric array; ragged nested lists are not supported"
|
||||
) from exc
|
||||
raise
|
||||
if array.dtype.kind in ("f", "i", "u"):
|
||||
if _key_to_type[key] == "int64":
|
||||
array = array.astype(np.int64)
|
||||
else:
|
||||
array = array.astype(np.float32)
|
||||
return TensorData.from_numpy(array)
|
||||
if any(isinstance(item, list) for item in value):
|
||||
raise ValueError(
|
||||
f"{key} must be a rectangular numeric array; ragged nested lists are not supported"
|
||||
)
|
||||
return TensorData(data=value, dtype=_key_to_type[key], shape=[len(value)])
|
||||
else:
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -2,4 +2,10 @@ from typing_extensions import Literal, TypeAlias
|
|||
|
||||
__all__ = ["LossFnType"]
|
||||
|
||||
LossFnType: TypeAlias = Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"]
|
||||
LossFnType: TypeAlias = Literal[
|
||||
"cross_entropy",
|
||||
"importance_sampling",
|
||||
"ppo",
|
||||
"cispo",
|
||||
"dro",
|
||||
]
|
||||
|
|
|
|||
196
tests/test_checkpoint_delete.py
Normal file
196
tests/test_checkpoint_delete.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
"""Tests for bulk checkpoint deletion: CLI flags and date parsing."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from tinker.cli.commands.checkpoint import _filter_checkpoints, _parse_date
|
||||
|
||||
|
||||
class TestParseDate:
|
||||
"""Tests for the _parse_date ISO 8601 parser."""
|
||||
|
||||
def test_date_only(self) -> None:
|
||||
dt = _parse_date("2024-01-15")
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
assert dt.tzinfo is not None
|
||||
|
||||
def test_datetime_with_z(self) -> None:
|
||||
dt = _parse_date("2024-06-01T12:00:00Z")
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 6
|
||||
assert dt.hour == 12
|
||||
assert dt.tzinfo is not None
|
||||
|
||||
def test_datetime_with_offset(self) -> None:
|
||||
dt = _parse_date("2024-06-01T12:00:00+00:00")
|
||||
assert dt.year == 2024
|
||||
assert dt.tzinfo is not None
|
||||
|
||||
def test_datetime_naive_gets_utc(self) -> None:
|
||||
dt = _parse_date("2024-06-01T12:00:00")
|
||||
assert dt.tzinfo is not None
|
||||
|
||||
def test_whitespace_stripped(self) -> None:
|
||||
dt = _parse_date(" 2024-01-01 ")
|
||||
assert dt.year == 2024
|
||||
|
||||
def test_invalid_raises(self) -> None:
|
||||
from tinker.cli.exceptions import TinkerCliError
|
||||
|
||||
with pytest.raises(TinkerCliError):
|
||||
_parse_date("not-a-date")
|
||||
|
||||
def test_invalid_format_raises(self) -> None:
|
||||
from tinker.cli.exceptions import TinkerCliError
|
||||
|
||||
with pytest.raises(TinkerCliError):
|
||||
_parse_date("01/15/2024")
|
||||
|
||||
|
||||
class TestFilterCheckpoints:
|
||||
"""Tests for the _filter_checkpoints function."""
|
||||
|
||||
@pytest.fixture()
|
||||
def sample_checkpoints(self):
|
||||
from tinker.types.checkpoint import Checkpoint
|
||||
|
||||
now = datetime.now(UTC)
|
||||
return [
|
||||
Checkpoint(
|
||||
checkpoint_id="weights/0001",
|
||||
checkpoint_type="training",
|
||||
time=now - timedelta(days=10),
|
||||
tinker_path="tinker://run-1/weights/0001",
|
||||
size_bytes=1000,
|
||||
),
|
||||
Checkpoint(
|
||||
checkpoint_id="weights/0002",
|
||||
checkpoint_type="training",
|
||||
time=now - timedelta(days=3),
|
||||
tinker_path="tinker://run-1/weights/0002",
|
||||
size_bytes=2000,
|
||||
),
|
||||
Checkpoint(
|
||||
checkpoint_id="sampler_weights/0001",
|
||||
checkpoint_type="sampler",
|
||||
time=now - timedelta(days=10),
|
||||
tinker_path="tinker://run-1/sampler_weights/0001",
|
||||
size_bytes=500,
|
||||
),
|
||||
Checkpoint(
|
||||
checkpoint_id="weights/0003",
|
||||
checkpoint_type="training",
|
||||
time=now - timedelta(hours=1),
|
||||
tinker_path="tinker://run-1/weights/0003",
|
||||
size_bytes=3000,
|
||||
),
|
||||
]
|
||||
|
||||
def test_no_filters(self, sample_checkpoints) -> None:
|
||||
result = _filter_checkpoints(sample_checkpoints, None, None, None)
|
||||
assert len(result) == 4
|
||||
|
||||
def test_filter_by_weights_type(self, sample_checkpoints) -> None:
|
||||
result = _filter_checkpoints(sample_checkpoints, "weights", None, None)
|
||||
assert len(result) == 3
|
||||
assert all(c.checkpoint_type == "training" for c in result)
|
||||
|
||||
def test_filter_by_sampler_weights_type(self, sample_checkpoints) -> None:
|
||||
result = _filter_checkpoints(sample_checkpoints, "sampler_weights", None, None)
|
||||
assert len(result) == 1
|
||||
assert result[0].checkpoint_type == "sampler"
|
||||
|
||||
def test_filter_before(self, sample_checkpoints) -> None:
|
||||
# Before 7 days ago → only the 10-day-old checkpoints
|
||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
||||
result = _filter_checkpoints(sample_checkpoints, None, cutoff, None)
|
||||
assert len(result) == 2
|
||||
assert all("0001" in c.checkpoint_id for c in result)
|
||||
|
||||
def test_filter_after(self, sample_checkpoints) -> None:
|
||||
# After 7 days ago → the 3-day-old and 1-hour-old checkpoints
|
||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
||||
result = _filter_checkpoints(sample_checkpoints, None, None, cutoff)
|
||||
assert len(result) == 2
|
||||
paths = {c.tinker_path for c in result}
|
||||
assert "tinker://run-1/weights/0002" in paths
|
||||
assert "tinker://run-1/weights/0003" in paths
|
||||
|
||||
def test_filter_date_range(self, sample_checkpoints) -> None:
|
||||
# Between 5 and 2 days ago → only the 3-day-old checkpoint
|
||||
after_dt = datetime.now(UTC) - timedelta(days=5)
|
||||
before_dt = datetime.now(UTC) - timedelta(days=2)
|
||||
result = _filter_checkpoints(sample_checkpoints, None, before_dt, after_dt)
|
||||
assert len(result) == 1
|
||||
assert result[0].tinker_path == "tinker://run-1/weights/0002"
|
||||
|
||||
def test_filter_by_type_and_before(self, sample_checkpoints) -> None:
|
||||
cutoff = datetime.now(UTC) - timedelta(days=7)
|
||||
result = _filter_checkpoints(sample_checkpoints, "weights", cutoff, None)
|
||||
assert len(result) == 1
|
||||
assert result[0].tinker_path == "tinker://run-1/weights/0001"
|
||||
|
||||
def test_invalid_type_raises(self, sample_checkpoints) -> None:
|
||||
from tinker.cli.exceptions import TinkerCliError
|
||||
|
||||
with pytest.raises(TinkerCliError):
|
||||
_filter_checkpoints(sample_checkpoints, "invalid_type", None, None)
|
||||
|
||||
|
||||
class TestDeleteCLIValidation:
|
||||
"""Tests for CLI delete command argument validation."""
|
||||
|
||||
def _get_error_message(self, result) -> str:
|
||||
"""Get error message from either output or exception."""
|
||||
if result.output:
|
||||
return result.output
|
||||
if result.exception:
|
||||
return str(result.exception)
|
||||
return ""
|
||||
|
||||
def test_no_args_shows_error(self) -> None:
|
||||
from tinker.cli.commands.checkpoint import cli
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["delete"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
def test_paths_and_run_id_conflict(self) -> None:
|
||||
from tinker.cli.commands.checkpoint import cli
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--run-id", "run-1"])
|
||||
assert result.exit_code != 0
|
||||
assert "Cannot specify both" in self._get_error_message(result)
|
||||
|
||||
def test_type_without_run_id_error(self) -> None:
|
||||
from tinker.cli.commands.checkpoint import cli
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["delete", "tinker://run-1/weights/0001", "--type", "weights"])
|
||||
assert result.exit_code != 0
|
||||
assert "--run-id" in self._get_error_message(result)
|
||||
|
||||
def test_before_without_run_id_error(self) -> None:
|
||||
from tinker.cli.commands.checkpoint import cli
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli, ["delete", "tinker://run-1/weights/0001", "--before", "2024-01-01"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "--run-id" in self._get_error_message(result)
|
||||
|
||||
def test_after_without_run_id_error(self) -> None:
|
||||
from tinker.cli.commands.checkpoint import cli
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli, ["delete", "tinker://run-1/weights/0001", "--after", "2024-01-01"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "--run-id" in self._get_error_message(result)
|
||||
384
tests/test_subprocess_sampling_client.py
Normal file
384
tests/test_subprocess_sampling_client.py
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
"""Tests for SamplingClient subprocess mode.
|
||||
|
||||
These tests use a picklable fake SamplingClient to verify that
|
||||
subprocess mode correctly routes sample() and compute_logprobs()
|
||||
through the sidecar subprocess.
|
||||
|
||||
Test organization:
|
||||
TestRPCRouting — sample/compute_logprobs delegation through sidecar
|
||||
TestErrorHandling — error propagation, sidecar death
|
||||
TestPickle — roundtrip with/without sidecar, re-enable mode
|
||||
TestConcurrency — multithreaded, async, cancelled futures, mixed ops
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future as ConcurrentFuture
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tinker import types
|
||||
from tinker._exceptions import SidecarDiedError
|
||||
from tinker.lib.sidecar import create_sidecar_handle
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Picklable fake SamplingClient (must be module-level for pickling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeSamplingClient:
|
||||
"""A picklable fake that simulates SamplingClient for testing.
|
||||
|
||||
This is NOT a real SamplingClient — it provides just enough interface
|
||||
to test the sidecar integration. Real SamplingClient requires an
|
||||
InternalClientHolder and API connection.
|
||||
"""
|
||||
|
||||
def __init__(self, delay: float = 0.0, fail: bool = False, subprocess_sampling: bool = False):
|
||||
self._delay = delay
|
||||
self._fail = fail
|
||||
self._sampling_client_sidecar_handle = None # set by create_sidecar_handle() in tests
|
||||
if subprocess_sampling:
|
||||
from tinker.lib.sidecar import _inside_sidecar
|
||||
|
||||
if not _inside_sidecar:
|
||||
self._sampling_client_sidecar_handle = create_sidecar_handle(self)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
prompt: types.ModelInput,
|
||||
num_samples: int,
|
||||
sampling_params: types.SamplingParams,
|
||||
include_prompt_logprobs: bool = False,
|
||||
topk_prompt_logprobs: int = 0,
|
||||
) -> Any:
|
||||
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
|
||||
if self._sampling_client_sidecar_handle is not None:
|
||||
from tinker.lib.public_interfaces.sampling_client import _SampleRPC
|
||||
|
||||
return self._sampling_client_sidecar_handle.submit_rpc(
|
||||
_SampleRPC(
|
||||
prompt,
|
||||
num_samples,
|
||||
sampling_params,
|
||||
include_prompt_logprobs,
|
||||
topk_prompt_logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
f: ConcurrentFuture[types.SampleResponse] = ConcurrentFuture()
|
||||
if self._fail:
|
||||
f.set_exception(RuntimeError("Simulated sample failure"))
|
||||
elif self._delay > 0:
|
||||
|
||||
def _delayed():
|
||||
time.sleep(self._delay)
|
||||
f.set_result(_make_sample_response())
|
||||
|
||||
threading.Thread(target=_delayed, daemon=True).start()
|
||||
else:
|
||||
f.set_result(_make_sample_response())
|
||||
return f
|
||||
|
||||
def compute_logprobs(self, prompt: types.ModelInput) -> Any:
|
||||
# Delegate through sidecar if enabled (mirrors real SamplingClient behavior)
|
||||
if self._sampling_client_sidecar_handle is not None:
|
||||
from tinker.lib.public_interfaces.sampling_client import _ComputeLogprobsRPC
|
||||
|
||||
return self._sampling_client_sidecar_handle.submit_rpc(_ComputeLogprobsRPC(prompt))
|
||||
|
||||
f: ConcurrentFuture[list[float | None]] = ConcurrentFuture()
|
||||
if self._fail:
|
||||
f.set_exception(RuntimeError("Simulated logprobs failure"))
|
||||
else:
|
||||
f.set_result([0.1, 0.2, None])
|
||||
return f
|
||||
|
||||
def __reduce__(self) -> tuple[type, tuple[float, bool, bool]]:
|
||||
return (
|
||||
_FakeSamplingClient,
|
||||
(self._delay, self._fail, self._sampling_client_sidecar_handle is not None),
|
||||
)
|
||||
|
||||
|
||||
def _make_sample_response() -> types.SampleResponse:
|
||||
return types.SampleResponse(
|
||||
sequences=[
|
||||
types.SampledSequence(
|
||||
stop_reason="length",
|
||||
tokens=[1, 2, 3],
|
||||
logprobs=[0.1, 0.2, 0.3],
|
||||
)
|
||||
],
|
||||
type="sample",
|
||||
)
|
||||
|
||||
|
||||
def _create_proxy(delay: float = 0.0, fail: bool = False) -> _FakeSamplingClient:
|
||||
"""Create a fake client with sidecar handle for testing."""
|
||||
client = _FakeSamplingClient(delay=delay, fail=fail)
|
||||
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
|
||||
return client
|
||||
|
||||
|
||||
_PROMPT = types.ModelInput.from_ints([1, 2, 3])
|
||||
_PARAMS = types.SamplingParams(max_tokens=10)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestRPCRouting:
|
||||
"""Verify sample() and compute_logprobs() are routed through the sidecar."""
|
||||
|
||||
def test_sample(self):
|
||||
"""sample() → subprocess → SampleResponse."""
|
||||
proxy = _create_proxy()
|
||||
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(result, types.SampleResponse)
|
||||
assert result.sequences[0].tokens == [1, 2, 3]
|
||||
|
||||
def test_constructor_enables_subprocess_mode(self):
|
||||
"""subprocess_sampling=True in __init__ creates the sidecar handle."""
|
||||
client = _FakeSamplingClient(subprocess_sampling=True)
|
||||
assert client._sampling_client_sidecar_handle is not None
|
||||
result = client.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(result, types.SampleResponse)
|
||||
|
||||
def test_compute_logprobs(self):
|
||||
"""compute_logprobs() → subprocess → list of logprobs."""
|
||||
proxy = _create_proxy()
|
||||
result = proxy.compute_logprobs(_PROMPT).result(timeout=10)
|
||||
assert result == [0.1, 0.2, None]
|
||||
|
||||
def test_mixed_sample_and_logprobs(self):
|
||||
"""Interleaved sample() and compute_logprobs() all resolve correctly."""
|
||||
proxy = _create_proxy(delay=0.01)
|
||||
|
||||
futures_sample = [proxy.sample(_PROMPT, 1, _PARAMS) for _ in range(10)]
|
||||
futures_logprobs = [proxy.compute_logprobs(_PROMPT) for _ in range(10)]
|
||||
|
||||
for f in futures_sample:
|
||||
result = f.result(timeout=10)
|
||||
assert isinstance(result, types.SampleResponse)
|
||||
assert result.sequences[0].tokens == [1, 2, 3]
|
||||
|
||||
for f in futures_logprobs:
|
||||
assert f.result(timeout=10) == [0.1, 0.2, None]
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Error propagation from subprocess to caller."""
|
||||
|
||||
def test_sample_error(self):
|
||||
"""Exceptions from sample() in the subprocess are propagated."""
|
||||
proxy = _create_proxy(fail=True)
|
||||
with pytest.raises(RuntimeError, match="Simulated sample failure"):
|
||||
proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
|
||||
def test_compute_logprobs_error(self):
|
||||
"""Exceptions from compute_logprobs() in the subprocess are propagated."""
|
||||
proxy = _create_proxy(fail=True)
|
||||
with pytest.raises(RuntimeError, match="Simulated logprobs failure"):
|
||||
proxy.compute_logprobs(_PROMPT).result(timeout=10)
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
def test_sidecar_death_fails_pending_futures(self):
|
||||
"""When the subprocess is killed, pending futures get SidecarDiedError."""
|
||||
proxy = _create_proxy(delay=0.5)
|
||||
future = proxy.sample(_PROMPT, 1, _PARAMS)
|
||||
|
||||
# Kill the underlying subprocess directly
|
||||
sidecar = proxy._sampling_client_sidecar_handle._sidecar
|
||||
assert sidecar._process is not None
|
||||
sidecar._process.kill()
|
||||
sidecar._process.join(timeout=5)
|
||||
|
||||
with pytest.raises(SidecarDiedError):
|
||||
future.result(timeout=5)
|
||||
|
||||
|
||||
class TestPickle:
|
||||
"""Pickle roundtrip preserves subprocess mode correctly."""
|
||||
|
||||
def test_roundtrip_preserves_subprocess_mode(self):
|
||||
"""Pickling a sidecar-enabled client re-enables subprocess mode on unpickle."""
|
||||
proxy = _create_proxy()
|
||||
assert proxy._sampling_client_sidecar_handle is not None
|
||||
|
||||
restored = pickle.loads(pickle.dumps(proxy))
|
||||
assert restored._sampling_client_sidecar_handle is not None
|
||||
|
||||
result = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(result, types.SampleResponse)
|
||||
|
||||
def test_roundtrip_without_sidecar(self):
|
||||
"""Pickling a client without subprocess mode keeps it disabled."""
|
||||
client = _FakeSamplingClient()
|
||||
assert client._sampling_client_sidecar_handle is None
|
||||
restored = pickle.loads(pickle.dumps(client))
|
||||
assert restored._sampling_client_sidecar_handle is None
|
||||
|
||||
def test_re_enable_subprocess_mode(self):
|
||||
"""Replacing the sidecar handle works cleanly."""
|
||||
client = _FakeSamplingClient()
|
||||
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
|
||||
|
||||
# First handle works
|
||||
assert isinstance(
|
||||
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
|
||||
)
|
||||
|
||||
# Replace with a new handle (old one is GC'd and unregistered)
|
||||
client._sampling_client_sidecar_handle = create_sidecar_handle(client)
|
||||
|
||||
# New handle also works
|
||||
assert isinstance(
|
||||
client.sample(_PROMPT, 1, _PARAMS).result(timeout=10), types.SampleResponse
|
||||
)
|
||||
|
||||
|
||||
class TestConcurrency:
|
||||
"""Thread safety and concurrent operations through the sidecar."""
|
||||
|
||||
def test_multithreaded_samples(self):
|
||||
"""sample() from 20 threads all resolve correctly."""
|
||||
proxy = _create_proxy(delay=0.01)
|
||||
results: list[types.SampleResponse | None] = [None] * 20
|
||||
errors: list[Exception] = []
|
||||
|
||||
def _worker(idx: int) -> None:
|
||||
try:
|
||||
results[idx] = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=30)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=_worker, args=(i,)) for i in range(20)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=30)
|
||||
|
||||
assert not errors, f"Threads raised: {errors}"
|
||||
for r in results:
|
||||
assert isinstance(r, types.SampleResponse)
|
||||
assert r.sequences[0].tokens == [1, 2, 3]
|
||||
|
||||
def test_multithreaded_mixed_operations(self):
|
||||
"""sample() and compute_logprobs() from different threads simultaneously."""
|
||||
proxy = _create_proxy(delay=0.01)
|
||||
errors: list[Exception] = []
|
||||
|
||||
def _sample_worker() -> None:
|
||||
try:
|
||||
for _ in range(10):
|
||||
r = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(r, types.SampleResponse)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def _logprobs_worker() -> None:
|
||||
try:
|
||||
for _ in range(10):
|
||||
r = proxy.compute_logprobs(_PROMPT).result(timeout=10)
|
||||
assert r == [0.1, 0.2, None]
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=_sample_worker) for _ in range(3)]
|
||||
threads += [threading.Thread(target=_logprobs_worker) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=30)
|
||||
|
||||
assert not errors, f"Errors: {errors}"
|
||||
|
||||
def test_async_concurrent_samples(self):
|
||||
"""Multiple async sample calls via asyncio.gather all resolve."""
|
||||
proxy = _create_proxy(delay=0.01)
|
||||
|
||||
async def _run() -> list[types.SampleResponse]:
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
|
||||
coros = [
|
||||
AwaitableConcurrentFuture(proxy.sample(_PROMPT, 1, _PARAMS)) for _ in range(20)
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
results = asyncio.run(_run())
|
||||
assert len(results) == 20
|
||||
for r in results:
|
||||
assert isinstance(r, types.SampleResponse)
|
||||
|
||||
def test_cancelled_future_does_not_crash_collector(self):
|
||||
"""Cancelling a future doesn't kill the collector thread."""
|
||||
proxy = _create_proxy(delay=0.5)
|
||||
|
||||
future1 = proxy.sample(_PROMPT, 1, _PARAMS)
|
||||
future1.cancel()
|
||||
|
||||
result = proxy.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(result, types.SampleResponse)
|
||||
|
||||
def test_multiple_clients_share_sidecar(self):
|
||||
"""Two independent clients sharing the sidecar singleton work concurrently."""
|
||||
proxy1 = _create_proxy(delay=0.01)
|
||||
proxy2 = _create_proxy(delay=0.01)
|
||||
errors: list[Exception] = []
|
||||
|
||||
def _worker1() -> None:
|
||||
try:
|
||||
for _ in range(10):
|
||||
r = proxy1.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(r, types.SampleResponse)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def _worker2() -> None:
|
||||
try:
|
||||
for _ in range(10):
|
||||
r = proxy2.compute_logprobs(_PROMPT).result(timeout=10)
|
||||
assert r == [0.1, 0.2, None]
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
t1 = threading.Thread(target=_worker1)
|
||||
t2 = threading.Thread(target=_worker2)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join(timeout=30)
|
||||
t2.join(timeout=30)
|
||||
|
||||
assert not errors, f"Errors: {errors}"
|
||||
|
||||
def test_pickle_roundtrip_then_concurrent_use(self):
|
||||
"""Pickle a client, restore it, then use from multiple threads."""
|
||||
proxy = _create_proxy(delay=0.01)
|
||||
restored = pickle.loads(pickle.dumps(proxy))
|
||||
assert restored._sampling_client_sidecar_handle is not None
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def _worker() -> None:
|
||||
try:
|
||||
for _ in range(10):
|
||||
r = restored.sample(_PROMPT, 1, _PARAMS).result(timeout=10)
|
||||
assert isinstance(r, types.SampleResponse)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=_worker) for _ in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=30)
|
||||
|
||||
assert not errors, f"Errors: {errors}"
|
||||
234
tests/test_training_client_custom_loss.py
Normal file
234
tests/test_training_client_custom_loss.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tinker import types
|
||||
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
|
||||
from tinker.lib.public_interfaces.training_client import TrainingClient
|
||||
|
||||
|
||||
class _DummyHolder:
|
||||
def run_coroutine_threadsafe(self, coro):
|
||||
future: Future = Future()
|
||||
future.set_result(coro)
|
||||
return future
|
||||
|
||||
def get_telemetry(self):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeTrainingClient(TrainingClient):
|
||||
def __init__(self):
|
||||
self.holder = _DummyHolder()
|
||||
self.forward_calls: list[
|
||||
tuple[list[types.Datum], types.LossFnType, dict[str, float] | None]
|
||||
] = []
|
||||
self.backward_calls: list[
|
||||
tuple[list[types.Datum], types.LossFnType, dict[str, float] | None]
|
||||
] = []
|
||||
|
||||
async def forward_async(
|
||||
self,
|
||||
data: list[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: dict[str, float] | None = None,
|
||||
):
|
||||
self.forward_calls.append((data, loss_fn, loss_fn_config))
|
||||
result = types.ForwardBackwardOutput(
|
||||
metrics={},
|
||||
loss_fn_output_type="target_token_logprobs",
|
||||
loss_fn_outputs=[
|
||||
{
|
||||
"logprobs": types.TensorData(
|
||||
data=[-3.0, -2.0, -1.0, 0.0],
|
||||
dtype="float32",
|
||||
shape=[2, 2],
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
future: Future = Future()
|
||||
future.set_result(result)
|
||||
return AwaitableConcurrentFuture(future)
|
||||
|
||||
async def forward_backward_async(
|
||||
self,
|
||||
data: list[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: dict[str, float] | None = None,
|
||||
):
|
||||
self.backward_calls.append((data, loss_fn, loss_fn_config))
|
||||
result = types.ForwardBackwardOutput(
|
||||
metrics={"base:sum": 1.0},
|
||||
loss_fn_output_type="target_token_logprobs",
|
||||
loss_fn_outputs=[],
|
||||
)
|
||||
future: Future = Future()
|
||||
future.set_result(result)
|
||||
return AwaitableConcurrentFuture(future)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_backward_custom_supports_2d_cross_entropy_targets():
|
||||
client = _FakeTrainingClient()
|
||||
datum = types.Datum(
|
||||
model_input=types.ModelInput.from_ints([1, 2]),
|
||||
loss_fn_inputs={
|
||||
"target_tokens": [[101, 102], [201, 202]],
|
||||
},
|
||||
)
|
||||
|
||||
assert datum.loss_fn_inputs["target_tokens"].shape == [2, 2]
|
||||
|
||||
def custom_loss(
|
||||
data: list[types.Datum], logprobs_list: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
del data
|
||||
logprobs = logprobs_list[0]
|
||||
assert logprobs.shape == (2, 2)
|
||||
probs = torch.softmax(logprobs[1], dim=-1)
|
||||
target_distribution = torch.tensor([0.0, 1.0], dtype=torch.float32)
|
||||
loss = torch.sum((probs - target_distribution) ** 2)
|
||||
return loss, {"selected_prob:mean": float(probs[1].detach())}
|
||||
|
||||
result_future = await client.forward_backward_custom_async(
|
||||
[datum],
|
||||
custom_loss,
|
||||
loss_type_input="logprobs",
|
||||
)
|
||||
result = await result_future.result_async()
|
||||
|
||||
assert client.forward_calls[0][1] == "cross_entropy"
|
||||
forward_datum = client.forward_calls[0][0][0]
|
||||
assert forward_datum.loss_fn_inputs["weights"].shape == [2, 2]
|
||||
|
||||
assert client.backward_calls[0][1] == "cross_entropy"
|
||||
backward_datum = client.backward_calls[0][0][0]
|
||||
assert backward_datum.loss_fn_inputs["target_tokens"].shape == [2, 2]
|
||||
assert backward_datum.loss_fn_inputs["weights"].shape == [2, 2]
|
||||
assert "weights" not in datum.loss_fn_inputs
|
||||
assert result.metrics["selected_prob:mean"] > 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_backward_custom_preserves_1d_cross_entropy_targets():
|
||||
client = _FakeTrainingClient()
|
||||
datum = types.Datum(
|
||||
model_input=types.ModelInput.from_ints([1, 2]),
|
||||
loss_fn_inputs={"target_tokens": [101, 102]},
|
||||
)
|
||||
|
||||
async def forward_async_1d(
|
||||
data: list[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
loss_fn_config: dict[str, float] | None = None,
|
||||
):
|
||||
client.forward_calls.append((data, loss_fn, loss_fn_config))
|
||||
result = types.ForwardBackwardOutput(
|
||||
metrics={},
|
||||
loss_fn_output_type="target_token_logprobs",
|
||||
loss_fn_outputs=[
|
||||
{
|
||||
"logprobs": types.TensorData(
|
||||
data=[-3.0, -1.0],
|
||||
dtype="float32",
|
||||
shape=[2],
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
future: Future = Future()
|
||||
future.set_result(result)
|
||||
return AwaitableConcurrentFuture(future)
|
||||
|
||||
setattr(client, "forward_async", forward_async_1d)
|
||||
|
||||
def custom_loss(
|
||||
data: list[types.Datum], logprobs_list: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
del data
|
||||
logprobs = logprobs_list[0]
|
||||
assert logprobs.shape == (2,)
|
||||
loss = -logprobs[-1]
|
||||
return loss, {"selected_logprob:last": float(logprobs[-1].detach())}
|
||||
|
||||
result_future = await client.forward_backward_custom_async(
|
||||
[datum],
|
||||
custom_loss,
|
||||
loss_type_input="logprobs",
|
||||
)
|
||||
result = await result_future.result_async()
|
||||
|
||||
assert client.forward_calls[0][1] == "cross_entropy"
|
||||
forward_datum = client.forward_calls[0][0][0]
|
||||
assert forward_datum.loss_fn_inputs["weights"].shape == [2]
|
||||
|
||||
assert client.backward_calls[0][1] == "cross_entropy"
|
||||
backward_datum = client.backward_calls[0][0][0]
|
||||
assert backward_datum.loss_fn_inputs["target_tokens"].shape == [2]
|
||||
assert backward_datum.loss_fn_inputs["weights"].shape == [2]
|
||||
torch.testing.assert_close(
|
||||
torch.tensor(backward_datum.loss_fn_inputs["weights"].data).reshape(
|
||||
backward_datum.loss_fn_inputs["weights"].shape
|
||||
),
|
||||
torch.tensor([0.0, 1.0], dtype=torch.float32),
|
||||
)
|
||||
assert result.metrics["selected_logprob:last"] < 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_backward_custom_rejects_unsupported_loss_type_input():
|
||||
client = _FakeTrainingClient()
|
||||
datum = types.Datum(
|
||||
model_input=types.ModelInput.from_ints([1, 2]),
|
||||
loss_fn_inputs={"target_tokens": [101, 102]},
|
||||
)
|
||||
|
||||
def custom_loss(
|
||||
data: list[types.Datum], logprobs_list: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
del data, logprobs_list
|
||||
return torch.tensor(0.0, requires_grad=True), {}
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported loss_type_input"):
|
||||
await client.forward_backward_custom_async(
|
||||
[datum],
|
||||
custom_loss,
|
||||
loss_type_input="logits", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def test_datum_rejects_ragged_nested_target_tokens():
|
||||
with pytest.raises(ValueError, match="ragged nested lists are not supported"):
|
||||
types.Datum(
|
||||
model_input=types.ModelInput.from_ints([1, 2]),
|
||||
loss_fn_inputs={"target_tokens": [[101, 102], [201]]},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_backward_custom_rejects_unexpected_loss_fn_input_keys():
|
||||
client = _FakeTrainingClient()
|
||||
datum = types.Datum(
|
||||
model_input=types.ModelInput.from_ints([1, 2]),
|
||||
loss_fn_inputs={
|
||||
"target_tokens": [101, 102],
|
||||
"advantages": [1.0, 1.0],
|
||||
},
|
||||
)
|
||||
|
||||
def custom_loss(
|
||||
data: list[types.Datum], logprobs_list: list[torch.Tensor]
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
del data, logprobs_list
|
||||
return torch.tensor(0.0, requires_grad=True), {}
|
||||
|
||||
with pytest.raises(ValueError, match="only supports loss_fn_inputs keys"):
|
||||
await client.forward_backward_custom_async(
|
||||
[datum],
|
||||
custom_loss,
|
||||
loss_type_input="logprobs",
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue