Sync contents

This commit is contained in:
Daniel Xu 2025-11-25 06:15:14 +00:00
parent 937c36e9b1
commit 9bf0df6796
9 changed files with 5 additions and 179 deletions

View file

@ -1,5 +1,3 @@
# `tinker.lib.public_interfaces.training_client`
TrainingClient for Tinker API.
## `TrainingClient` Objects
@ -34,7 +32,6 @@ sampling_client = training_client.save_weights_and_get_sampling_client("my-model
#### `forward`
```python
@capture_exceptions(fatal=True)
def forward(
data: List[types.Datum],
loss_fn: types.LossFnType,
@ -78,7 +75,6 @@ Async version of forward.
#### `forward_backward`
```python
@capture_exceptions(fatal=True)
def forward_backward(
data: List[types.Datum],
loss_fn: types.LossFnType,
@ -130,8 +126,6 @@ Async version of forward_backward.
#### `forward_backward_custom`
```python
@sync_only
@capture_exceptions(fatal=True)
def forward_backward_custom(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
@ -166,7 +160,6 @@ print(f"Metrics: {result.metrics}")
#### `forward_backward_custom_async`
```python
@capture_exceptions(fatal=True)
async def forward_backward_custom_async(
data: List[types.Datum],
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
@ -177,7 +170,6 @@ Async version of forward_backward_custom.
#### `optim_step`
```python
@capture_exceptions(fatal=True)
def optim_step(
adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]
```
@ -220,7 +212,6 @@ Async version of optim_step.
#### `save_state`
```python
@capture_exceptions(fatal=True)
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
```
@ -251,7 +242,6 @@ Async version of save_state.
#### `load_state`
```python
@capture_exceptions(fatal=True)
def load_state(path: str) -> APIFuture[types.LoadWeightsResponse]
```
@ -282,7 +272,6 @@ Async version of load_state.
#### `load_state_with_optimizer`
```python
@capture_exceptions(fatal=True)
def load_state_with_optimizer(
path: str) -> APIFuture[types.LoadWeightsResponse]
```
@ -317,7 +306,6 @@ Async version of load_state_with_optimizer.
#### `save_weights_for_sampler`
```python
@capture_exceptions(fatal=True)
def save_weights_for_sampler(
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
```
@ -355,8 +343,6 @@ Async version of save_weights_for_sampler.
#### `get_info`
```python
@sync_only
@capture_exceptions(fatal=True)
def get_info() -> types.GetInfoResponse
```
@ -376,7 +362,6 @@ print(f"LoRA rank: {info.model_data.lora_rank}")
#### `get_info_async`
```python
@capture_exceptions(fatal=True)
async def get_info_async() -> types.GetInfoResponse
```
@ -385,8 +370,6 @@ Async version of get_info.
#### `get_tokenizer`
```python
@cache
@capture_exceptions(fatal=True)
def get_tokenizer() -> PreTrainedTokenizer
```
@ -405,7 +388,6 @@ text = tokenizer.decode(tokens)
#### `create_sampling_client`
```python
@capture_exceptions(fatal=True)
def create_sampling_client(
model_path: str,
retry_config: RetryConfig | None = None) -> SamplingClient
@ -431,7 +413,6 @@ sampling_client = training_client.create_sampling_client(
#### `create_sampling_client_async`
```python
@capture_exceptions(fatal=True)
async def create_sampling_client_async(
model_path: str,
retry_config: RetryConfig | None = None) -> SamplingClient
@ -442,7 +423,6 @@ Async version of create_sampling_client.
#### `save_weights_and_get_sampling_client`
```python
@capture_exceptions(fatal=True)
def save_weights_and_get_sampling_client(
name: str | None = None,
retry_config: RetryConfig | None = None) -> SamplingClient
@ -471,7 +451,6 @@ result = sampling_client.sample(prompt, 1, params).result()
#### `save_weights_and_get_sampling_client_async`
```python
@capture_exceptions(fatal=True)
async def save_weights_and_get_sampling_client_async(
name: str | None = None,
retry_config: RetryConfig | None = None) -> SamplingClient