tinker/docs/api/trainingclient.md
2026-01-25 05:52:42 +00:00

12 KiB

TrainingClient for Tinker API.

TrainingClient Objects

class TrainingClient(TelemetryProvider, QueueStateObserver)

Client for training ML models with forward/backward passes and optimization.

The TrainingClient corresponds to a fine-tuned model that you can train and sample from. You typically get one by calling service_client.create_lora_training_client(). Key methods:

  • forward_backward() - compute gradients for training
  • optim_step() - update model parameters with Adam optimizer
  • save_weights_and_get_sampling_client() - export trained model for inference

Args:

  • holder: Internal client managing HTTP connections and async operations
  • model_id: Unique identifier for the model to train. Required for training operations.

Example:

training_client = service_client.create_lora_training_client(base_model="Qwen/Qwen3-8B")
fwdbwd_future = training_client.forward_backward(training_data, "cross_entropy")
optim_future = training_client.optim_step(types.AdamParams(learning_rate=1e-4))
fwdbwd_result = fwdbwd_future.result()  # Wait for gradients
optim_result = optim_future.result()    # Wait for parameter update
sampling_client = training_client.save_weights_and_get_sampling_client("my-model")

forward

def forward(
    data: List[types.Datum],
    loss_fn: types.LossFnType,
    loss_fn_config: Dict[str, float] | None = None
) -> APIFuture[types.ForwardBackwardOutput]

Compute forward pass without gradients.

Args:

  • data: List of training data samples
  • loss_fn: Loss function type (e.g., "cross_entropy")
  • loss_fn_config: Optional configuration for the loss function

Returns:

  • APIFuture containing the forward pass outputs and loss

Example:

data = [types.Datum(
    model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
    loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
)]
future = training_client.forward(data, "cross_entropy")
result = await future
print(f"Loss: {result.loss}")

forward_async

async def forward_async(
    data: List[types.Datum],
    loss_fn: types.LossFnType,
    loss_fn_config: Dict[str, float] | None = None
) -> APIFuture[types.ForwardBackwardOutput]

Async version of forward.

forward_backward

def forward_backward(
    data: List[types.Datum],
    loss_fn: types.LossFnType,
    loss_fn_config: Dict[str, float] | None = None
) -> APIFuture[types.ForwardBackwardOutput]

Compute forward pass and backward pass to calculate gradients.

Args:

  • data: List of training data samples
  • loss_fn: Loss function type (e.g., "cross_entropy")
  • loss_fn_config: Optional configuration for the loss function

Returns:

  • APIFuture containing the forward/backward outputs, loss, and gradients

Example:

data = [types.Datum(
    model_input=types.ModelInput.from_ints(tokenizer.encode("Hello")),
    loss_fn_inputs={"target_tokens": types.ModelInput.from_ints(tokenizer.encode("world"))}
)]

# Compute gradients
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")

# Update parameters
optim_future = training_client.optim_step(
    types.AdamParams(learning_rate=1e-4)
)

fwdbwd_result = await fwdbwd_future
print(f"Loss: {fwdbwd_result.loss}")

forward_backward_async

async def forward_backward_async(
    data: List[types.Datum],
    loss_fn: types.LossFnType,
    loss_fn_config: Dict[str, float] | None = None
) -> APIFuture[types.ForwardBackwardOutput]

Async version of forward_backward.

forward_backward_custom

def forward_backward_custom(
        data: List[types.Datum],
        loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]

Compute forward/backward with a custom loss function.

Allows you to define custom loss functions that operate on log probabilities. The custom function receives logprobs and computes loss and gradients.

Args:

  • data: List of training data samples
  • loss_fn: Custom loss function that takes (data, logprobs) and returns (loss, metrics)

Returns:

  • APIFuture containing the forward/backward outputs with custom loss

Example:

def custom_loss(data, logprobs_list):
    # Custom loss computation
    loss = torch.mean(torch.stack([torch.mean(lp) for lp in logprobs_list]))
    metrics = {"custom_metric": loss.item()}
    return loss, metrics

future = training_client.forward_backward_custom(data, custom_loss)
result = future.result()
print(f"Custom loss: {result.loss}")
print(f"Metrics: {result.metrics}")

forward_backward_custom_async

async def forward_backward_custom_async(
        data: List[types.Datum],
        loss_fn: CustomLossFnV1) -> APIFuture[types.ForwardBackwardOutput]

Async version of forward_backward_custom.

optim_step

def optim_step(
        adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]

Update model parameters using Adam optimizer.

The Adam optimizer used by tinker is identical to torch.optim.AdamW. Note that unlike PyTorch, Tinker's default weight decay value is 0.0 (no weight decay).

Args:

  • adam_params: Adam optimizer parameters (learning_rate, betas, eps, weight_decay)

Returns:

  • APIFuture containing optimizer step response

Example:

# First compute gradients
fwdbwd_future = training_client.forward_backward(data, "cross_entropy")

# Then update parameters
optim_future = training_client.optim_step(
    types.AdamParams(
        learning_rate=1e-4,
        weight_decay=0.01
    )
)

# Wait for both to complete
fwdbwd_result = await fwdbwd_future
optim_result = await optim_future

optim_step_async

