mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-24 17:04:56 +00:00
Sync contents
This commit is contained in:
parent
ad03d44978
commit
ca40e08bb4
12 changed files with 358 additions and 27 deletions
|
|
@ -110,3 +110,14 @@ async def compute_logprobs_async(
|
|||
```
|
||||
|
||||
Async version of compute_logprobs.
|
||||
|
||||
#### `get_tokenizer`
|
||||
|
||||
```python
|
||||
def get_tokenizer() -> PreTrainedTokenizer
|
||||
```
|
||||
|
||||
Get the tokenizer for the current model.
|
||||
|
||||
Returns:
|
||||
- `PreTrainedTokenizer` compatible with the model
|
||||
|
|
|
|||
|
|
@ -176,6 +176,11 @@ def optim_step(
|
|||
|
||||
Update model parameters using Adam optimizer.
|
||||
|
||||
The Adam optimizer used by tinker is identical
|
||||
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
||||
|
||||
|
||||
Args:
|
||||
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
||||
|
||||
|
|
@ -212,13 +217,17 @@ Async version of optim_step.
|
|||
#### `save_state`
|
||||
|
||||
```python
|
||||
def save_state(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
def save_state(
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Save model weights to persistent storage.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved checkpoint
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with checkpoint path
|
||||
|
|
@ -234,7 +243,10 @@ print(f"Saved to: {result.path}")
|
|||
#### `save_state_async`
|
||||
|
||||
```python
|
||||
async def save_state_async(name: str) -> APIFuture[types.SaveWeightsResponse]
|
||||
async def save_state_async(
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsResponse]
|
||||
```
|
||||
|
||||
Async version of save_state.
|
||||
|
|
@ -310,13 +322,16 @@ Async version of load_state_with_optimizer.
|
|||
|
||||
```python
|
||||
def save_weights_for_sampler(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Save model weights for use with a SamplingClient.
|
||||
|
||||
Args:
|
||||
- `name`: Name for the saved sampler weights
|
||||
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
||||
|
||||
Returns:
|
||||
- `APIFuture` containing the save response with sampler path
|
||||
|
|
@ -338,7 +353,9 @@ sampling_client = service_client.create_sampling_client(
|
|||
|
||||
```python
|
||||
async def save_weights_for_sampler_async(
|
||||
name: str) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
name: str,
|
||||
ttl_seconds: int | None = None
|
||||
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
||||
```
|
||||
|
||||
Async version of save_weights_for_sampler.
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Weight decay for the optimizer. Uses decoupled weight decay.
|
|||
|
||||
#### `grad_clip_norm`
|
||||
|
||||
Gradient clip norm for the optimizer. 0.0 means no clipping.
|
||||
Maximum global gradient norm. If the global gradient norm is greater than this value, it will be clipped to this value. 0.0 means no clipping.
|
||||
|
||||
## `SupportedModel` Objects
|
||||
|
||||
|
|
@ -159,6 +159,10 @@ The size of the checkpoint in bytes
|
|||
|
||||
Whether the checkpoint is publicly accessible
|
||||
|
||||
#### `expires_at`
|
||||
|
||||
When this checkpoint expires (None = never expires)
|
||||
|
||||
## `ParsedCheckpointTinkerPath` Objects
|
||||
|
||||
```python
|
||||
|
|
@ -725,6 +729,10 @@ class SaveWeightsRequest(StrictBase)
|
|||
|
||||
A file/directory name for the weights
|
||||
|
||||
#### `ttl_seconds`
|
||||
|
||||
TTL in seconds for this checkpoint (None = never expires)
|
||||
|
||||
## `LoraConfig` Objects
|
||||
|
||||
```python
|
||||
|
|
@ -834,6 +842,10 @@ class SaveWeightsForSamplerRequest(StrictBase)
|
|||
|
||||
A file/directory name for the weights
|
||||
|
||||
#### `ttl_seconds`
|
||||
|
||||
TTL in seconds for this checkpoint (None = never expires)
|
||||
|
||||
## `SamplingParams` Objects
|
||||
|
||||
```python
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue