Sync contents

This commit is contained in:
Dylan Huang 2026-03-19 00:10:49 +00:00
parent 35a77e79fe
commit f3c0b1f179
17 changed files with 2069 additions and 632 deletions

View file

@ -121,6 +121,38 @@ class InternalServerError(APIStatusError)
HTTP 500+: An error occurred on the server.
## `SidecarError` Objects
```python
class SidecarError(TinkerError)
```
Base exception for subprocess sidecar errors.
## `SidecarStartupError` Objects
```python
class SidecarStartupError(SidecarError)
```
Raised when the sidecar subprocess fails to start or times out.
## `SidecarDiedError` Objects
```python
class SidecarDiedError(SidecarError)
```
Raised when the sidecar subprocess exits unexpectedly while requests are pending.
## `SidecarIPCError` Objects
```python
class SidecarIPCError(SidecarError)
```
Raised when communication with the sidecar subprocess fails.
## `RequestFailedError` Objects
```python

View file

@ -20,6 +20,7 @@ Key methods:
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
- set_checkpoint_ttl_from_tinker_path() - set or remove TTL on a checkpoint
Args:
- `holder`: Internal client managing HTTP connections and async operations
@ -39,7 +40,9 @@ for checkpoint in checkpoints.checkpoints:
```python
def get_training_run(
training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]
training_run_id: types.ModelID,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRun]
```
Get training run info.
@ -61,7 +64,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
```python
async def get_training_run_async(
training_run_id: types.ModelID) -> types.TrainingRun
training_run_id: types.ModelID,
access_scope: Literal["owned",
"accessible"] = "owned") -> types.TrainingRun
```
Async version of get_training_run.
@ -70,7 +75,9 @@ Async version of get_training_run.
```python
def get_training_run_by_tinker_path(
tinker_path: str) -> ConcurrentFuture[types.TrainingRun]
tinker_path: str,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRun]
```
Get training run info.
@ -92,7 +99,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
```python
async def get_training_run_by_tinker_path_async(
tinker_path: str) -> types.TrainingRun
tinker_path: str,
access_scope: Literal["owned",
"accessible"] = "owned") -> types.TrainingRun
```
Async version of get_training_run_by_tinker_path.
@ -123,8 +132,10 @@ print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
```python
def list_training_runs(
limit: int = 20,
offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse]
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRunsResponse]
```
List training runs with pagination support.
@ -149,9 +160,11 @@ next_page = rest_client.list_training_runs(limit=50, offset=50)
#### `list_training_runs_async`
```python
async def list_training_runs_async(limit: int = 20,
offset: int = 0
) -> types.TrainingRunsResponse
async def list_training_runs_async(
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> types.TrainingRunsResponse
```
Async version of list_training_runs.
@ -367,6 +380,46 @@ async def unpublish_checkpoint_from_tinker_path_async(
Async version of unpublish_checkpoint_from_tinker_path.
#### `set_checkpoint_ttl_from_tinker_path`
```python
def set_checkpoint_ttl_from_tinker_path(
tinker_path: str, ttl_seconds: int | None) -> ConcurrentFuture[None]
```
Set or remove the TTL on a checkpoint referenced by a tinker path.
If ttl_seconds is provided, the checkpoint will expire after that many seconds from now.
If ttl_seconds is None, any existing expiration will be removed.
Args:
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
- `ttl_seconds`: Number of seconds until expiration, or None to remove TTL
Returns:
- A `Future` that completes when the TTL is set
Raises:
HTTPException: 400 if checkpoint identifier is invalid or ttl_seconds <= 0
HTTPException: 404 if checkpoint not found or user doesn't own the training run
HTTPException: 500 if there's an error setting the TTL
Example:
```python
future = rest_client.set_checkpoint_ttl_from_tinker_path("tinker://run-id/weights/0001", 86400)
future.result() # Wait for completion
print("Checkpoint TTL set successfully")
```
#### `set_checkpoint_ttl_from_tinker_path_async`
```python
async def set_checkpoint_ttl_from_tinker_path_async(
tinker_path: str, ttl_seconds: int | None) -> None
```
Async version of set_checkpoint_ttl_from_tinker_path.
#### `list_user_checkpoints`
```python
@ -414,7 +467,10 @@ Async version of list_user_checkpoints.
#### `get_session`
```python
def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse]
def get_session(
session_id: str,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.GetSessionResponse]
```
Get session information including all training runs and samplers.
@ -436,7 +492,10 @@ print(f"Samplers: {len(response.sampler_ids)}")
#### `get_session_async`
```python
async def get_session_async(session_id: str) -> types.GetSessionResponse
async def get_session_async(
session_id: str,
access_scope: Literal["owned", "accessible"] = "owned"
) -> types.GetSessionResponse
```
Async version of get_session.
@ -445,8 +504,10 @@ Async version of get_session.
```python
def list_sessions(
limit: int = 20,
offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse]
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.ListSessionsResponse]
```
List sessions with pagination support.
@ -470,8 +531,11 @@ next_page = rest_client.list_sessions(limit=50, offset=50)
#### `list_sessions_async`
```python
async def list_sessions_async(limit: int = 20,
offset: int = 0) -> types.ListSessionsResponse
async def list_sessions_async(
limit: int = 20,
offset: int = 0,
access_scope: Literal["owned", "accessible"] = "owned"
) -> types.ListSessionsResponse
```
Async version of list_sessions.

View file

@ -30,6 +30,20 @@ future = sampling_client.sample(prompt=prompt, sampling_params=params, num_sampl
result = future.result()
```
Multi-processing support:
This class is picklable, so it can be passed to a separate process/worker to sample. It is also
safe to pass the same instance of SamplingClient to multiple processes/workers.
If you are using Tinker SDK with more than one process you should always create SamplingClient from
the main process and then pass it to the other processes/workers.
ServiceClient and TrainingClient should always be managed from the main process.
Subprocess isolation:
Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated
subprocess, preventing GIL contention from CPU-heavy user code (grading, environment
interactions) from stalling networking IO and heartbeats. This is transparent — the same
API works with or without it.
#### `sample`
```python
@ -121,3 +135,33 @@ Get the tokenizer for the current model.
Returns:
- `PreTrainedTokenizer` compatible with the model
#### `get_base_model`
```python
def get_base_model() -> str
```
Get the base model name for the current sampling session.
#### `get_base_model_async`
```python
async def get_base_model_async() -> str
```
Async version of get_base_model.
#### `__reduce__`
```python
def __reduce__() -> tuple[Any, tuple[_SamplingClientPickleState]]
```
Enable pickling of SamplingClient for subprocess use.
Serializes into a ``_SamplingClientPickleState`` dataclass. The
``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a
bool flag is stored. The unpickled copy creates its own handle via
the per-process sidecar singleton. Do not add ``__getstate__``
without preserving this behavior.

View file

@ -3,7 +3,7 @@ TrainingClient for Tinker API.
## `TrainingClient` Objects
```python
class TrainingClient(TelemetryProvider, QueueStateObserver)
class TrainingClient(TelemetryProvider)
```
Client for training ML models with forward/backward passes and optimization.
@ -127,8 +127,11 @@ Async version of forward_backward.
```python
def forward_backward_custom(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs"
) -> APIFuture[types.ForwardBackwardOutput]
```
Compute forward/backward with a custom loss function.
@ -139,6 +142,7 @@ The custom function receives logprobs and computes loss and gradients.
Args:
- `data`: List of training data samples
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
Returns:
- `APIFuture` containing the forward/backward outputs with custom loss
@ -161,8 +165,11 @@ print(f"Metrics: {result.metrics}")
```python
async def forward_backward_custom_async(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs"
) -> APIFuture[types.ForwardBackwardOutput]
```
Async version of forward_backward_custom.

File diff suppressed because it is too large Load diff