async def optim_step_async(
        adam_params: types.AdamParams) -> APIFuture[types.OptimStepResponse]

Async version of optim_step.

save_state

def save_state(
        name: str,
        ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsResponse]

Save model weights to persistent storage.

Args:

  • name: Name for the saved checkpoint
  • ttl_seconds: Optional TTL in seconds for the checkpoint (None = never expires)

Returns:

  • APIFuture containing the save response with checkpoint path

Example:

# Save after training
save_future = training_client.save_state("checkpoint-001")
result = await save_future
print(f"Saved to: {result.path}")

save_state_async

async def save_state_async(
        name: str,
        ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsResponse]

Async version of save_state.

load_state

def load_state(path: str) -> APIFuture[types.LoadWeightsResponse]

Load model weights from a saved checkpoint.

This loads only the model weights, not optimizer state (e.g., Adam momentum). To also restore optimizer state, use load_state_with_optimizer.

Args:

  • path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")

Returns:

  • APIFuture containing the load response

Example:

# Load checkpoint to continue training (weights only, optimizer resets)
load_future = training_client.load_state("tinker://run-id/weights/checkpoint-001")
await load_future
# Continue training from loaded state

load_state_async

async def load_state_async(path: str) -> APIFuture[types.LoadWeightsResponse]

Async version of load_state.

load_state_with_optimizer

def load_state_with_optimizer(
        path: str) -> APIFuture[types.LoadWeightsResponse]

Load model weights and optimizer state from a checkpoint.

Args:

  • path: Tinker path to saved weights (e.g., "tinker://run-id/weights/checkpoint-001")

Returns:

  • APIFuture containing the load response

Example:

# Resume training with optimizer state
load_future = training_client.load_state_with_optimizer(
    "tinker://run-id/weights/checkpoint-001"
)
await load_future
# Continue training with restored optimizer momentum

load_state_with_optimizer_async

async def load_state_with_optimizer_async(
        path: str) -> APIFuture[types.LoadWeightsResponse]

Async version of load_state_with_optimizer.

save_weights_for_sampler

def save_weights_for_sampler(
    name: str,
    ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsForSamplerResponse]

Save model weights for use with a SamplingClient.

Args:

  • name: Name for the saved sampler weights
  • ttl_seconds: Optional TTL in seconds for the checkpoint (None = never expires)

Returns:

  • APIFuture containing the save response with sampler path

Example:

# Save weights for inference
save_future = training_client.save_weights_for_sampler("sampler-001")
result = await save_future
print(f"Sampler weights saved to: {result.path}")

# Use the path to create a sampling client
sampling_client = service_client.create_sampling_client(
    model_path=result.path
)

save_weights_for_sampler_async

async def save_weights_for_sampler_async(
    name: str,
    ttl_seconds: int | None = None
) -> APIFuture[types.SaveWeightsForSamplerResponse]

Async version of save_weights_for_sampler.

get_info

def get_info() -> types.GetInfoResponse

Get information about the current model.

Returns:

  • GetInfoResponse with model configuration and metadata

Example:

info = training_client.get_info()
print(f"Model ID: {info.model_data.model_id}")
print(f"Base model: {info.model_data.model_name}")
print(f"LoRA rank: {info.model_data.lora_rank}")

get_info_async

async def get_info_async() -> types.GetInfoResponse

Async version of get_info.

get_tokenizer

def get_tokenizer() -> PreTrainedTokenizer

Get the tokenizer for the current model.

Returns:

  • PreTrainedTokenizer compatible with the model

Example:

tokenizer = training_client.get_tokenizer()
tokens = tokenizer.encode("Hello world")
text = tokenizer.decode(tokens)

create_sampling_client

def create_sampling_client(
        model_path: str,
        retry_config: RetryConfig | None = None) -> SamplingClient

Create a SamplingClient from saved weights.

Args:

  • model_path: Tinker path to saved weights
  • retry_config: Optional configuration for retrying failed requests

Returns:

  • SamplingClient configured with the specified weights

Example:

sampling_client = training_client.create_sampling_client(
    "tinker://run-id/weights/checkpoint-001"
)
# Use sampling_client for inference

create_sampling_client_async

async def create_sampling_client_async(
        model_path: str,
        retry_config: RetryConfig | None = None) -> SamplingClient

Async version of create_sampling_client.

save_weights_and_get_sampling_client

def save_weights_and_get_sampling_client(
        name: str | None = None,
        retry_config: RetryConfig | None = None) -> SamplingClient

Save current weights and create a SamplingClient for inference.

Args:

  • name: Optional name for the saved weights (currently ignored for ephemeral saves)
  • retry_config: Optional configuration for retrying failed requests

Returns:

  • SamplingClient configured with the current model weights

Example:

# After training, create a sampling client directly
sampling_client = training_client.save_weights_and_get_sampling_client()

# Now use it for inference
prompt = types.ModelInput.from_ints(tokenizer.encode("Hello"))
params = types.SamplingParams(max_tokens=20)
result = sampling_client.sample(prompt, 1, params).result()

save_weights_and_get_sampling_client_async

async def save_weights_and_get_sampling_client_async(
        name: str | None = None,
        retry_config: RetryConfig | None = None) -> SamplingClient

Async version of save_weights_and_get_sampling_client.