mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-05-01 17:45:18 +00:00
Sync contents
This commit is contained in:
parent
35a77e79fe
commit
f3c0b1f179
17 changed files with 2069 additions and 632 deletions
|
|
@ -121,6 +121,38 @@ class InternalServerError(APIStatusError)
|
|||
|
||||
HTTP 500+: An error occurred on the server.
|
||||
|
||||
## `SidecarError` Objects
|
||||
|
||||
```python
|
||||
class SidecarError(TinkerError)
|
||||
```
|
||||
|
||||
Base exception for subprocess sidecar errors.
|
||||
|
||||
## `SidecarStartupError` Objects
|
||||
|
||||
```python
|
||||
class SidecarStartupError(SidecarError)
|
||||
```
|
||||
|
||||
Raised when the sidecar subprocess fails to start or times out.
|
||||
|
||||
## `SidecarDiedError` Objects
|
||||
|
||||
```python
|
||||
class SidecarDiedError(SidecarError)
|
||||
```
|
||||
|
||||
Raised when the sidecar subprocess exits unexpectedly while requests are pending.
|
||||
|
||||
## `SidecarIPCError` Objects
|
||||
|
||||
```python
|
||||
class SidecarIPCError(SidecarError)
|
||||
```
|
||||
|
||||
Raised when communication with the sidecar subprocess fails.
|
||||
|
||||
## `RequestFailedError` Objects
|
||||
|
||||
```python
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ Key methods:
|
|||
- get_checkpoint_archive_url() - get signed URL to download checkpoint archive
|
||||
- publish_checkpoint_from_tinker_path() - publish a checkpoint to make it public
|
||||
- unpublish_checkpoint_from_tinker_path() - unpublish a checkpoint to make it private
|
||||
- set_checkpoint_ttl_from_tinker_path() - set or remove TTL on a checkpoint
|
||||
|
||||
Args:
|
||||
- `holder`: Internal client managing HTTP connections and async operations
|
||||
|
|
@ -39,7 +40,9 @@ for checkpoint in checkpoints.checkpoints:
|
|||
|
||||
```python
|
||||
def get_training_run(
|
||||
training_run_id: types.ModelID) -> ConcurrentFuture[types.TrainingRun]
|
||||
training_run_id: types.ModelID,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.TrainingRun]
|
||||
```
|
||||
|
||||
Get training run info.
|
||||
|
|
@ -61,7 +64,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
|
|||
|
||||
```python
|
||||
async def get_training_run_async(
|
||||
training_run_id: types.ModelID) -> types.TrainingRun
|
||||
training_run_id: types.ModelID,
|
||||
access_scope: Literal["owned",
|
||||
"accessible"] = "owned") -> types.TrainingRun
|
||||
```
|
||||
|
||||
Async version of get_training_run.
|
||||
|
|
@ -70,7 +75,9 @@ Async version of get_training_run.
|
|||
|
||||
```python
|
||||
def get_training_run_by_tinker_path(
|
||||
tinker_path: str) -> ConcurrentFuture[types.TrainingRun]
|
||||
tinker_path: str,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.TrainingRun]
|
||||
```
|
||||
|
||||
Get training run info.
|
||||
|
|
@ -92,7 +99,9 @@ print(f"Training Run ID: {response.training_run_id}, Base: {response.base_model}
|
|||
|
||||
```python
|
||||
async def get_training_run_by_tinker_path_async(
|
||||
tinker_path: str) -> types.TrainingRun
|
||||
tinker_path: str,
|
||||
access_scope: Literal["owned",
|
||||
"accessible"] = "owned") -> types.TrainingRun
|
||||
```
|
||||
|
||||
Async version of get_training_run_by_tinker_path.
|
||||
|
|
@ -123,8 +132,10 @@ print(f"Base Model: {response.base_model}, LoRA Rank: {response.lora_rank}")
|
|||
|
||||
```python
|
||||
def list_training_runs(
|
||||
limit: int = 20,
|
||||
offset: int = 0) -> ConcurrentFuture[types.TrainingRunsResponse]
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.TrainingRunsResponse]
|
||||
```
|
||||
|
||||
List training runs with pagination support.
|
||||
|
|
@ -149,9 +160,11 @@ next_page = rest_client.list_training_runs(limit=50, offset=50)
|
|||
#### `list_training_runs_async`
|
||||
|
||||
```python
|
||||
async def list_training_runs_async(limit: int = 20,
|
||||
offset: int = 0
|
||||
) -> types.TrainingRunsResponse
|
||||
async def list_training_runs_async(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> types.TrainingRunsResponse
|
||||
```
|
||||
|
||||
Async version of list_training_runs.
|
||||
|
|
@ -367,6 +380,46 @@ async def unpublish_checkpoint_from_tinker_path_async(
|
|||
|
||||
Async version of unpublish_checkpoint_from_tinker_path.
|
||||
|
||||
#### `set_checkpoint_ttl_from_tinker_path`
|
||||
|
||||
```python
|
||||
def set_checkpoint_ttl_from_tinker_path(
|
||||
tinker_path: str, ttl_seconds: int | None) -> ConcurrentFuture[None]
|
||||
```
|
||||
|
||||
Set or remove the TTL on a checkpoint referenced by a tinker path.
|
||||
|
||||
If ttl_seconds is provided, the checkpoint will expire after that many seconds from now.
|
||||
If ttl_seconds is None, any existing expiration will be removed.
|
||||
|
||||
Args:
|
||||
- `tinker_path`: The tinker path to the checkpoint (e.g., "tinker://run-id/weights/0001")
|
||||
- `ttl_seconds`: Number of seconds until expiration, or None to remove TTL
|
||||
|
||||
Returns:
|
||||
- A `Future` that completes when the TTL is set
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if checkpoint identifier is invalid or ttl_seconds <= 0
|
||||
HTTPException: 404 if checkpoint not found or user doesn't own the training run
|
||||
HTTPException: 500 if there's an error setting the TTL
|
||||
|
||||
Example:
|
||||
```python
|
||||
future = rest_client.set_checkpoint_ttl_from_tinker_path("tinker://run-id/weights/0001", 86400)
|
||||
future.result() # Wait for completion
|
||||
print("Checkpoint TTL set successfully")
|
||||
```
|
||||
|
||||
#### `set_checkpoint_ttl_from_tinker_path_async`
|
||||
|
||||
```python
|
||||
async def set_checkpoint_ttl_from_tinker_path_async(
|
||||
tinker_path: str, ttl_seconds: int | None) -> None
|
||||
```
|
||||
|
||||
Async version of set_checkpoint_ttl_from_tinker_path.
|
||||
|
||||
#### `list_user_checkpoints`
|
||||
|
||||
```python
|
||||
|
|
@ -414,7 +467,10 @@ Async version of list_user_checkpoints.
|
|||
#### `get_session`
|
||||
|
||||
```python
|
||||
def get_session(session_id: str) -> ConcurrentFuture[types.GetSessionResponse]
|
||||
def get_session(
|
||||
session_id: str,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.GetSessionResponse]
|
||||
```
|
||||
|
||||
Get session information including all training runs and samplers.
|
||||
|
|
@ -436,7 +492,10 @@ print(f"Samplers: {len(response.sampler_ids)}")
|
|||
#### `get_session_async`
|
||||
|
||||
```python
|
||||
async def get_session_async(session_id: str) -> types.GetSessionResponse
|
||||
async def get_session_async(
|
||||
session_id: str,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> types.GetSessionResponse
|
||||
```
|
||||
|
||||
Async version of get_session.
|
||||
|
|
@ -445,8 +504,10 @@ Async version of get_session.
|
|||
|
||||
```python
|
||||
def list_sessions(
|
||||
limit: int = 20,
|
||||
offset: int = 0) -> ConcurrentFuture[types.ListSessionsResponse]
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> ConcurrentFuture[types.ListSessionsResponse]
|
||||
```
|
||||
|
||||
List sessions with pagination support.
|
||||
|
|
@ -470,8 +531,11 @@ next_page = rest_client.list_sessions(limit=50, offset=50)
|
|||
#### `list_sessions_async`
|
||||
|
||||
```python
|
||||
async def list_sessions_async(limit: int = 20,
|
||||
offset: int = 0) -> types.ListSessionsResponse
|
||||
async def list_sessions_async(
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
access_scope: Literal["owned", "accessible"] = "owned"
|
||||
) -> types.ListSessionsResponse
|
||||
```
|
||||
|
||||
Async version of list_sessions.
|
||||
|
|
|
|||
|
|
@ -30,6 +30,20 @@ future = sampling_client.sample(prompt=prompt, sampling_params=params, num_sampl
|
|||
result = future.result()
|
||||
```
|
||||
|
||||
Multi-processing support:
|
||||
This class is picklable, so it can be passed to a separate process/worker to sample. It is also
|
||||
safe to pass the same instance of SamplingClient to multiple processes/workers.
|
||||
|
||||
If you are using Tinker SDK with more than one process you should always create SamplingClient from
|
||||
the main process and then pass it to the other processes/workers.
|
||||
ServiceClient and TrainingClient should always be managed from the main process.
|
||||
|
||||
Subprocess isolation:
|
||||
Set ``TINKER_SUBPROCESS_SAMPLING=1`` to run sample() and compute_logprobs() in a dedicated
|
||||
subprocess, preventing GIL contention from CPU-heavy user code (grading, environment
|
||||
interactions) from stalling networking IO and heartbeats. This is transparent — the same
|
||||
API works with or without it.
|
||||
|
||||
#### `sample`
|
||||
|
||||
```python
|
||||
|
|
@ -121,3 +135,33 @@ Get the tokenizer for the current model.
|
|||
|
||||
Returns:
|
||||
- `PreTrainedTokenizer` compatible with the model
|
||||
|
||||
#### `get_base_model`
|
||||
|
||||
```python
|
||||
def get_base_model() -> str
|
||||
```
|
||||
|
||||
Get the base model name for the current sampling session.
|
||||
|
||||
#### `get_base_model_async`
|
||||
|
||||
```python
|
||||
async def get_base_model_async() -> str
|
||||
```
|
||||
|
||||
Async version of get_base_model.
|
||||
|
||||
#### `__reduce__`
|
||||
|
||||
```python
|
||||
def __reduce__() -> tuple[Any, tuple[_SamplingClientPickleState]]
|
||||
```
|
||||
|
||||
Enable pickling of SamplingClient for subprocess use.
|
||||
|
||||
Serializes into a ``_SamplingClientPickleState`` dataclass. The
|
||||
``_sampling_client_sidecar_handle`` handle is deliberately omitted — only a
|
||||
bool flag is stored. The unpickled copy creates its own handle via
|
||||
the per-process sidecar singleton. Do not add ``__getstate__``
|
||||
without preserving this behavior.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ TrainingClient for Tinker API.
|
|||
## `TrainingClient` Objects
|
||||
|
||||
```python
|
||||
class TrainingClient(TelemetryProvider, QueueStateObserver)
|
||||
class TrainingClient(TelemetryProvider)
|
||||
```
|
||||
|
||||
Client for training ML models with forward/backward passes and optimization.
|
||||
|
|
@ -127,8 +127,11 @@ Async version of forward_backward.
|
|||
|
||||
```python
|
||||
def forward_backward_custom(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1,
|
||||
*,
|
||||
loss_type_input: Literal["logprobs"] = "logprobs"
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Compute forward/backward with a custom loss function.
|
||||
|
|
@ -139,6 +142,7 @@ The custom function receives logprobs and computes loss and gradients.
|
|||
Args:
|
||||
- `data`: List of training data samples
|
||||
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
|
||||
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the forward/backward outputs with custom loss
|
||||
|
|
@ -161,8 +165,11 @@ print(f"Metrics: {result.metrics}")
|
|||
|
||||
```python
|
||||
async def forward_backward_custom_async(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1,
|
||||
*,
|
||||
loss_type_input: Literal["logprobs"] = "logprobs"
|
||||
) -> APIFuture[types.ForwardBackwardOutput]
|
||||
```
|
||||
|
||||
Async version of forward_backward_custom.
|
||||
|
|
|
|||
1130
docs/api/types.md
1130
docs/api/types.md
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue