mirror of
https://github.com/thinking-machines-lab/tinker.git
synced 2026-04-19 12:58:01 +00:00
479 lines
12 KiB
Markdown
479 lines
12 KiB
Markdown
TrainingClient for Tinker API.
|
|
|
|
## `TrainingClient` Objects
|
|
|
|
```python
|
|
class TrainingClient(TelemetryProvider, QueueStateObserver)
|
|
```
|
|
|
|
Client for training ML models with forward/backward passes and optimization.
|
|
|
|
The TrainingClient corresponds to a fine-tuned model that you can train and sample from.
|
|
You typically get one by calling `service_client.create_lora_training_client()`.
|
|
Key methods:
|
|
- forward_backward() - compute gradients for training
|
|
- optim_step() - update model parameters with Adam optimizer
|
|
- save_weights_and_get_sampling_client() - export trained model for inference
|
|
|
|
Args:
|
|
- `holder`: Internal client managing HTTP connections and async operations
|
|
- `model_id`: Unique identifier for the model to train. Required for training operations.
|
|
|
|
Example:
|
|
```python
|
|
training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
|
|
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
|
|
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
|
|
fwdbwd_result = fwdbwd_future.result() # Wait for gradients
|
|
optim_result = optim_future.result() # Wait for parameter update
|
|
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")
|
|
```
|
|
|
|
#### `forward`
|
|
|
|
```python
|
|
def forward(
|
|
data: List[types.Datum],
|
|
loss_fn: types.LossFnType,
|
|
loss_fn_config: Dict[str, float] | None = None
|
|
) -> APIFuture[types.ForwardBackwardOutput]
|
|
```
|
|
|
|
Compute forward pass without gradients.
|
|
|
|
Args:
|
|
- `data`: List of training data samples
|
|
- `loss_fn`: Loss function type (e.g., "cross_entropy")
|
|
- `loss_fn_config`: Optional configuration for the loss function
|
|
|
|
Returns:
|
|
- `APIFuture` containing the forward pass outputs and loss
|
|
|
|
Example:
|
|
```python
|
|
data = [types.Datum(
|
|
model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
|
|
loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
|
|
)]
|
|
future = training_client.forward(data, "cross_entropy")
|
|
result = await future
|
|
print(f"Loss: {result.loss}")
|
|
```
|
|
|
|
#### `forward_async`
|
|
|
|
```python
|
|
async def forward_async(
|
|
data: List[types.Datum],
|
|
loss_fn: types.LossFnType,
|
|
loss_fn_config: Dict[str, float] | None = None
|
|
) -> APIFuture[types.ForwardBackwardOutput]
|
|
```
|
|
|
|
Async version of forward.
|
|
|
|
#### `forward_backward`
|
|
|
|
```python
|
|
def forward_backward(
|
|
data: List[types.Datum],
|
|
loss_fn: types.LossFnType,
|
|
loss_fn_config: Dict[str, float] | None = None
|
|
) -> APIFuture[types.ForwardBackwardOutput]
|
|
```
|
|
|
|
Compute forward pass and backward pass to calculate gradients.
|
|
|
|
Args:
|
|
- `data`: List of training data samples
|
|
- `loss_fn`: Loss function type (e.g., "cross_entropy")
|
|
- `loss_fn_config`: Optional configuration for the loss function
|
|
|
|
Returns:
|
|
- `APIFuture` containing the forward/backward outputs, loss, and gradients
|
|
|
|
Example:
|
|
```python
|
|
data = [types.Datum(
|
|
model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
|
|
loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
|
|
)]
|
|
|
|
# Compute gradients
|
|
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
|
|
|
|
# Update parameters
|
|
optim_future = training_client.optim_step(
|
|
types.AdamParams(learning_rate=1e-4)
|
|
)
|
|
|
|
fwdbwd_result = await fwdbwd_future
|
|
print(f"Loss: {fwdbwd_result.loss}")
|
|
```
|
|
|
|
#### `forward_backward_async`
|
|
|
|
```python
|
|
async def forward_backward_async(
|
|
data: List[types.Datum],
|
|
loss_fn: types.LossFnType,
|
|
loss_fn_config: Dict[str, float] | None = None
|
|
) -> APIFuture[types.ForwardBackwardOutput]
|
|
```
|
|
|
|
Async version of forward_backward.
|
|
|
|
#### `forward_backward_custom`
|
|
|
|
```python
|
|
def forward_backward_custom(
|
|
data: List[types.Datum],
|
|
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
|
```
|
|
|
|
Compute forward/backward with a custom loss function.
|
|
|
|
Allows you to define custom loss functions that operate on log probabilities.
|
|
The custom function receives logprobs and computes loss and gradients.
|
|
|
|
Args:
|
|
- `data`: List of training data samples
|
|
- `loss_fn`: Custom loss function that takes (data, logprobs) and returns (loss, metrics)
|
|
|
|
Returns:
|
|
- `APIFuture` containing the forward/backward outputs with custom loss
|
|
|
|
Example:
|
|
```python
|
|
def custom_loss(data, logprobs_list):
|
|
# Custom loss computation
|
|
loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list]))
|
|
metrics = {"custom_metric": loss.item()}
|
|
return loss, metrics
|
|
|
|
future = training_client.forward_backward_custom(data, custom_loss)
|
|
result = future.result()
|
|
print(f"Custom loss: {result.loss}")
|
|
print(f"Metrics: {result.metrics}")
|
|
```
|
|
|
|
#### `forward_backward_custom_async`
|
|
|
|
```python
|
|
async def forward_backward_custom_async(
|
|
data: List[types.Datum],
|
|
loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]
|
|
```
|
|
|
|
Async version of forward_backward_custom.
|
|
|
|
#### `optim_step`
|
|
|
|
```python
|
|
def optim_step(
|
|
adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]
|
|
```
|
|
|
|
Update model parameters using Adam optimizer.
|
|
|
|
The Adam optimizer used by tinker is identical
|
|
to [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
|
Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).
|
|
|
|
|
|
Args:
|
|
- `adam_params`: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)
|
|
|
|
Returns:
|
|
- `APIFuture` containing optimizer step response
|
|
|
|
Example:
|
|
```python
|
|
# First compute gradients
|
|
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")
|
|
|
|
# Then update parameters
|
|
optim_future = training_client.optim_step(
|
|
types.AdamParams(
|
|
learning_rate=1e-4,
|
|
weight_decay=0.01
|
|
)
|
|
)
|
|
|
|
# Wait for both to complete
|
|
fwdbwd_result = await fwdbwd_future
|
|
optim_result = await optim_future
|
|
```
|
|
|
|
#### `optim_step_async`
|
|
|
|
```python
|
|
async def optim_step_async(
|
|
adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]
|
|
```
|
|
|
|
Async version of optim_step.
|
|
|
|
#### `save_state`
|
|
|
|
```python
|
|
def save_state(
|
|
name: str,
|
|
ttl_seconds: int | None = None
|
|
) -> APIFuture[types.SaveWeightsResponse]
|
|
```
|
|
|
|
Save model weights to persistent storage.
|
|
|
|
Args:
|
|
- `name`: Name for the saved checkpoint
|
|
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
|
|
|
Returns:
|
|
- `APIFuture` containing the save response with checkpoint path
|
|
|
|
Example:
|
|
```python
|
|
# Save after training
|
|
save_future = training_client.save_state("checkpoint-001")
|
|
result = await save_future
|
|
print(f"Saved to: {result.path}")
|
|
```
|
|
|
|
#### `save_state_async`
|
|
|
|
```python
|
|
async def save_state_async(
|
|
name: str,
|
|
ttl_seconds: int | None = None
|
|
) -> APIFuture[types.SaveWeightsResponse]
|
|
```
|
|
|
|
Async version of save_state.
|
|
|
|
#### `load_state`
|
|
|
|
```python
|
|
def load_state(path: str) -> APIFuture[types.LoadWeightsResponse]
|
|
```
|
|
|
|
Load model weights from a saved checkpoint.
|
|
|
|
This loads only the model weights, not optimizer state (e.g., Adam momentum).
|
|
To also restore optimizer state, use load_state_with_optimizer.
|
|
|
|
Args:
|
|
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
|
|
|
Returns:
|
|
- `APIFuture` containing the load response
|
|
|
|
Example:
|
|
```python
|
|
# Load checkpoint to continue training (weights only, optimizer resets)
|
|
load_future = training_client.load_state("tinker://run-id/weights/checkpoint-001")
|
|
await load_future
|
|
# Continue training from loaded state
|
|
```
|
|
|
|
#### `load_state_async`
|
|
|
|
```python
|
|
async def load_state_async(path: str) -> APIFuture[types.LoadWeightsResponse]
|
|
```
|
|
|
|
Async version of load_state.
|
|
|
|
#### `load_state_with_optimizer`
|
|
|
|
```python
|
|
def load_state_with_optimizer(
|
|
path: str) -> APIFuture[types.LoadWeightsResponse]
|
|
```
|
|
|
|
Load model weights and optimizer state from a checkpoint.
|
|
|
|
Args:
|
|
- `path`: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")
|
|
|
|
Returns:
|
|
- `APIFuture` containing the load response
|
|
|
|
Example:
|
|
```python
|
|
# Resume training with optimizer state
|
|
load_future = training_client.load_state_with_optimizer(
|
|
"tinker://run-id/weights/checkpoint-001"
|
|
)
|
|
await load_future
|
|
# Continue training with restored optimizer momentum
|
|
```
|
|
|
|
#### `load_state_with_optimizer_async`
|
|
|
|
```python
|
|
async def load_state_with_optimizer_async(
|
|
path: str) -> APIFuture[types.LoadWeightsResponse]
|
|
```
|
|
|
|
Async version of load_state_with_optimizer.
|
|
|
|
#### `save_weights_for_sampler`
|
|
|
|
```python
|
|
def save_weights_for_sampler(
|
|
name: str,
|
|
ttl_seconds: int | None = None
|
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
|
```
|
|
|
|
Save model weights for use with a SamplingClient.
|
|
|
|
Args:
|
|
- `name`: Name for the saved sampler weights
|
|
- `ttl_seconds`: Optional TTL in seconds for the checkpoint (None = never expires)
|
|
|
|
Returns:
|
|
- `APIFuture` containing the save response with sampler path
|
|
|
|
Example:
|
|
```python
|
|
# Save weights for inference
|
|
save_future = training_client.save_weights_for_sampler("sampler-001")
|
|
result = await save_future
|
|
print(f"Sampler weights saved to: {result.path}")
|
|
|
|
# Use the path to create a sampling client
|
|
sampling_client = service_client.create_sampling_client(
|
|
model_path=result.path
|
|
)
|
|
```
|
|
|
|
#### `save_weights_for_sampler_async`
|
|
|
|
```python
|
|
async def save_weights_for_sampler_async(
|
|
name: str,
|
|
ttl_seconds: int | None = None
|
|
) -> APIFuture[types.SaveWeightsForSamplerResponse]
|
|
```
|
|
|
|
Async version of save_weights_for_sampler.
|
|
|
|
#### `get_info`
|
|
|
|
```python
|
|
def get_info() -> types.GetInfoResponse
|
|
```
|
|
|
|
Get information about the current model.
|
|
|
|
Returns:
|
|
- `GetInfoResponse` with model configuration and metadata
|
|
|
|
Example:
|
|
```python
|
|
info = training_client.get_info()
|
|
print(f"Model ID: {info.model_data.model_id}")
|
|
print(f"Base model: {info.model_data.model_name}")
|
|
print(f"LoRA rank: {info.model_data.lora_rank}")
|
|
```
|
|
|
|
#### `get_info_async`
|
|
|
|
```python
|
|
async def get_info_async() -> types.GetInfoResponse
|
|
```
|
|
|
|
Async version of get_info.
|
|
|
|
#### `get_tokenizer`
|
|
|
|
```python
|
|
def get_tokenizer() -> PreTrainedTokenizer
|
|
```
|
|
|
|
Get the tokenizer for the current model.
|
|
|
|
Returns:
|
|
- `PreTrainedTokenizer` compatible with the model
|
|
|
|
Example:
|
|
```python
|
|
tokenizer = training_client.get_tokenizer()
|
|
tokens = tokenizer.encode("Hello world")
|
|
text = tokenizer.decode(tokens)
|
|
```
|
|
|
|
#### `create_sampling_client`
|
|
|
|
```python
|
|
def create_sampling_client(
|
|
model_path: str,
|
|
retry_config: RetryConfig | None = None) -> SamplingClient
|
|
```
|
|
|
|
Create a SamplingClient from saved weights.
|
|
|
|
Args:
|
|
- `model_path`: Tinker path to saved weights
|
|
- `retry_config`: Optional configuration for retrying failed requests
|
|
|
|
Returns:
|
|
- `SamplingClient` configured with the specified weights
|
|
|
|
Example:
|
|
```python
|
|
sampling_client = training_client.create_sampling_client(
|
|
"tinker://run-id/weights/checkpoint-001"
|
|
)
|
|
# Use sampling_client for inference
|
|
```
|
|
|
|
#### `create_sampling_client_async`
|
|
|
|
```python
|
|
async def create_sampling_client_async(
|
|
model_path: str,
|
|
retry_config: RetryConfig | None = None) -> SamplingClient
|
|
```
|
|
|
|
Async version of create_sampling_client.
|
|
|
|
#### `save_weights_and_get_sampling_client`
|
|
|
|
```python
|
|
def save_weights_and_get_sampling_client(
|
|
name: str | None = None,
|
|
retry_config: RetryConfig | None = None) -> SamplingClient
|
|
```
|
|
|
|
Save current weights and create a SamplingClient for inference.
|
|
|
|
Args:
|
|
- `name`: Optional name for the saved weights (currently ignored for ephemeral saves)
|
|
- `retry_config`: Optional configuration for retrying failed requests
|
|
|
|
Returns:
|
|
- `SamplingClient` configured with the current model weights
|
|
|
|
Example:
|
|
```python
|
|
# After training, create a sampling client directly
|
|
sampling_client = training_client.save_weights_and_get_sampling_client()
|
|
|
|
# Now use it for inference
|
|
prompt = types.ModelInput.from_ints(tokenizer.encode("Hello"))
|
|
params = types.SamplingParams(max_tokens=20)
|
|
result = sampling_client.sample(prompt, 1, params).result()
|
|
```
|
|
|
|
#### `save_weights_and_get_sampling_client_async`
|
|
|
|
```python
|
|
async def save_weights_and_get_sampling_client_async(
|
|
name: str | None = None,
|
|
retry_config: RetryConfig | None = None) -> SamplingClient
|
|
```
|
|
|
|
Async version of save_weights_and_get_sampling_client.
|