mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-29 17:35:09 +00:00
Sync contents
This commit is contained in:
parent
937c36e9b1
commit
9bf0df6796
9 changed files with 5 additions and 179 deletions
|
|
@ -1,5 +1,3 @@
|
|||
# `tinker.lib.public_interfaces.training_client`
|
||||
|
||||
TrainingClient for Tinker API.
|
||||
|
||||
## `TrainingClient` Objects
|
||||
|
|
@ -34,7 +32,6 @@ sampling_client = training_client.save_weights_and_get_sampling_client("my-model
|
|||
#### `forward`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward(
|
||||
data: List[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
|
|
@ -78,7 +75,6 @@ Async version of forward.
|
|||
#### `forward_backward`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward(
|
||||
data: List[types.Datum],
|
||||
loss_fn: types.LossFnType,
|
||||
|
|
@ -130,8 +126,6 @@ Async version of forward_backward.
|
|||
#### `forward_backward_custom`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def forward_backward_custom(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
|
|
@ -166,7 +160,6 @@ print(f"Metrics: {result.metrics}")
|
|||
#### `forward_backward_custom_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def forward_backward_custom_async(
|
||||
data: List[types.Datum],
|
||||
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
||||
|
|
@ -177,7 +170,6 @@ Async version of forward_backward_custom.
|
|||
#### `optim_step`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def optim_step(
|
||||
adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]
|
||||
```
|
||||
|
|
@ -220,7 +212,6 @@ Async version of optim_step.
|
|||
#### `save_state`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
|
|
@ -251,7 +242,6 @@ Async version of save_state.
|
|||
#### `load_state`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state(path: str) -> APIFuture[types.LoadWeightsResponse]
|
||||
```
|
||||
|
||||
|
|
@ -282,7 +272,6 @@ Async version of load_state.
|
|||
#### `load_state_with_optimizer`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def load_state_with_optimizer(
|
||||
path: str) -> APIFuture[types.LoadWeightsResponse]
|
||||
```
|
||||
|
|
@ -317,7 +306,6 @@ Async version of load_state_with_optimizer.
|
|||
#### `save_weights_for_sampler`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_for_sampler(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
|
@ -355,8 +343,6 @@ Async version of save_weights_for_sampler.
|
|||
#### `get_info`
|
||||
|
||||
```python
|
||||
@sync_only
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_info() -> types.GetInfoResponse
|
||||
```
|
||||
|
||||
|
|
@ -376,7 +362,6 @@ print(f"LoRA rank: {info.model_data.lora_rank}")
|
|||
#### `get_info_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def get_info_async() -> types.GetInfoResponse
|
||||
```
|
||||
|
||||
|
|
@ -385,8 +370,6 @@ Async version of get_info.
|
|||
#### `get_tokenizer`
|
||||
|
||||
```python
|
||||
@cache
|
||||
@capture_exceptions(fatal=True)
|
||||
def get_tokenizer() -> PreTrainedTokenizer
|
||||
```
|
||||
|
||||
|
|
@ -405,7 +388,6 @@ text = tokenizer.decode(tokens)
|
|||
#### `create_sampling_client`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def create_sampling_client(
|
||||
model_path: str,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
|
|
@ -431,7 +413,6 @@ sampling_client = training_client.create_sampling_client(
|
|||
#### `create_sampling_client_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def create_sampling_client_async(
|
||||
model_path: str,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
|
|
@ -442,7 +423,6 @@ Async version of create_sampling_client.
|
|||
#### `save_weights_and_get_sampling_client`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
def save_weights_and_get_sampling_client(
|
||||
name: str | None = None,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
|
|
@ -471,7 +451,6 @@ result = sampling_client.sample(prompt, 1, params).result()
|
|||
#### `save_weights_and_get_sampling_client_async`
|
||||
|
||||
```python
|
||||
@capture_exceptions(fatal=True)
|
||||
async def save_weights_and_get_sampling_client_async(
|
||||
name: str | None = None,
|
||||
retry_config: RetryConfig | None = None) -> SamplingClient
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue