mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-22 16:49:00 +00:00
Sync contents
This commit is contained in:
parent
ad03d44978
commit
ca40e08bb4
12 changed files with 358 additions and 27 deletions
|
|
@ -176,6 +176,11 @@ def optim_step(
|
|||
|
||||
Update model parameters using Adam optimizer.
|
||||
|
||||
The Adam optimizer used by tinker is identical
|
||||
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
||||
|
||||
|
||||
Args:
|
||||
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||
|
||||
|
|
@ -212,13 +217,17 @@ Async version of optim_step.
|
|||
#### `save_state`
|
||||
|
||||
```python
|
||||
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
def save_state(
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Save model weights to persistent storage.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved checkpoint
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with checkpoint path
|
||||
|
|
@ -234,7 +243,10 @@ print(f"Saved to: {result.path}")
|
|||
#### `save_state_async`
|
||||
|
||||
```python
|
||||
async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
async def save_state_async(
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Async version of save_state.
|
||||
|
|
@ -310,13 +322,16 @@ Async version of load_state_with_optimizer.
|
|||
|
||||
```python
|
||||
def save_weights_for_sampler(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Save model weights for use with a SamplingClient.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved sampler weights
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with sampler path
|
||||
|
|
@ -338,7 +353,9 @@ sampling_client = service_client.create_sampling_client(
|
|||
|
||||
```python
|
||||
async def save_weights_for_sampler_async(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Async version of save_weights_for_sampler.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue