atropos/example_trainer/grpo.py
Jai Suphavadeeprasit 80d2608c4e basic changes
2026-03-02 11:18:52 -05:00

1611 lines
54 KiB
Python

import argparse
import atexit
import json
import math
import os
import random
import shutil
import string
import subprocess
import time
from typing import List, Literal, Optional, Tuple
import numpy as np
import requests
import torch
import torch.nn.functional as F
import wandb # Added for logging
from pydantic import BaseModel, Field
from tenacity import retry, stop_after_attempt, wait_exponential
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer
# Import weight bridge for shared vLLM mode
try:
from example_trainer.vllm_weight_bridge import (
BridgeConfig,
VLLMWeightBridge,
create_bridge_from_training_config,
)
BRIDGE_AVAILABLE = True
except ImportError:
BRIDGE_AVAILABLE = False
# Import PEFT for LoRA training
try:
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
# Global variable to keep track of the vLLM process
vllm_process = None
def cleanup_vllm():
global vllm_process
if vllm_process:
print("\nTerminating vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
print("vLLM process terminated.")
except subprocess.TimeoutExpired:
print("vLLM process did not terminate gracefully, killing.")
vllm_process.kill()
vllm_process.wait()
print("vLLM process killed.")
vllm_process = None
# Register the cleanup function to be called on script exit
atexit.register(cleanup_vllm)
class TrainingConfig(BaseModel):
"""
Training details, model, etc
"""
model_name: str = Field(..., description="Name of the base model to train")
lr: float = Field(1e-5, description="Learning rate for the optimizer")
training_steps: int = Field(
10, description="Number of training steps"
) # Renamed from epochs
batch_size: int = Field(
2, description="Batch size for training (will be handled by get_data)"
)
seq_len: int = Field(2048, description="Sequence length for training")
gradient_accumulation_steps: int = Field(
32, description="Number of gradient accumulation steps"
)
device: str = Field(
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
)
save_path: str = Field(
"trained_model_checkpoints", description="Base path to save model checkpoints"
)
vllm_restart_interval: int = Field(
3, description="Restart vLLM every N training steps"
)
vllm_port: int = Field(9001, description="Port for the vLLM server")
vllm_gpu_memory_utilization: float = Field(
0.45, description="GPU memory utilization for vLLM server (0.0-1.0)"
)
# Wandb configuration
use_wandb: bool = Field(
False, description="Whether to use Weights & Biases for logging"
)
wandb_project: Optional[str] = Field(None, description="Wandb project name")
wandb_group: Optional[str] = Field(None, description="Wandb group name")
# Pipeline / weight bridge configuration
weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field(
"none",
description=(
"How to synchronize weights with inference server. "
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
"'lora_only': keep base model frozen, train/swap LoRA adapters. "
"'none': legacy mode, restart vLLM with new checkpoint files."
),
)
trainer_rank: int = Field(
0,
description="Rank of this trainer in the distributed group (for shared_vllm mode)",
)
world_size: int = Field(
1,
description="Total processes in the distributed group (for shared_vllm mode)",
)
init_method: str = Field(
"env://",
description=(
"PyTorch distributed init method URL. "
"Use 'env://' to read MASTER_ADDR/MASTER_PORT from environment, "
"or 'tcp://host:port' for explicit rendezvous."
),
)
num_inference_nodes: int = Field(
0,
description=(
"Number of inference nodes (vLLM servers) to coordinate with. "
"0 means single-node local mode."
),
)
# LoRA configuration (for lora_only mode)
lora_r: int = Field(16, description="LoRA rank (dimension of low-rank matrices)")
lora_alpha: int = Field(32, description="LoRA alpha (scaling factor)")
lora_dropout: float = Field(0.05, description="Dropout probability for LoRA layers")
lora_target_modules: Optional[List[str]] = Field(
None,
description=(
"List of module names to apply LoRA to. "
"If None, defaults to ['q_proj', 'v_proj'] for most models."
),
)
# Shared memory mode (for shared_vllm mode - NCCL weight broadcast)
use_shared_memory: bool = Field(
False,
description=(
"Enable shared memory weight updates via NCCL. "
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1. "
"Weight updates are broadcast to vLLM's daemon process."
),
)
def check_atropos_api(timeout: float = 30.0) -> bool:
"""
Check if the Atropos API server is reachable.
Args:
timeout: Maximum time to wait for the server
Returns:
True if server is reachable
"""
import time as _time
start = _time.time()
while _time.time() - start < timeout:
try:
response = requests.get("http://localhost:8000/info", timeout=2)
if response.status_code == 200:
print("[Trainer] ✓ Atropos API server is reachable")
return True
except requests.exceptions.ConnectionError:
pass
except Exception as e:
print(f"[Trainer] Waiting for Atropos API... ({e})")
_time.sleep(1)
print("[Trainer] ⚠ Warning: Atropos API server not reachable")
return False
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
def register_trainer(config: TrainingConfig):
"""
Register the trainer with the Atropos API.
Verifies registration succeeded before returning.
"""
response = requests.post(
"http://localhost:8000/register",
json={
# wandb fields are required strings - use empty string if None
"wandb_group": config.wandb_group or "",
"wandb_project": config.wandb_project or "",
"batch_size": config.batch_size * config.gradient_accumulation_steps,
"max_token_len": config.seq_len,
"starting_step": 0,
"checkpoint_dir": config.save_path,
"save_checkpoint_interval": config.training_steps,
"num_steps": config.training_steps,
},
timeout=10,
)
# Check for HTTP errors
response.raise_for_status()
# Verify we got a valid response with UUID
data = response.json()
if "uuid" not in data:
raise RuntimeError(f"Registration failed: {data}")
print(f"[Trainer] ✓ Registered with Atropos API (uuid: {data['uuid']})")
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
def get_batch():
data = requests.get("http://localhost:8000/batch", timeout=10).json()
# Check if there was an error (trainer not registered)
if data.get("status") == "error":
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
return data
def pad_data_to_good_offset(data, batch_size: int):
max_token_len = max(
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
)
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
# so we pad to the nearest multiple of 64
good_multiple = 64
if (max_token_len - 1) % (good_multiple) != 0:
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
token_setup_len = (
max_token_len + 1
) # add 1 so we can make it causal at the proper length
else:
token_setup_len = max_token_len
max_token_len = (
max_token_len - 1
) # since it's causal we need to remove the last bit...
# pad all tokens to max_token_len and add to lists
input_ids = list()
labels = list()
advantages = list()
lengths = list()
temperatures = list()
for item in data["batch"]:
scores = item["scores"]
scores = np.array(scores)
# check if we have more than 1 score...
if len(scores) > 1:
scores = scores - scores.mean()
scores = scores / max(scores.std(), 1e-8)
item["scores"] = scores
if item["overrides"] is not None:
for i in range(len(item["overrides"])):
if item["overrides"][i].get("set_advantage_to_zero", False):
item["scores"][i] = 0
for i in range(len(item["tokens"])):
lengths.append(
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
* good_multiple
)
label_item = np.concatenate(
[
np.array(item["masks"][i]),
np.full(
max(0, token_setup_len - len(item["tokens"][i])),
-100,
dtype=np.int32,
),
]
)
item["tokens"][i] = np.concatenate(
[
np.array(item["tokens"][i]),
np.zeros(
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
),
]
)
input_ids.append(item["tokens"][i][:-1])
labels.append(label_item[1:])
advantages.append(item["scores"][i])
# per-sample override -> group generation_params -> group_overrides - > 1.0
# need to update docs since this lets you set the temperature for each sample from the override
t = 1.0
if (
item.get("overrides")
and i < len(item["overrides"])
and isinstance(item["overrides"][i], dict)
and ("temperature" in item["overrides"][i])
):
t = float(item["overrides"][i]["temperature"])
elif item.get("generation_params") and (
"temperature" in item["generation_params"]
):
t = float(item["generation_params"]["temperature"])
elif item.get("group_overrides") and (
"temperature" in item["group_overrides"]
):
t = float(item["group_overrides"]["temperature"])
temperatures.append(t)
# combine all lists into tensors
token_batches = []
label_batches = []
advantage_batches = []
temperature_batches = []
for i in range(len(input_ids) // batch_size):
token_batches.append(
torch.tensor(
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
)
)
label_batches.append(
torch.tensor(
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
)
)
advantage_batches.append(
torch.tensor(
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
).view(-1, 1)
)
# Temperatures: one per sample, shaped for broadcasting to [B, 1, 1]
temperature_batches.append(
torch.tensor(
np.array(
temperatures[i * batch_size : (i + 1) * batch_size],
dtype=np.float32,
)
).view(-1, 1, 1)
)
return token_batches, label_batches, advantage_batches, temperature_batches
def get_data(
batch_size: int, seq_len: int
) -> List[
Tuple[
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
]
]:
"""
getting data from the api
"""
batches = []
while True:
data = get_batch()
if data["batch"] is not None:
# Save the batch
with open("temp.json", "w", encoding="utf-8") as f:
json.dump(data, f)
# In case the inference runs ahead of the training, we loop until we don't have any more data
batches.append(pad_data_to_good_offset(data, batch_size))
elif len(batches) > 0:
# Return the batches
return batches
else:
time.sleep(1)
# =============================================================================
# Common Training Helpers (shared across all modes)
# =============================================================================
def setup_wandb(config: TrainingConfig) -> bool:
"""
Initialize Weights & Biases logging if enabled.
Args:
config: Training configuration
Returns:
True if wandb is active, False otherwise
"""
if not config.use_wandb:
return False
if not config.wandb_project:
print("Warning: wandb_project not set, disabling wandb.")
return False
# Generate random group name if not provided
if not config.wandb_group:
config.wandb_group = "".join(
random.choices(string.ascii_letters + string.digits, k=8)
)
try:
wandb.init(
project=config.wandb_project,
group=config.wandb_group,
config=config.dict(),
)
print(
f"Wandb logging enabled. Run: {wandb.run.name} "
f"(Project: {config.wandb_project})"
)
return True
except Exception as e:
print(f"Error initializing wandb: {e}. Disabling wandb.")
return False
def load_model_and_tokenizer(
config: TrainingConfig,
bridge: Optional["VLLMWeightBridge"] = None,
) -> Tuple[torch.nn.Module, "AutoTokenizer"]:
"""
Load or attach to model based on weight_bridge_mode.
Args:
config: Training configuration
bridge: Optional weight bridge for shared_vllm mode
Returns:
Tuple of (model, tokenizer)
"""
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
if config.weight_bridge_mode == "shared_vllm" and bridge is not None:
# Shared vLLM mode: load model, weights will be broadcast via NCCL
print("[Setup] Loading model for shared vLLM mode...")
if config.use_shared_memory:
print("[Setup] NCCL shared memory mode - updates broadcast to vLLM daemon")
else:
print("[Setup] HTTP notification mode - vLLM notified of updates")
model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
model.to(config.device)
elif config.weight_bridge_mode == "lora_only":
model = _load_model_with_lora(config)
else:
print("[Setup] Loading model for legacy mode...")
model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
model.to(config.device)
# Enable gradient checkpointing (saves memory)
# For LoRA, use PEFT's method; for others, use standard method
if config.weight_bridge_mode == "lora_only":
# PEFT models need gradient_checkpointing enabled on base model
# and require use_reentrant=False for proper gradient flow
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
else:
# Standard gradient checkpointing
model.gradient_checkpointing_enable()
model.train()
return model, tokenizer
def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
"""
Load base model and wrap with LoRA adapters.
Args:
config: Training configuration with LoRA settings
Returns:
PEFT model with LoRA adapters applied
"""
if not PEFT_AVAILABLE:
raise RuntimeError(
"PEFT library not available. Install with: pip install peft"
)
print("[Setup] Loading base model for LoRA mode...")
base_model = AutoModelForCausalLM.from_pretrained(
config.model_name, torch_dtype=torch.bfloat16
)
base_model.to(config.device)
# Determine target modules
target_modules = config.lora_target_modules
if target_modules is None:
# Default modules for most transformer models
target_modules = ["q_proj", "v_proj"]
print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
print(f"Target modules: {target_modules}")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=target_modules,
bias="none",
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
return model
def save_lora_checkpoint(
model: torch.nn.Module,
save_path: str,
step: int,
is_final: bool = False,
) -> str:
"""
Save LoRA adapter checkpoint.
Args:
model: PEFT model with LoRA adapters
save_path: Base directory for checkpoints
step: Current training step
is_final: Whether this is the final checkpoint
Returns:
Path where adapter was saved
"""
if is_final:
adapter_path = os.path.join(save_path, "final_adapter")
else:
adapter_path = os.path.join(save_path, f"adapter_step_{step}")
print(f" Saving LoRA adapter to {adapter_path}...")
if os.path.exists(adapter_path):
shutil.rmtree(adapter_path)
os.makedirs(adapter_path, exist_ok=True)
# Save only the adapter weights (much smaller than full model)
model.save_pretrained(adapter_path)
print(" Adapter saved.")
return adapter_path
def compute_grpo_loss(
model: torch.nn.Module,
tokens: torch.Tensor,
labels: torch.Tensor,
advantages: torch.Tensor,
temperatures: torch.Tensor,
gradient_accumulation_steps: int,
) -> Tuple[torch.Tensor, dict]:
"""
Compute GRPO loss for a single micro-batch.
Args:
model: The model to compute loss for
tokens: Input token IDs [batch, seq_len]
labels: Target labels [batch, seq_len]
advantages: Advantage values [batch, 1]
temperatures: Temperature values [batch, 1, 1]
gradient_accumulation_steps: Number of accumulation steps
Returns:
Tuple of (loss tensor, metrics dict)
"""
# Forward pass
outputs = model(tokens)
logits = outputs.logits
# Temperature scaling
t = temperatures.to(logits.device, logits.dtype)
t = torch.where(t <= 0, torch.ones_like(t), t)
logits = logits / t
# Log probabilities per token
logp_per_token = -F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
reduction="none",
ignore_index=-100,
).view(labels.shape)
# Masking based on labels != -100
mask = (labels != -100).float()
# Compute metrics (no grad needed)
with torch.no_grad():
pos = (advantages > 0).float()
neg = (advantages <= 0).float()
mask_float = mask.to(logp_per_token.dtype)
mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8)
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
pos_logp = (logp_per_token * pos).mean().item()
neg_logp = (logp_per_token * neg).mean().item()
# GRPO loss
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
grpo_loss = (
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
* advantages.to(logp_per_token.device)
).mean() / gradient_accumulation_steps
metrics = {
"pos_logp": pos_logp,
"neg_logp": neg_logp,
"avg_logp": avg_logp,
"pos_count": pos.sum().item(),
"neg_count": neg.sum().item(),
}
return grpo_loss, metrics
def run_training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
token_batches: List[torch.Tensor],
label_batches: List[torch.Tensor],
advantage_batches: List[torch.Tensor],
temperature_batches: List[torch.Tensor],
config: TrainingConfig,
) -> dict:
"""
Run a single training step (forward, backward, optimizer step).
Args:
model: The model to train
optimizer: The optimizer
token_batches: List of token tensors
label_batches: List of label tensors
advantage_batches: List of advantage tensors
temperature_batches: List of temperature tensors
config: Training configuration
Returns:
Dict of training metrics for this step
"""
total_loss = 0.0
total_pos_logp = 0.0
total_neg_logp = 0.0
total_pos = 0.0
total_neg = 0.0
# Accumulate gradients over micro-batches
for tokens, labels, advantages, temperatures in zip(
token_batches, label_batches, advantage_batches, temperature_batches
):
tokens = tokens.to(config.device)
labels = labels.to(config.device)
advantages = advantages.to(config.device)
loss, metrics = compute_grpo_loss(
model, tokens, labels, advantages, temperatures,
config.gradient_accumulation_steps
)
loss.backward()
total_loss += loss.item()
total_pos_logp += metrics["pos_logp"]
total_neg_logp += metrics["neg_logp"]
total_pos += metrics["pos_count"]
total_neg += metrics["neg_count"]
# Gradient clipping and optimizer step
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Normalize metrics
if total_pos > 0:
total_pos_logp /= total_pos
if total_neg > 0:
total_neg_logp /= total_neg
return {
"loss": total_loss,
"grad_norm": grad_norm.item(),
"pos_logp": total_pos_logp,
"neg_logp": total_neg_logp,
}
def save_checkpoint(
model: torch.nn.Module,
tokenizer: "AutoTokenizer",
save_path: str,
step: int,
is_final: bool = False,
) -> str:
"""
Save model checkpoint.
Args:
model: Model to save
tokenizer: Tokenizer to save
save_path: Base directory for checkpoints
step: Current training step
is_final: Whether this is the final checkpoint
Returns:
Path where checkpoint was saved
"""
if is_final:
checkpoint_path = os.path.join(save_path, "final_model")
else:
checkpoint_path = os.path.join(save_path, f"step_{step}")
print(f" Saving checkpoint to {checkpoint_path}...")
if os.path.exists(checkpoint_path):
shutil.rmtree(checkpoint_path)
os.makedirs(checkpoint_path, exist_ok=True)
model.save_pretrained(checkpoint_path)
tokenizer.save_pretrained(checkpoint_path)
print(" Checkpoint saved.")
return checkpoint_path
def log_metrics(
metrics: dict,
step: int,
use_wandb: bool,
extra_metrics: Optional[dict] = None,
) -> None:
"""
Log training metrics to console and optionally wandb.
Args:
metrics: Dict of metrics from training step
step: Current step number
use_wandb: Whether to log to wandb
extra_metrics: Optional additional metrics to log
"""
# Console output with timing info
timing_str = ""
if "step_time" in metrics:
timing_str += f", Step time: {metrics['step_time']:.2f}s"
if "sync_time" in metrics and metrics["sync_time"] > 0:
timing_str += f", Sync time: {metrics['sync_time']:.2f}s"
if "data_fetch_time" in metrics:
timing_str += f", Data fetch: {metrics['data_fetch_time']:.2f}s"
if "gpu_memory_gb" in metrics:
timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB"
print(f" Loss: {metrics['loss']:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
if use_wandb:
log_dict = {
"train/loss": metrics["loss"],
"train/grad_norm": metrics["grad_norm"],
"train/pos_logp": metrics["pos_logp"],
"train/neg_logp": metrics["neg_logp"],
}
# Add timing metrics if present
if "step_time" in metrics:
log_dict["train/step_time"] = metrics["step_time"]
if "sync_time" in metrics:
log_dict["train/sync_time"] = metrics["sync_time"]
if "data_fetch_time" in metrics:
log_dict["train/data_fetch_time"] = metrics["data_fetch_time"]
if "gpu_memory_gb" in metrics:
log_dict["train/gpu_memory_gb"] = metrics["gpu_memory_gb"]
if "gpu_memory_reserved_gb" in metrics:
log_dict["train/gpu_memory_reserved_gb"] = metrics["gpu_memory_reserved_gb"]
if extra_metrics:
log_dict.update(extra_metrics)
wandb.log(log_dict, step=step)
def finalize_training(
use_wandb: bool,
training_start_time: Optional[float] = None,
mode: str = "unknown",
total_steps: int = 0,
benchmark_stats: Optional[dict] = None,
) -> None:
"""Clean up after training and log benchmark summary.
Args:
use_wandb: Whether wandb is enabled
training_start_time: Start time of training
mode: Training mode name
total_steps: Total steps completed
benchmark_stats: Dict with lists of per-step metrics:
- step_times: List of step durations
- sync_times: List of sync durations
- data_fetch_times: List of data fetch durations
- gpu_memories: List of GPU memory readings (GB)
"""
print("\nTraining finished.")
# Default empty stats
if benchmark_stats is None:
benchmark_stats = {}
# Log benchmark summary
if training_start_time is not None:
total_time = time.time() - training_start_time
peak_gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
# Calculate averages from collected stats
step_times = benchmark_stats.get("step_times", [])
sync_times = benchmark_stats.get("sync_times", [])
data_fetch_times = benchmark_stats.get("data_fetch_times", [])
gpu_memories = benchmark_stats.get("gpu_memories", [])
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
total_step_time = sum(step_times)
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
total_sync_time = sum(sync_times)
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
total_data_fetch = sum(data_fetch_times)
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
print(f"\n{'='*70}")
print(f"BENCHMARK SUMMARY ({mode})")
print(f"{'='*70}")
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(f" Total steps: {total_steps}")
print(f" ")
print(f" TIMING BREAKDOWN:")
print(f" Avg step time: {avg_step_time:.2f}s")
print(f" Total step time: {total_step_time:.2f}s")
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
print(f" Total sync time: {total_sync_time:.2f}s")
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
print(f" Total data fetch time: {total_data_fetch:.2f}s")
print(f" ")
print(f" MEMORY:")
print(f" Peak GPU memory: {peak_gpu_mem_gb:.2f} GB")
print(f" Avg GPU memory: {avg_gpu_mem:.2f} GB")
print(f"{'='*70}\n")
if use_wandb:
# Total time metrics
wandb.summary["benchmark/total_time_seconds"] = total_time
wandb.summary["benchmark/total_time_minutes"] = total_time / 60
wandb.summary["benchmark/mode"] = mode
wandb.summary["benchmark/total_steps"] = total_steps
# Step timing metrics
wandb.summary["benchmark/avg_step_time_seconds"] = avg_step_time
wandb.summary["benchmark/total_step_time_seconds"] = total_step_time
# Sync timing metrics
wandb.summary["benchmark/avg_sync_time_seconds"] = avg_sync_time
wandb.summary["benchmark/total_sync_time_seconds"] = total_sync_time
wandb.summary["benchmark/num_syncs"] = len(sync_times)
# Data fetch timing metrics
wandb.summary["benchmark/avg_data_fetch_time_seconds"] = avg_data_fetch
wandb.summary["benchmark/total_data_fetch_time_seconds"] = total_data_fetch
# Memory metrics
wandb.summary["benchmark/peak_gpu_memory_gb"] = peak_gpu_mem_gb
wandb.summary["benchmark/avg_gpu_memory_gb"] = avg_gpu_mem
if use_wandb:
wandb.finish()
def train(config: TrainingConfig):
"""
Legacy GRPO training with periodic vLLM restarts.
This mode saves checkpoints to disk and restarts vLLM to pick up new weights.
Use weight_bridge_mode='shared_vllm' for in-place weight updates without restarts.
"""
global vllm_process
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
model, tokenizer = load_model_and_tokenizer(config)
optimizer = AdamW(model.parameters(), lr=config.lr)
print(f"\n{'='*60}")
print("LEGACY MODE (checkpoint + vLLM restart)")
print(f"{'='*60}")
print(f"Training for {config.training_steps} steps on {config.device}")
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
print(f"Save path: {config.save_path}")
print(f"{'='*60}\n")
os.makedirs(config.save_path, exist_ok=True)
register_trainer(config)
# Launch initial vLLM server
vllm_process = _launch_vllm_server(config, config.model_name)
# === Benchmark tracking ===
benchmark_stats = {
"step_times": [],
"sync_times": [],
"data_fetch_times": [],
"gpu_memories": [],
}
# === Training Loop ===
batches = []
for step in range(config.training_steps):
print(f"\nStep {step+1}/{config.training_steps}")
# Track data fetch time
data_fetch_start = time.time()
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len)
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
# Terminate vLLM before training step (to free GPU memory)
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
if should_sync:
_terminate_vllm_process()
# Track step time
step_start = time.time()
# Run training step using common helper
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
config
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
# Track GPU memory
if torch.cuda.is_available():
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
else:
gpu_mem_gb = 0
gpu_mem_reserved_gb = 0
# Track sync time
sync_time = 0
if should_sync:
sync_start = time.time()
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1)
torch.cuda.empty_cache()
vllm_process = _launch_vllm_server(config, checkpoint_path)
sync_time = time.time() - sync_start
benchmark_stats["sync_times"].append(sync_time)
# Add timing metrics
metrics["step_time"] = step_time
metrics["sync_time"] = sync_time
metrics["data_fetch_time"] = data_fetch_time
metrics["gpu_memory_gb"] = gpu_mem_gb
metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb
# Log metrics
log_metrics(metrics, step + 1, use_wandb, {
"train/learning_rate": optimizer.param_groups[0]["lr"],
})
# Check for unexpected vLLM termination
_check_vllm_process_health()
# === Cleanup ===
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats)
# =============================================================================
# vLLM Process Management (Legacy Mode Only)
# =============================================================================
def _launch_vllm_server(config: TrainingConfig, model_path: str) -> Optional[subprocess.Popen]:
"""Launch a vLLM server process using our custom vllm_api_server.py.
Uses the custom server instead of standard vLLM because:
- Standard vLLM only has /v1/completions (OpenAI-compatible)
- Our custom server has /generate endpoint needed by VLLMServer class
- This allows proper tokens_and_logprobs_completion support
"""
global vllm_process
# Use our custom vllm_api_server.py instead of standard vLLM
# This provides the /generate endpoint that VLLMServer needs
script_dir = os.path.dirname(os.path.abspath(__file__))
custom_server_path = os.path.join(script_dir, "vllm_api_server.py")
vllm_command = [
"python", custom_server_path,
"--model", model_path,
"--port", str(config.vllm_port),
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
]
# Add served-model-name if using checkpoint path
if model_path != config.model_name:
vllm_command.extend(["--served-model-name", config.model_name])
print(f" Launching vLLM: {' '.join(vllm_command)}")
try:
proc = subprocess.Popen(vllm_command)
print(f" vLLM launched with PID: {proc.pid}")
# Check for immediate startup errors
try:
proc.communicate(timeout=2)
if proc.returncode is not None and proc.returncode != 0:
print(" WARNING: vLLM failed to start")
return None
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details)")
return proc
except FileNotFoundError:
print(" ERROR: vLLM not found. Is it installed?")
return None
except Exception as e:
print(f" ERROR launching vLLM: {e}")
return None
def _terminate_vllm_process() -> None:
"""Terminate the running vLLM process if any."""
global vllm_process
if vllm_process is None:
return
print(" Terminating vLLM process...")
vllm_process.terminate()
try:
vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(" vLLM did not terminate gracefully, killing...")
vllm_process.kill()
vllm_process.wait()
vllm_process = None
def _check_vllm_process_health() -> None:
"""Check if vLLM process terminated unexpectedly (legacy mode)."""
global vllm_process
if vllm_process is not None and vllm_process.poll() is not None:
print(f" WARNING: vLLM terminated unexpectedly (code: {vllm_process.returncode})")
vllm_process = None
def train_shared_vllm(config: TrainingConfig):
"""
GRPO training with shared vLLM weights.
Instead of saving checkpoints and restarting vLLM, this mode:
1. Joins the same distributed group as vLLM
2. Attaches to vLLM's weight tensors directly
3. optimizer.step() modifies vLLM's weights in-place
4. vLLM immediately uses updated weights (no restart!)
"""
if not BRIDGE_AVAILABLE:
raise RuntimeError(
"vLLM weight bridge not available. "
"Ensure vllm_weight_bridge.py is in the same directory."
)
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
print(f"\n{'='*60}")
if config.use_shared_memory:
print("SHARED VLLM MODE (NCCL BROADCAST)")
print(">>> Weights broadcast to vLLM via NCCL!")
else:
print("SHARED VLLM MODE (HTTP notifications)")
print(f"{'='*60}")
print(f"Model: {config.model_name}")
print(f"Shared Memory: {config.use_shared_memory}")
print(f"Distributed: rank={config.trainer_rank}/{config.world_size}")
print(f"Init method: {config.init_method}")
print(f"Inference nodes: {config.num_inference_nodes}")
print(f"Save path: {config.save_path}")
print(f"{'='*60}\n")
# Initialize weight bridge
print("[1/3] Initializing weight bridge...")
bridge = create_bridge_from_training_config(config)
# Load model with bridge attachment
print("[2/3] Loading model with shared weights...")
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
optimizer = AdamW(model.parameters(), lr=config.lr)
# For NCCL mode, build mapping between trainer's and vLLM's param names
if config.use_shared_memory:
bridge.build_param_mapping(model)
print(f"[3/3] Starting training for {config.training_steps} steps")
print("NOTE: vLLM sees weight updates immediately after each step!")
print("-" * 60)
os.makedirs(config.save_path, exist_ok=True)
# Check Atropos API and register BEFORE training loop
print("\n[Setup] Connecting to Atropos API...")
if not check_atropos_api(timeout=30):
raise RuntimeError(
"Atropos API server not reachable. "
"Please start it with: run-api"
)
register_trainer(config)
# === Benchmark tracking ===
benchmark_stats = {
"step_times": [],
"sync_times": [], # For shared mode, this is the notify_update time
"data_fetch_times": [],
"gpu_memories": [],
}
# === Training Loop ===
batches = []
for step in range(config.training_steps):
print(f"\nStep {step+1}/{config.training_steps}")
# Track data fetch time
data_fetch_start = time.time()
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len)
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
# Track step time
step_start = time.time()
# Run training step using common helper
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
config
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
# Track GPU memory
if torch.cuda.is_available():
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
else:
gpu_mem_gb = 0
gpu_mem_reserved_gb = 0
# Sync weights with vLLM
sync_start = time.time()
if config.use_shared_memory:
# NCCL broadcast mode - weights sent directly to vLLM daemon
bridge.broadcast_weights(model)
print(f" [SHARED] Weights broadcast via NCCL - step {step+1} (sync: {(time.time()-sync_start)*1000:.1f}ms)")
else:
# HTTP notification mode - just notify
bridge.notify_update()
print(f" [SHARED] Update notification sent - step {step+1}")
sync_time = time.time() - sync_start
benchmark_stats["sync_times"].append(sync_time)
# Add timing metrics
metrics["step_time"] = step_time
metrics["sync_time"] = sync_time
metrics["data_fetch_time"] = data_fetch_time
metrics["gpu_memory_gb"] = gpu_mem_gb
metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb
# Log metrics
log_metrics(metrics, step + 1, use_wandb, {
"train/learning_rate": optimizer.param_groups[0]["lr"],
"bridge/update_count": step + 1,
})
# Periodic checkpoint save (for recovery, not for vLLM sync)
if (step + 1) % config.vllm_restart_interval == 0:
save_checkpoint(model, tokenizer, config.save_path, step + 1)
# === Cleanup ===
bridge.cleanup()
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats)
def _check_vllm_health(port: int) -> bool:
"""Check if external vLLM server is running and healthy."""
try:
response = requests.get(f"http://localhost:{port}/health", timeout=5)
return response.status_code == 200
except Exception:
return False
def _hotswap_lora_adapter(port: int, adapter_path: str) -> bool:
"""
Request vLLM to hot-swap to a new LoRA adapter.
Args:
port: vLLM server port
adapter_path: Path to the saved adapter directory
Returns:
True if successful, False otherwise
"""
try:
response = requests.post(
f"http://localhost:{port}/lora/load",
json={"adapter_path": adapter_path},
timeout=30,
)
if response.status_code == 200:
print(f" [LORA] Hot-swapped adapter: {adapter_path}")
return True
else:
print(f" [LORA] Hot-swap failed: {response.text}")
return False
except Exception as e:
print(f" [LORA] Hot-swap request failed: {e}")
return False
def train_lora(config: TrainingConfig):
"""
GRPO training with LoRA adapters.
This mode keeps the base model frozen and only trains LoRA adapter weights.
REQUIRES: External vLLM server running via vllm_api_server.py
Benefits:
- Much faster training (fewer parameters)
- Smaller checkpoint sizes (adapter only, not full model)
- Adapters can be hot-swapped in vLLM via /lora/load endpoint
"""
if not PEFT_AVAILABLE:
raise RuntimeError(
"PEFT library required for LoRA mode. Install with: pip install peft"
)
training_start_time = time.time()
# === Setup ===
use_wandb = setup_wandb(config)
print(f"\n{'='*60}")
print("LORA MODE (adapter-only training)")
print(f"{'='*60}")
print(f"Base model: {config.model_name}")
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
print(f"Save path: {config.save_path}")
print(f"vLLM port: {config.vllm_port}")
print(f"{'='*60}\n")
# Check that external vLLM is running
print("[1/3] Checking external vLLM server...")
if not _check_vllm_health(config.vllm_port):
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
print("\nLoRA mode requires an external vLLM server. Start it first:")
print(f" python example_trainer/vllm_api_server.py \\")
print(f" --model {config.model_name} \\")
print(f" --port {config.vllm_port} \\")
print(f" --gpu-memory-utilization 0.45")
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
print(f"vLLM server healthy on port {config.vllm_port}")
# Load model with LoRA adapters
print("[2/3] Loading model with LoRA adapters...")
model, tokenizer = load_model_and_tokenizer(config)
# Only optimize LoRA parameters (base model is frozen)
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=config.lr)
print(f"[3/3] Starting training for {config.training_steps} steps")
print("-" * 60)
os.makedirs(config.save_path, exist_ok=True)
register_trainer(config)
# NOTE: No vLLM launch here - using external vLLM server
# === Benchmark tracking ===
benchmark_stats = {
"step_times": [],
"sync_times": [], # For LoRA mode, this is adapter save + hot-swap time
"data_fetch_times": [],
"gpu_memories": [],
}
# === Training Loop ===
batches = []
for step in range(config.training_steps):
print(f"\nStep {step+1}/{config.training_steps}")
# Track data fetch time
data_fetch_start = time.time()
if len(batches) == 0:
batches = get_data(config.batch_size, config.seq_len)
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
data_fetch_time = time.time() - data_fetch_start
benchmark_stats["data_fetch_times"].append(data_fetch_time)
# Track step time
step_start = time.time()
# Run training step
metrics = run_training_step(
model, optimizer,
token_batches, label_batches, advantage_batches, temperature_batches,
config
)
step_time = time.time() - step_start
benchmark_stats["step_times"].append(step_time)
# Track GPU memory
if torch.cuda.is_available():
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
else:
gpu_mem_gb = 0
gpu_mem_reserved_gb = 0
# Track sync time (adapter save + hot-swap)
sync_time = 0
should_sync = (step + 1) % config.vllm_restart_interval == 0
if should_sync:
sync_start = time.time()
adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
# Try to hot-swap the adapter in vLLM (non-blocking, best effort)
_hotswap_lora_adapter(config.vllm_port, adapter_path)
sync_time = time.time() - sync_start
benchmark_stats["sync_times"].append(sync_time)
# Add timing metrics
metrics["step_time"] = step_time
metrics["sync_time"] = sync_time
metrics["data_fetch_time"] = data_fetch_time
metrics["gpu_memory_gb"] = gpu_mem_gb
metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb
# Log metrics
log_metrics(metrics, step + 1, use_wandb, {
"train/learning_rate": optimizer.param_groups[0]["lr"],
"lora/trainable_params": sum(p.numel() for p in trainable_params),
})
# === Cleanup ===
# NOTE: No vLLM termination - external server keeps running
# Save final adapter (track this sync time too)
final_sync_start = time.time()
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
# Hot-swap to final adapter
_hotswap_lora_adapter(config.vllm_port, final_adapter_path)
final_sync_time = time.time() - final_sync_start
benchmark_stats["sync_times"].append(final_sync_time)
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats)
# Also save tokenizer for convenience
tokenizer_path = os.path.join(config.save_path, "tokenizer")
tokenizer.save_pretrained(tokenizer_path)
print(f"Tokenizer saved to {tokenizer_path}")
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for the GRPO trainer."""
parser = argparse.ArgumentParser(
description="GRPO Trainer with optional shared-weight vLLM integration",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# --- Core training arguments ---
parser.add_argument(
"--model-name",
type=str,
required=True,
help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')",
)
parser.add_argument(
"--lr",
type=float,
default=1e-5,
help="Learning rate for the optimizer",
)
parser.add_argument(
"--training-steps",
type=int,
default=10,
help="Number of training steps to run",
)
parser.add_argument(
"--batch-size",
type=int,
default=2,
help="Batch size for training",
)
parser.add_argument(
"--seq-len",
type=int,
default=2048,
help="Maximum sequence length",
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
default=32,
help="Number of gradient accumulation steps",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to train on (cuda/cpu)",
)
parser.add_argument(
"--save-path",
type=str,
default="trained_model_checkpoints",
help="Directory to save model checkpoints",
)
# --- vLLM arguments ---
parser.add_argument(
"--vllm-restart-interval",
type=int,
default=3,
help="Restart vLLM every N training steps (legacy mode only)",
)
parser.add_argument(
"--vllm-port",
type=int,
default=9001,
help="Port for the vLLM server",
)
parser.add_argument(
"--vllm-gpu-memory-utilization",
type=float,
default=0.45,
help="GPU memory utilization for vLLM server (0.0-1.0)",
)
# --- Wandb arguments ---
parser.add_argument(
"--use-wandb",
action="store_true",
help="Enable Weights & Biases logging",
)
parser.add_argument(
"--wandb-project",
type=str,
default=None,
help="Wandb project name",
)
parser.add_argument(
"--wandb-group",
type=str,
default=None,
help="Wandb group name",
)
# --- Pipeline / weight bridge arguments ---
parser.add_argument(
"--weight-bridge-mode",
type=str,
choices=["shared_vllm", "lora_only", "none"],
default="none",
help=(
"Weight sync mode: "
"'shared_vllm' = attach to vLLM shared memory, "
"'lora_only' = train LoRA adapters only, "
"'none' = legacy restart-based sync"
),
)
parser.add_argument(
"--trainer-rank",
type=int,
default=0,
help="Rank of this trainer in the distributed group",
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Total processes in the distributed group",
)
parser.add_argument(
"--init-method",
type=str,
default="env://",
help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')",
)
parser.add_argument(
"--num-inference-nodes",
type=int,
default=0,
help="Number of inference nodes to coordinate with (0 = single-node local)",
)
# --- LoRA arguments ---
parser.add_argument(
"--lora-r",
type=int,
default=16,
help="LoRA rank (dimension of low-rank matrices)",
)
parser.add_argument(
"--lora-alpha",
type=int,
default=32,
help="LoRA alpha (scaling factor, typically 2x rank)",
)
parser.add_argument(
"--lora-dropout",
type=float,
default=0.05,
help="Dropout probability for LoRA layers",
)
parser.add_argument(
"--lora-target-modules",
type=str,
nargs="+",
default=None,
help="Module names to apply LoRA to (default: q_proj v_proj)",
)
# --- Shared memory arguments ---
parser.add_argument(
"--use-shared-memory",
action="store_true",
help=(
"Enable NCCL shared memory weight updates (shared_vllm mode only). "
"Weights are broadcast to vLLM's daemon via NCCL. "
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
),
)
return parser.parse_args()
def config_from_args(args: argparse.Namespace) -> TrainingConfig:
"""Build a TrainingConfig from parsed CLI arguments."""
return TrainingConfig(
model_name=args.model_name,
lr=args.lr,
training_steps=args.training_steps,
batch_size=args.batch_size,
seq_len=args.seq_len,
gradient_accumulation_steps=args.gradient_accumulation_steps,
device=args.device,
save_path=args.save_path,
vllm_restart_interval=args.vllm_restart_interval,
vllm_port=args.vllm_port,
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
use_wandb=args.use_wandb,
wandb_project=args.wandb_project,
wandb_group=args.wandb_group,
weight_bridge_mode=args.weight_bridge_mode,
trainer_rank=args.trainer_rank,
world_size=args.world_size,
init_method=args.init_method,
num_inference_nodes=args.num_inference_nodes,
lora_r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
lora_target_modules=args.lora_target_modules,
use_shared_memory=getattr(args, 'use_shared_memory', False),
)
# Example usage (optional, can be run from another script)
if __name__ == "__main__":
args = parse_args()
training_config = config_from_args(args)
print(f"Weight bridge mode: {training_config.weight_bridge_mode}")
if training_config.weight_bridge_mode == "shared_vllm":
# Shared vLLM mode: attach to vLLM's weights, update in-place
train_shared_vllm(training_config)
elif training_config.weight_bridge_mode == "lora_only":
# LoRA mode: freeze base model, train adapters only
train_lora(training_config)
else:
# Legacy mode: periodic checkpoint saves + vLLM restarts
train(training_config)