Sync contents

This commit is contained in:
Dylan Huang 2026-03-19 00:10:49 +00:00
parent 35a77e79fe
commit f3c0b1f179
17 changed files with 2069 additions and 632 deletions

View file

@ -3,7 +3,7 @@ TrainingClient for Tinker API.
## `TrainingClient` Objects
```python
class TrainingClient(TelemetryProvider, QueueStateObserver)
class TrainingClient(TelemetryProvider)
```
Client for training ML models with forward/backward passes and optimization.
@ -127,8 +127,11 @@ Async version of forward_backward.
```python
def forward_backward_custom(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs"
) -> APIFuture[types.ForwardBackwardOutput]
```
Compute forward/backward with a custom loss function.
@ -139,6 +142,7 @@ The custom function receives logprobs and computes loss and gradients.
Args:
- `data`: List of training data samples
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
- `loss_type_input`: Input space for `loss_fn`. Currently the only supported value is `"logprobs"`.
Returns:
- `APIFuture` containing the forward/backward outputs with custom loss
@ -161,8 +165,11 @@ print(f"Metrics: {result.metrics}")
```python
async def forward_backward_custom_async(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
data: List[types.Datum],
loss_fn: CustomLossFnV1,
*,
loss_type_input: Literal["logprobs"] = "logprobs"
) -> APIFuture[types.ForwardBackwardOutput]
```
Async version of forward_backward_custom.