mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
1611 lines
54 KiB
Python
1611 lines
54 KiB
Python
import argparse
|
|
import atexit
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import shutil
|
|
import string
|
|
import subprocess
|
|
import time
|
|
from typing import List, Literal, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import wandb # Added for logging
|
|
from pydantic import BaseModel, Field
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
from torch.optim import AdamW
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
# Import weight bridge for shared vLLM mode
|
|
try:
|
|
from example_trainer.vllm_weight_bridge import (
|
|
BridgeConfig,
|
|
VLLMWeightBridge,
|
|
create_bridge_from_training_config,
|
|
)
|
|
BRIDGE_AVAILABLE = True
|
|
except ImportError:
|
|
BRIDGE_AVAILABLE = False
|
|
|
|
# Import PEFT for LoRA training
|
|
try:
|
|
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
|
|
PEFT_AVAILABLE = True
|
|
except ImportError:
|
|
PEFT_AVAILABLE = False
|
|
|
|
# Global variable to keep track of the vLLM process
|
|
vllm_process = None
|
|
|
|
|
|
def cleanup_vllm():
|
|
global vllm_process
|
|
if vllm_process:
|
|
print("\nTerminating vLLM process...")
|
|
vllm_process.terminate()
|
|
try:
|
|
vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown
|
|
print("vLLM process terminated.")
|
|
except subprocess.TimeoutExpired:
|
|
print("vLLM process did not terminate gracefully, killing.")
|
|
vllm_process.kill()
|
|
vllm_process.wait()
|
|
print("vLLM process killed.")
|
|
vllm_process = None
|
|
|
|
|
|
# Register the cleanup function to be called on script exit
|
|
atexit.register(cleanup_vllm)
|
|
|
|
|
|
class TrainingConfig(BaseModel):
|
|
"""
|
|
Training details, model, etc
|
|
"""
|
|
|
|
model_name: str = Field(..., description="Name of the base model to train")
|
|
lr: float = Field(1e-5, description="Learning rate for the optimizer")
|
|
training_steps: int = Field(
|
|
10, description="Number of training steps"
|
|
) # Renamed from epochs
|
|
batch_size: int = Field(
|
|
2, description="Batch size for training (will be handled by get_data)"
|
|
)
|
|
seq_len: int = Field(2048, description="Sequence length for training")
|
|
gradient_accumulation_steps: int = Field(
|
|
32, description="Number of gradient accumulation steps"
|
|
)
|
|
device: str = Field(
|
|
"cuda" if torch.cuda.is_available() else "cpu", description="Device to train on"
|
|
)
|
|
save_path: str = Field(
|
|
"trained_model_checkpoints", description="Base path to save model checkpoints"
|
|
)
|
|
vllm_restart_interval: int = Field(
|
|
3, description="Restart vLLM every N training steps"
|
|
)
|
|
vllm_port: int = Field(9001, description="Port for the vLLM server")
|
|
vllm_gpu_memory_utilization: float = Field(
|
|
0.45, description="GPU memory utilization for vLLM server (0.0-1.0)"
|
|
)
|
|
|
|
# Wandb configuration
|
|
use_wandb: bool = Field(
|
|
False, description="Whether to use Weights & Biases for logging"
|
|
)
|
|
wandb_project: Optional[str] = Field(None, description="Wandb project name")
|
|
wandb_group: Optional[str] = Field(None, description="Wandb group name")
|
|
|
|
# Pipeline / weight bridge configuration
|
|
weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field(
|
|
"none",
|
|
description=(
|
|
"How to synchronize weights with inference server. "
|
|
"'shared_vllm': attach to vLLM's shared memory tensors and update in-place. "
|
|
"'lora_only': keep base model frozen, train/swap LoRA adapters. "
|
|
"'none': legacy mode, restart vLLM with new checkpoint files."
|
|
),
|
|
)
|
|
trainer_rank: int = Field(
|
|
0,
|
|
description="Rank of this trainer in the distributed group (for shared_vllm mode)",
|
|
)
|
|
world_size: int = Field(
|
|
1,
|
|
description="Total processes in the distributed group (for shared_vllm mode)",
|
|
)
|
|
init_method: str = Field(
|
|
"env://",
|
|
description=(
|
|
"PyTorch distributed init method URL. "
|
|
"Use 'env://' to read MASTER_ADDR/MASTER_PORT from environment, "
|
|
"or 'tcp://host:port' for explicit rendezvous."
|
|
),
|
|
)
|
|
num_inference_nodes: int = Field(
|
|
0,
|
|
description=(
|
|
"Number of inference nodes (vLLM servers) to coordinate with. "
|
|
"0 means single-node local mode."
|
|
),
|
|
)
|
|
|
|
# LoRA configuration (for lora_only mode)
|
|
lora_r: int = Field(16, description="LoRA rank (dimension of low-rank matrices)")
|
|
lora_alpha: int = Field(32, description="LoRA alpha (scaling factor)")
|
|
lora_dropout: float = Field(0.05, description="Dropout probability for LoRA layers")
|
|
lora_target_modules: Optional[List[str]] = Field(
|
|
None,
|
|
description=(
|
|
"List of module names to apply LoRA to. "
|
|
"If None, defaults to ['q_proj', 'v_proj'] for most models."
|
|
),
|
|
)
|
|
|
|
# Shared memory mode (for shared_vllm mode - NCCL weight broadcast)
|
|
use_shared_memory: bool = Field(
|
|
False,
|
|
description=(
|
|
"Enable shared memory weight updates via NCCL. "
|
|
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1. "
|
|
"Weight updates are broadcast to vLLM's daemon process."
|
|
),
|
|
)
|
|
|
|
|
|
def check_atropos_api(timeout: float = 30.0) -> bool:
|
|
"""
|
|
Check if the Atropos API server is reachable.
|
|
|
|
Args:
|
|
timeout: Maximum time to wait for the server
|
|
|
|
Returns:
|
|
True if server is reachable
|
|
"""
|
|
import time as _time
|
|
start = _time.time()
|
|
while _time.time() - start < timeout:
|
|
try:
|
|
response = requests.get("http://localhost:8000/info", timeout=2)
|
|
if response.status_code == 200:
|
|
print("[Trainer] ✓ Atropos API server is reachable")
|
|
return True
|
|
except requests.exceptions.ConnectionError:
|
|
pass
|
|
except Exception as e:
|
|
print(f"[Trainer] Waiting for Atropos API... ({e})")
|
|
_time.sleep(1)
|
|
|
|
print("[Trainer] ⚠ Warning: Atropos API server not reachable")
|
|
return False
|
|
|
|
|
|
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
|
|
def register_trainer(config: TrainingConfig):
|
|
"""
|
|
Register the trainer with the Atropos API.
|
|
|
|
Verifies registration succeeded before returning.
|
|
"""
|
|
response = requests.post(
|
|
"http://localhost:8000/register",
|
|
json={
|
|
# wandb fields are required strings - use empty string if None
|
|
"wandb_group": config.wandb_group or "",
|
|
"wandb_project": config.wandb_project or "",
|
|
"batch_size": config.batch_size * config.gradient_accumulation_steps,
|
|
"max_token_len": config.seq_len,
|
|
"starting_step": 0,
|
|
"checkpoint_dir": config.save_path,
|
|
"save_checkpoint_interval": config.training_steps,
|
|
"num_steps": config.training_steps,
|
|
},
|
|
timeout=10,
|
|
)
|
|
|
|
# Check for HTTP errors
|
|
response.raise_for_status()
|
|
|
|
# Verify we got a valid response with UUID
|
|
data = response.json()
|
|
if "uuid" not in data:
|
|
raise RuntimeError(f"Registration failed: {data}")
|
|
|
|
print(f"[Trainer] ✓ Registered with Atropos API (uuid: {data['uuid']})")
|
|
|
|
|
|
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30))
|
|
def get_batch():
|
|
data = requests.get("http://localhost:8000/batch", timeout=10).json()
|
|
|
|
# Check if there was an error (trainer not registered)
|
|
if data.get("status") == "error":
|
|
raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}")
|
|
|
|
return data
|
|
|
|
|
|
def pad_data_to_good_offset(data, batch_size: int):
|
|
max_token_len = max(
|
|
[max([len(x) for x in item["tokens"]]) for item in data["batch"]]
|
|
)
|
|
# usually 64 is a good choice to ensure nonweird scaling behavior on GPUS
|
|
# so we pad to the nearest multiple of 64
|
|
good_multiple = 64
|
|
if (max_token_len - 1) % (good_multiple) != 0:
|
|
max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple
|
|
token_setup_len = (
|
|
max_token_len + 1
|
|
) # add 1 so we can make it causal at the proper length
|
|
else:
|
|
token_setup_len = max_token_len
|
|
max_token_len = (
|
|
max_token_len - 1
|
|
) # since it's causal we need to remove the last bit...
|
|
# pad all tokens to max_token_len and add to lists
|
|
input_ids = list()
|
|
labels = list()
|
|
advantages = list()
|
|
lengths = list()
|
|
temperatures = list()
|
|
for item in data["batch"]:
|
|
scores = item["scores"]
|
|
scores = np.array(scores)
|
|
# check if we have more than 1 score...
|
|
if len(scores) > 1:
|
|
scores = scores - scores.mean()
|
|
scores = scores / max(scores.std(), 1e-8)
|
|
item["scores"] = scores
|
|
if item["overrides"] is not None:
|
|
for i in range(len(item["overrides"])):
|
|
if item["overrides"][i].get("set_advantage_to_zero", False):
|
|
item["scores"][i] = 0
|
|
for i in range(len(item["tokens"])):
|
|
lengths.append(
|
|
math.ceil((len(item["tokens"][i]) - 1) / (good_multiple))
|
|
* good_multiple
|
|
)
|
|
label_item = np.concatenate(
|
|
[
|
|
np.array(item["masks"][i]),
|
|
np.full(
|
|
max(0, token_setup_len - len(item["tokens"][i])),
|
|
-100,
|
|
dtype=np.int32,
|
|
),
|
|
]
|
|
)
|
|
item["tokens"][i] = np.concatenate(
|
|
[
|
|
np.array(item["tokens"][i]),
|
|
np.zeros(
|
|
max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32
|
|
),
|
|
]
|
|
)
|
|
input_ids.append(item["tokens"][i][:-1])
|
|
labels.append(label_item[1:])
|
|
advantages.append(item["scores"][i])
|
|
# per-sample override -> group generation_params -> group_overrides - > 1.0
|
|
# need to update docs since this lets you set the temperature for each sample from the override
|
|
t = 1.0
|
|
if (
|
|
item.get("overrides")
|
|
and i < len(item["overrides"])
|
|
and isinstance(item["overrides"][i], dict)
|
|
and ("temperature" in item["overrides"][i])
|
|
):
|
|
t = float(item["overrides"][i]["temperature"])
|
|
elif item.get("generation_params") and (
|
|
"temperature" in item["generation_params"]
|
|
):
|
|
t = float(item["generation_params"]["temperature"])
|
|
elif item.get("group_overrides") and (
|
|
"temperature" in item["group_overrides"]
|
|
):
|
|
t = float(item["group_overrides"]["temperature"])
|
|
temperatures.append(t)
|
|
# combine all lists into tensors
|
|
token_batches = []
|
|
label_batches = []
|
|
advantage_batches = []
|
|
temperature_batches = []
|
|
for i in range(len(input_ids) // batch_size):
|
|
token_batches.append(
|
|
torch.tensor(
|
|
np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0)
|
|
)
|
|
)
|
|
label_batches.append(
|
|
torch.tensor(
|
|
np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0)
|
|
)
|
|
)
|
|
advantage_batches.append(
|
|
torch.tensor(
|
|
np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0)
|
|
).view(-1, 1)
|
|
)
|
|
# Temperatures: one per sample, shaped for broadcasting to [B, 1, 1]
|
|
temperature_batches.append(
|
|
torch.tensor(
|
|
np.array(
|
|
temperatures[i * batch_size : (i + 1) * batch_size],
|
|
dtype=np.float32,
|
|
)
|
|
).view(-1, 1, 1)
|
|
)
|
|
|
|
return token_batches, label_batches, advantage_batches, temperature_batches
|
|
|
|
|
|
def get_data(
|
|
batch_size: int, seq_len: int
|
|
) -> List[
|
|
Tuple[
|
|
List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]
|
|
]
|
|
]:
|
|
"""
|
|
getting data from the api
|
|
"""
|
|
batches = []
|
|
while True:
|
|
data = get_batch()
|
|
if data["batch"] is not None:
|
|
# Save the batch
|
|
with open("temp.json", "w", encoding="utf-8") as f:
|
|
json.dump(data, f)
|
|
# In case the inference runs ahead of the training, we loop until we don't have any more data
|
|
batches.append(pad_data_to_good_offset(data, batch_size))
|
|
elif len(batches) > 0:
|
|
# Return the batches
|
|
return batches
|
|
else:
|
|
time.sleep(1)
|
|
|
|
|
|
# =============================================================================
|
|
# Common Training Helpers (shared across all modes)
|
|
# =============================================================================
|
|
|
|
|
|
def setup_wandb(config: TrainingConfig) -> bool:
|
|
"""
|
|
Initialize Weights & Biases logging if enabled.
|
|
|
|
Args:
|
|
config: Training configuration
|
|
|
|
Returns:
|
|
True if wandb is active, False otherwise
|
|
"""
|
|
if not config.use_wandb:
|
|
return False
|
|
|
|
if not config.wandb_project:
|
|
print("Warning: wandb_project not set, disabling wandb.")
|
|
return False
|
|
|
|
# Generate random group name if not provided
|
|
if not config.wandb_group:
|
|
config.wandb_group = "".join(
|
|
random.choices(string.ascii_letters + string.digits, k=8)
|
|
)
|
|
|
|
try:
|
|
wandb.init(
|
|
project=config.wandb_project,
|
|
group=config.wandb_group,
|
|
config=config.dict(),
|
|
)
|
|
print(
|
|
f"Wandb logging enabled. Run: {wandb.run.name} "
|
|
f"(Project: {config.wandb_project})"
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error initializing wandb: {e}. Disabling wandb.")
|
|
return False
|
|
|
|
|
|
def load_model_and_tokenizer(
|
|
config: TrainingConfig,
|
|
bridge: Optional["VLLMWeightBridge"] = None,
|
|
) -> Tuple[torch.nn.Module, "AutoTokenizer"]:
|
|
"""
|
|
Load or attach to model based on weight_bridge_mode.
|
|
|
|
Args:
|
|
config: Training configuration
|
|
bridge: Optional weight bridge for shared_vllm mode
|
|
|
|
Returns:
|
|
Tuple of (model, tokenizer)
|
|
"""
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
|
|
|
if config.weight_bridge_mode == "shared_vllm" and bridge is not None:
|
|
# Shared vLLM mode: load model, weights will be broadcast via NCCL
|
|
print("[Setup] Loading model for shared vLLM mode...")
|
|
if config.use_shared_memory:
|
|
print("[Setup] NCCL shared memory mode - updates broadcast to vLLM daemon")
|
|
else:
|
|
print("[Setup] HTTP notification mode - vLLM notified of updates")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
config.model_name, torch_dtype=torch.bfloat16
|
|
)
|
|
model.to(config.device)
|
|
|
|
elif config.weight_bridge_mode == "lora_only":
|
|
model = _load_model_with_lora(config)
|
|
|
|
else:
|
|
print("[Setup] Loading model for legacy mode...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
config.model_name, torch_dtype=torch.bfloat16
|
|
)
|
|
model.to(config.device)
|
|
|
|
# Enable gradient checkpointing (saves memory)
|
|
# For LoRA, use PEFT's method; for others, use standard method
|
|
if config.weight_bridge_mode == "lora_only":
|
|
# PEFT models need gradient_checkpointing enabled on base model
|
|
# and require use_reentrant=False for proper gradient flow
|
|
if hasattr(model, "enable_input_require_grads"):
|
|
model.enable_input_require_grads()
|
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
|
else:
|
|
# Standard gradient checkpointing
|
|
model.gradient_checkpointing_enable()
|
|
|
|
model.train()
|
|
|
|
return model, tokenizer
|
|
|
|
|
|
def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module:
|
|
"""
|
|
Load base model and wrap with LoRA adapters.
|
|
|
|
Args:
|
|
config: Training configuration with LoRA settings
|
|
|
|
Returns:
|
|
PEFT model with LoRA adapters applied
|
|
"""
|
|
if not PEFT_AVAILABLE:
|
|
raise RuntimeError(
|
|
"PEFT library not available. Install with: pip install peft"
|
|
)
|
|
|
|
print("[Setup] Loading base model for LoRA mode...")
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
config.model_name, torch_dtype=torch.bfloat16
|
|
)
|
|
base_model.to(config.device)
|
|
|
|
# Determine target modules
|
|
target_modules = config.lora_target_modules
|
|
if target_modules is None:
|
|
# Default modules for most transformer models
|
|
target_modules = ["q_proj", "v_proj"]
|
|
|
|
print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}")
|
|
print(f"Target modules: {target_modules}")
|
|
|
|
lora_config = LoraConfig(
|
|
task_type=TaskType.CAUSAL_LM,
|
|
r=config.lora_r,
|
|
lora_alpha=config.lora_alpha,
|
|
lora_dropout=config.lora_dropout,
|
|
target_modules=target_modules,
|
|
bias="none",
|
|
)
|
|
|
|
model = get_peft_model(base_model, lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
return model
|
|
|
|
|
|
def save_lora_checkpoint(
|
|
model: torch.nn.Module,
|
|
save_path: str,
|
|
step: int,
|
|
is_final: bool = False,
|
|
) -> str:
|
|
"""
|
|
Save LoRA adapter checkpoint.
|
|
|
|
Args:
|
|
model: PEFT model with LoRA adapters
|
|
save_path: Base directory for checkpoints
|
|
step: Current training step
|
|
is_final: Whether this is the final checkpoint
|
|
|
|
Returns:
|
|
Path where adapter was saved
|
|
"""
|
|
if is_final:
|
|
adapter_path = os.path.join(save_path, "final_adapter")
|
|
else:
|
|
adapter_path = os.path.join(save_path, f"adapter_step_{step}")
|
|
|
|
print(f" Saving LoRA adapter to {adapter_path}...")
|
|
|
|
if os.path.exists(adapter_path):
|
|
shutil.rmtree(adapter_path)
|
|
os.makedirs(adapter_path, exist_ok=True)
|
|
|
|
# Save only the adapter weights (much smaller than full model)
|
|
model.save_pretrained(adapter_path)
|
|
|
|
print(" Adapter saved.")
|
|
return adapter_path
|
|
|
|
|
|
def compute_grpo_loss(
|
|
model: torch.nn.Module,
|
|
tokens: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
advantages: torch.Tensor,
|
|
temperatures: torch.Tensor,
|
|
gradient_accumulation_steps: int,
|
|
) -> Tuple[torch.Tensor, dict]:
|
|
"""
|
|
Compute GRPO loss for a single micro-batch.
|
|
|
|
Args:
|
|
model: The model to compute loss for
|
|
tokens: Input token IDs [batch, seq_len]
|
|
labels: Target labels [batch, seq_len]
|
|
advantages: Advantage values [batch, 1]
|
|
temperatures: Temperature values [batch, 1, 1]
|
|
gradient_accumulation_steps: Number of accumulation steps
|
|
|
|
Returns:
|
|
Tuple of (loss tensor, metrics dict)
|
|
"""
|
|
# Forward pass
|
|
outputs = model(tokens)
|
|
logits = outputs.logits
|
|
|
|
# Temperature scaling
|
|
t = temperatures.to(logits.device, logits.dtype)
|
|
t = torch.where(t <= 0, torch.ones_like(t), t)
|
|
logits = logits / t
|
|
|
|
# Log probabilities per token
|
|
logp_per_token = -F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)),
|
|
labels.view(-1),
|
|
reduction="none",
|
|
ignore_index=-100,
|
|
).view(labels.shape)
|
|
|
|
# Masking based on labels != -100
|
|
mask = (labels != -100).float()
|
|
|
|
# Compute metrics (no grad needed)
|
|
with torch.no_grad():
|
|
pos = (advantages > 0).float()
|
|
neg = (advantages <= 0).float()
|
|
mask_float = mask.to(logp_per_token.dtype)
|
|
mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8)
|
|
avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum
|
|
pos_logp = (logp_per_token * pos).mean().item()
|
|
neg_logp = (logp_per_token * neg).mean().item()
|
|
|
|
# GRPO loss
|
|
grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach())
|
|
grpo_loss = (
|
|
((-grpo_loss_term * mask).sum(-1) / mask.sum(-1))
|
|
* advantages.to(logp_per_token.device)
|
|
).mean() / gradient_accumulation_steps
|
|
|
|
metrics = {
|
|
"pos_logp": pos_logp,
|
|
"neg_logp": neg_logp,
|
|
"avg_logp": avg_logp,
|
|
"pos_count": pos.sum().item(),
|
|
"neg_count": neg.sum().item(),
|
|
}
|
|
|
|
return grpo_loss, metrics
|
|
|
|
|
|
def run_training_step(
|
|
model: torch.nn.Module,
|
|
optimizer: torch.optim.Optimizer,
|
|
token_batches: List[torch.Tensor],
|
|
label_batches: List[torch.Tensor],
|
|
advantage_batches: List[torch.Tensor],
|
|
temperature_batches: List[torch.Tensor],
|
|
config: TrainingConfig,
|
|
) -> dict:
|
|
"""
|
|
Run a single training step (forward, backward, optimizer step).
|
|
|
|
Args:
|
|
model: The model to train
|
|
optimizer: The optimizer
|
|
token_batches: List of token tensors
|
|
label_batches: List of label tensors
|
|
advantage_batches: List of advantage tensors
|
|
temperature_batches: List of temperature tensors
|
|
config: Training configuration
|
|
|
|
Returns:
|
|
Dict of training metrics for this step
|
|
"""
|
|
total_loss = 0.0
|
|
total_pos_logp = 0.0
|
|
total_neg_logp = 0.0
|
|
total_pos = 0.0
|
|
total_neg = 0.0
|
|
|
|
# Accumulate gradients over micro-batches
|
|
for tokens, labels, advantages, temperatures in zip(
|
|
token_batches, label_batches, advantage_batches, temperature_batches
|
|
):
|
|
tokens = tokens.to(config.device)
|
|
labels = labels.to(config.device)
|
|
advantages = advantages.to(config.device)
|
|
|
|
loss, metrics = compute_grpo_loss(
|
|
model, tokens, labels, advantages, temperatures,
|
|
config.gradient_accumulation_steps
|
|
)
|
|
|
|
loss.backward()
|
|
total_loss += loss.item()
|
|
total_pos_logp += metrics["pos_logp"]
|
|
total_neg_logp += metrics["neg_logp"]
|
|
total_pos += metrics["pos_count"]
|
|
total_neg += metrics["neg_count"]
|
|
|
|
# Gradient clipping and optimizer step
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# Normalize metrics
|
|
if total_pos > 0:
|
|
total_pos_logp /= total_pos
|
|
if total_neg > 0:
|
|
total_neg_logp /= total_neg
|
|
|
|
return {
|
|
"loss": total_loss,
|
|
"grad_norm": grad_norm.item(),
|
|
"pos_logp": total_pos_logp,
|
|
"neg_logp": total_neg_logp,
|
|
}
|
|
|
|
|
|
def save_checkpoint(
|
|
model: torch.nn.Module,
|
|
tokenizer: "AutoTokenizer",
|
|
save_path: str,
|
|
step: int,
|
|
is_final: bool = False,
|
|
) -> str:
|
|
"""
|
|
Save model checkpoint.
|
|
|
|
Args:
|
|
model: Model to save
|
|
tokenizer: Tokenizer to save
|
|
save_path: Base directory for checkpoints
|
|
step: Current training step
|
|
is_final: Whether this is the final checkpoint
|
|
|
|
Returns:
|
|
Path where checkpoint was saved
|
|
"""
|
|
if is_final:
|
|
checkpoint_path = os.path.join(save_path, "final_model")
|
|
else:
|
|
checkpoint_path = os.path.join(save_path, f"step_{step}")
|
|
|
|
print(f" Saving checkpoint to {checkpoint_path}...")
|
|
|
|
if os.path.exists(checkpoint_path):
|
|
shutil.rmtree(checkpoint_path)
|
|
os.makedirs(checkpoint_path, exist_ok=True)
|
|
|
|
model.save_pretrained(checkpoint_path)
|
|
tokenizer.save_pretrained(checkpoint_path)
|
|
|
|
print(" Checkpoint saved.")
|
|
return checkpoint_path
|
|
|
|
|
|
def log_metrics(
|
|
metrics: dict,
|
|
step: int,
|
|
use_wandb: bool,
|
|
extra_metrics: Optional[dict] = None,
|
|
) -> None:
|
|
"""
|
|
Log training metrics to console and optionally wandb.
|
|
|
|
Args:
|
|
metrics: Dict of metrics from training step
|
|
step: Current step number
|
|
use_wandb: Whether to log to wandb
|
|
extra_metrics: Optional additional metrics to log
|
|
"""
|
|
# Console output with timing info
|
|
timing_str = ""
|
|
if "step_time" in metrics:
|
|
timing_str += f", Step time: {metrics['step_time']:.2f}s"
|
|
if "sync_time" in metrics and metrics["sync_time"] > 0:
|
|
timing_str += f", Sync time: {metrics['sync_time']:.2f}s"
|
|
if "data_fetch_time" in metrics:
|
|
timing_str += f", Data fetch: {metrics['data_fetch_time']:.2f}s"
|
|
if "gpu_memory_gb" in metrics:
|
|
timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB"
|
|
|
|
print(f" Loss: {metrics['loss']:.4f}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}")
|
|
|
|
if use_wandb:
|
|
log_dict = {
|
|
"train/loss": metrics["loss"],
|
|
"train/grad_norm": metrics["grad_norm"],
|
|
"train/pos_logp": metrics["pos_logp"],
|
|
"train/neg_logp": metrics["neg_logp"],
|
|
}
|
|
# Add timing metrics if present
|
|
if "step_time" in metrics:
|
|
log_dict["train/step_time"] = metrics["step_time"]
|
|
if "sync_time" in metrics:
|
|
log_dict["train/sync_time"] = metrics["sync_time"]
|
|
if "data_fetch_time" in metrics:
|
|
log_dict["train/data_fetch_time"] = metrics["data_fetch_time"]
|
|
if "gpu_memory_gb" in metrics:
|
|
log_dict["train/gpu_memory_gb"] = metrics["gpu_memory_gb"]
|
|
if "gpu_memory_reserved_gb" in metrics:
|
|
log_dict["train/gpu_memory_reserved_gb"] = metrics["gpu_memory_reserved_gb"]
|
|
if extra_metrics:
|
|
log_dict.update(extra_metrics)
|
|
wandb.log(log_dict, step=step)
|
|
|
|
|
|
def finalize_training(
|
|
use_wandb: bool,
|
|
training_start_time: Optional[float] = None,
|
|
mode: str = "unknown",
|
|
total_steps: int = 0,
|
|
benchmark_stats: Optional[dict] = None,
|
|
) -> None:
|
|
"""Clean up after training and log benchmark summary.
|
|
|
|
Args:
|
|
use_wandb: Whether wandb is enabled
|
|
training_start_time: Start time of training
|
|
mode: Training mode name
|
|
total_steps: Total steps completed
|
|
benchmark_stats: Dict with lists of per-step metrics:
|
|
- step_times: List of step durations
|
|
- sync_times: List of sync durations
|
|
- data_fetch_times: List of data fetch durations
|
|
- gpu_memories: List of GPU memory readings (GB)
|
|
"""
|
|
print("\nTraining finished.")
|
|
|
|
# Default empty stats
|
|
if benchmark_stats is None:
|
|
benchmark_stats = {}
|
|
|
|
# Log benchmark summary
|
|
if training_start_time is not None:
|
|
total_time = time.time() - training_start_time
|
|
peak_gpu_mem_gb = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
|
|
|
# Calculate averages from collected stats
|
|
step_times = benchmark_stats.get("step_times", [])
|
|
sync_times = benchmark_stats.get("sync_times", [])
|
|
data_fetch_times = benchmark_stats.get("data_fetch_times", [])
|
|
gpu_memories = benchmark_stats.get("gpu_memories", [])
|
|
|
|
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
|
|
total_step_time = sum(step_times)
|
|
avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0
|
|
total_sync_time = sum(sync_times)
|
|
avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0
|
|
total_data_fetch = sum(data_fetch_times)
|
|
avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f"BENCHMARK SUMMARY ({mode})")
|
|
print(f"{'='*70}")
|
|
print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)")
|
|
print(f" Total steps: {total_steps}")
|
|
print(f" ")
|
|
print(f" TIMING BREAKDOWN:")
|
|
print(f" Avg step time: {avg_step_time:.2f}s")
|
|
print(f" Total step time: {total_step_time:.2f}s")
|
|
print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)")
|
|
print(f" Total sync time: {total_sync_time:.2f}s")
|
|
print(f" Avg data fetch time: {avg_data_fetch:.2f}s")
|
|
print(f" Total data fetch time: {total_data_fetch:.2f}s")
|
|
print(f" ")
|
|
print(f" MEMORY:")
|
|
print(f" Peak GPU memory: {peak_gpu_mem_gb:.2f} GB")
|
|
print(f" Avg GPU memory: {avg_gpu_mem:.2f} GB")
|
|
print(f"{'='*70}\n")
|
|
|
|
if use_wandb:
|
|
# Total time metrics
|
|
wandb.summary["benchmark/total_time_seconds"] = total_time
|
|
wandb.summary["benchmark/total_time_minutes"] = total_time / 60
|
|
wandb.summary["benchmark/mode"] = mode
|
|
wandb.summary["benchmark/total_steps"] = total_steps
|
|
|
|
# Step timing metrics
|
|
wandb.summary["benchmark/avg_step_time_seconds"] = avg_step_time
|
|
wandb.summary["benchmark/total_step_time_seconds"] = total_step_time
|
|
|
|
# Sync timing metrics
|
|
wandb.summary["benchmark/avg_sync_time_seconds"] = avg_sync_time
|
|
wandb.summary["benchmark/total_sync_time_seconds"] = total_sync_time
|
|
wandb.summary["benchmark/num_syncs"] = len(sync_times)
|
|
|
|
# Data fetch timing metrics
|
|
wandb.summary["benchmark/avg_data_fetch_time_seconds"] = avg_data_fetch
|
|
wandb.summary["benchmark/total_data_fetch_time_seconds"] = total_data_fetch
|
|
|
|
# Memory metrics
|
|
wandb.summary["benchmark/peak_gpu_memory_gb"] = peak_gpu_mem_gb
|
|
wandb.summary["benchmark/avg_gpu_memory_gb"] = avg_gpu_mem
|
|
|
|
if use_wandb:
|
|
wandb.finish()
|
|
|
|
|
|
def train(config: TrainingConfig):
|
|
"""
|
|
Legacy GRPO training with periodic vLLM restarts.
|
|
|
|
This mode saves checkpoints to disk and restarts vLLM to pick up new weights.
|
|
Use weight_bridge_mode='shared_vllm' for in-place weight updates without restarts.
|
|
"""
|
|
global vllm_process
|
|
training_start_time = time.time()
|
|
|
|
# === Setup ===
|
|
use_wandb = setup_wandb(config)
|
|
model, tokenizer = load_model_and_tokenizer(config)
|
|
optimizer = AdamW(model.parameters(), lr=config.lr)
|
|
|
|
print(f"\n{'='*60}")
|
|
print("LEGACY MODE (checkpoint + vLLM restart)")
|
|
print(f"{'='*60}")
|
|
print(f"Training for {config.training_steps} steps on {config.device}")
|
|
print(f"vLLM restart interval: every {config.vllm_restart_interval} steps")
|
|
print(f"Save path: {config.save_path}")
|
|
print(f"{'='*60}\n")
|
|
|
|
os.makedirs(config.save_path, exist_ok=True)
|
|
register_trainer(config)
|
|
|
|
# Launch initial vLLM server
|
|
vllm_process = _launch_vllm_server(config, config.model_name)
|
|
|
|
# === Benchmark tracking ===
|
|
benchmark_stats = {
|
|
"step_times": [],
|
|
"sync_times": [],
|
|
"data_fetch_times": [],
|
|
"gpu_memories": [],
|
|
}
|
|
|
|
# === Training Loop ===
|
|
batches = []
|
|
for step in range(config.training_steps):
|
|
print(f"\nStep {step+1}/{config.training_steps}")
|
|
|
|
# Track data fetch time
|
|
data_fetch_start = time.time()
|
|
if len(batches) == 0:
|
|
batches = get_data(config.batch_size, config.seq_len)
|
|
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
|
data_fetch_time = time.time() - data_fetch_start
|
|
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
|
|
|
# Terminate vLLM before training step (to free GPU memory)
|
|
should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1
|
|
if should_sync:
|
|
_terminate_vllm_process()
|
|
|
|
# Track step time
|
|
step_start = time.time()
|
|
|
|
# Run training step using common helper
|
|
metrics = run_training_step(
|
|
model, optimizer,
|
|
token_batches, label_batches, advantage_batches, temperature_batches,
|
|
config
|
|
)
|
|
|
|
step_time = time.time() - step_start
|
|
benchmark_stats["step_times"].append(step_time)
|
|
|
|
# Track GPU memory
|
|
if torch.cuda.is_available():
|
|
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9
|
|
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9
|
|
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
|
else:
|
|
gpu_mem_gb = 0
|
|
gpu_mem_reserved_gb = 0
|
|
|
|
# Track sync time
|
|
sync_time = 0
|
|
if should_sync:
|
|
sync_start = time.time()
|
|
checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
|
torch.cuda.empty_cache()
|
|
vllm_process = _launch_vllm_server(config, checkpoint_path)
|
|
sync_time = time.time() - sync_start
|
|
benchmark_stats["sync_times"].append(sync_time)
|
|
|
|
# Add timing metrics
|
|
metrics["step_time"] = step_time
|
|
metrics["sync_time"] = sync_time
|
|
metrics["data_fetch_time"] = data_fetch_time
|
|
metrics["gpu_memory_gb"] = gpu_mem_gb
|
|
metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb
|
|
|
|
# Log metrics
|
|
log_metrics(metrics, step + 1, use_wandb, {
|
|
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
|
})
|
|
|
|
# Check for unexpected vLLM termination
|
|
_check_vllm_process_health()
|
|
|
|
# === Cleanup ===
|
|
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
|
finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats)
|
|
|
|
|
|
# =============================================================================
|
|
# vLLM Process Management (Legacy Mode Only)
|
|
# =============================================================================
|
|
|
|
|
|
def _launch_vllm_server(config: TrainingConfig, model_path: str) -> Optional[subprocess.Popen]:
|
|
"""Launch a vLLM server process using our custom vllm_api_server.py.
|
|
|
|
Uses the custom server instead of standard vLLM because:
|
|
- Standard vLLM only has /v1/completions (OpenAI-compatible)
|
|
- Our custom server has /generate endpoint needed by VLLMServer class
|
|
- This allows proper tokens_and_logprobs_completion support
|
|
"""
|
|
global vllm_process
|
|
|
|
# Use our custom vllm_api_server.py instead of standard vLLM
|
|
# This provides the /generate endpoint that VLLMServer needs
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
custom_server_path = os.path.join(script_dir, "vllm_api_server.py")
|
|
|
|
vllm_command = [
|
|
"python", custom_server_path,
|
|
"--model", model_path,
|
|
"--port", str(config.vllm_port),
|
|
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
|
]
|
|
# Add served-model-name if using checkpoint path
|
|
if model_path != config.model_name:
|
|
vllm_command.extend(["--served-model-name", config.model_name])
|
|
|
|
print(f" Launching vLLM: {' '.join(vllm_command)}")
|
|
|
|
try:
|
|
proc = subprocess.Popen(vllm_command)
|
|
print(f" vLLM launched with PID: {proc.pid}")
|
|
|
|
# Check for immediate startup errors
|
|
try:
|
|
proc.communicate(timeout=2)
|
|
if proc.returncode is not None and proc.returncode != 0:
|
|
print(" WARNING: vLLM failed to start")
|
|
return None
|
|
except subprocess.TimeoutExpired:
|
|
print(" vLLM process started (check logs for details)")
|
|
|
|
return proc
|
|
|
|
except FileNotFoundError:
|
|
print(" ERROR: vLLM not found. Is it installed?")
|
|
return None
|
|
except Exception as e:
|
|
print(f" ERROR launching vLLM: {e}")
|
|
return None
|
|
|
|
|
|
def _terminate_vllm_process() -> None:
|
|
"""Terminate the running vLLM process if any."""
|
|
global vllm_process
|
|
|
|
if vllm_process is None:
|
|
return
|
|
|
|
print(" Terminating vLLM process...")
|
|
vllm_process.terminate()
|
|
try:
|
|
vllm_process.wait(timeout=5)
|
|
except subprocess.TimeoutExpired:
|
|
print(" vLLM did not terminate gracefully, killing...")
|
|
vllm_process.kill()
|
|
vllm_process.wait()
|
|
vllm_process = None
|
|
|
|
|
|
def _check_vllm_process_health() -> None:
|
|
"""Check if vLLM process terminated unexpectedly (legacy mode)."""
|
|
global vllm_process
|
|
|
|
if vllm_process is not None and vllm_process.poll() is not None:
|
|
print(f" WARNING: vLLM terminated unexpectedly (code: {vllm_process.returncode})")
|
|
vllm_process = None
|
|
|
|
|
|
def train_shared_vllm(config: TrainingConfig):
|
|
"""
|
|
GRPO training with shared vLLM weights.
|
|
|
|
Instead of saving checkpoints and restarting vLLM, this mode:
|
|
1. Joins the same distributed group as vLLM
|
|
2. Attaches to vLLM's weight tensors directly
|
|
3. optimizer.step() modifies vLLM's weights in-place
|
|
4. vLLM immediately uses updated weights (no restart!)
|
|
"""
|
|
if not BRIDGE_AVAILABLE:
|
|
raise RuntimeError(
|
|
"vLLM weight bridge not available. "
|
|
"Ensure vllm_weight_bridge.py is in the same directory."
|
|
)
|
|
|
|
training_start_time = time.time()
|
|
|
|
# === Setup ===
|
|
use_wandb = setup_wandb(config)
|
|
|
|
print(f"\n{'='*60}")
|
|
if config.use_shared_memory:
|
|
print("SHARED VLLM MODE (NCCL BROADCAST)")
|
|
print(">>> Weights broadcast to vLLM via NCCL!")
|
|
else:
|
|
print("SHARED VLLM MODE (HTTP notifications)")
|
|
print(f"{'='*60}")
|
|
print(f"Model: {config.model_name}")
|
|
print(f"Shared Memory: {config.use_shared_memory}")
|
|
print(f"Distributed: rank={config.trainer_rank}/{config.world_size}")
|
|
print(f"Init method: {config.init_method}")
|
|
print(f"Inference nodes: {config.num_inference_nodes}")
|
|
print(f"Save path: {config.save_path}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Initialize weight bridge
|
|
print("[1/3] Initializing weight bridge...")
|
|
bridge = create_bridge_from_training_config(config)
|
|
|
|
# Load model with bridge attachment
|
|
print("[2/3] Loading model with shared weights...")
|
|
model, tokenizer = load_model_and_tokenizer(config, bridge=bridge)
|
|
optimizer = AdamW(model.parameters(), lr=config.lr)
|
|
|
|
# For NCCL mode, build mapping between trainer's and vLLM's param names
|
|
if config.use_shared_memory:
|
|
bridge.build_param_mapping(model)
|
|
|
|
print(f"[3/3] Starting training for {config.training_steps} steps")
|
|
print("NOTE: vLLM sees weight updates immediately after each step!")
|
|
print("-" * 60)
|
|
|
|
os.makedirs(config.save_path, exist_ok=True)
|
|
|
|
# Check Atropos API and register BEFORE training loop
|
|
print("\n[Setup] Connecting to Atropos API...")
|
|
if not check_atropos_api(timeout=30):
|
|
raise RuntimeError(
|
|
"Atropos API server not reachable. "
|
|
"Please start it with: run-api"
|
|
)
|
|
register_trainer(config)
|
|
|
|
# === Benchmark tracking ===
|
|
benchmark_stats = {
|
|
"step_times": [],
|
|
"sync_times": [], # For shared mode, this is the notify_update time
|
|
"data_fetch_times": [],
|
|
"gpu_memories": [],
|
|
}
|
|
|
|
# === Training Loop ===
|
|
batches = []
|
|
for step in range(config.training_steps):
|
|
print(f"\nStep {step+1}/{config.training_steps}")
|
|
|
|
# Track data fetch time
|
|
data_fetch_start = time.time()
|
|
if len(batches) == 0:
|
|
batches = get_data(config.batch_size, config.seq_len)
|
|
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
|
data_fetch_time = time.time() - data_fetch_start
|
|
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
|
|
|
# Track step time
|
|
step_start = time.time()
|
|
|
|
# Run training step using common helper
|
|
metrics = run_training_step(
|
|
model, optimizer,
|
|
token_batches, label_batches, advantage_batches, temperature_batches,
|
|
config
|
|
)
|
|
|
|
step_time = time.time() - step_start
|
|
benchmark_stats["step_times"].append(step_time)
|
|
|
|
# Track GPU memory
|
|
if torch.cuda.is_available():
|
|
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9
|
|
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9
|
|
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
|
else:
|
|
gpu_mem_gb = 0
|
|
gpu_mem_reserved_gb = 0
|
|
|
|
# Sync weights with vLLM
|
|
sync_start = time.time()
|
|
if config.use_shared_memory:
|
|
# NCCL broadcast mode - weights sent directly to vLLM daemon
|
|
bridge.broadcast_weights(model)
|
|
print(f" [SHARED] Weights broadcast via NCCL - step {step+1} (sync: {(time.time()-sync_start)*1000:.1f}ms)")
|
|
else:
|
|
# HTTP notification mode - just notify
|
|
bridge.notify_update()
|
|
print(f" [SHARED] Update notification sent - step {step+1}")
|
|
sync_time = time.time() - sync_start
|
|
benchmark_stats["sync_times"].append(sync_time)
|
|
|
|
# Add timing metrics
|
|
metrics["step_time"] = step_time
|
|
metrics["sync_time"] = sync_time
|
|
metrics["data_fetch_time"] = data_fetch_time
|
|
metrics["gpu_memory_gb"] = gpu_mem_gb
|
|
metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb
|
|
|
|
# Log metrics
|
|
log_metrics(metrics, step + 1, use_wandb, {
|
|
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
|
"bridge/update_count": step + 1,
|
|
})
|
|
|
|
# Periodic checkpoint save (for recovery, not for vLLM sync)
|
|
if (step + 1) % config.vllm_restart_interval == 0:
|
|
save_checkpoint(model, tokenizer, config.save_path, step + 1)
|
|
|
|
# === Cleanup ===
|
|
bridge.cleanup()
|
|
save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True)
|
|
finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats)
|
|
|
|
|
|
def _check_vllm_health(port: int) -> bool:
|
|
"""Check if external vLLM server is running and healthy."""
|
|
try:
|
|
response = requests.get(f"http://localhost:{port}/health", timeout=5)
|
|
return response.status_code == 200
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _hotswap_lora_adapter(port: int, adapter_path: str) -> bool:
|
|
"""
|
|
Request vLLM to hot-swap to a new LoRA adapter.
|
|
|
|
Args:
|
|
port: vLLM server port
|
|
adapter_path: Path to the saved adapter directory
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
response = requests.post(
|
|
f"http://localhost:{port}/lora/load",
|
|
json={"adapter_path": adapter_path},
|
|
timeout=30,
|
|
)
|
|
if response.status_code == 200:
|
|
print(f" [LORA] Hot-swapped adapter: {adapter_path}")
|
|
return True
|
|
else:
|
|
print(f" [LORA] Hot-swap failed: {response.text}")
|
|
return False
|
|
except Exception as e:
|
|
print(f" [LORA] Hot-swap request failed: {e}")
|
|
return False
|
|
|
|
|
|
def train_lora(config: TrainingConfig):
|
|
"""
|
|
GRPO training with LoRA adapters.
|
|
|
|
This mode keeps the base model frozen and only trains LoRA adapter weights.
|
|
|
|
REQUIRES: External vLLM server running via vllm_api_server.py
|
|
|
|
Benefits:
|
|
- Much faster training (fewer parameters)
|
|
- Smaller checkpoint sizes (adapter only, not full model)
|
|
- Adapters can be hot-swapped in vLLM via /lora/load endpoint
|
|
"""
|
|
if not PEFT_AVAILABLE:
|
|
raise RuntimeError(
|
|
"PEFT library required for LoRA mode. Install with: pip install peft"
|
|
)
|
|
|
|
training_start_time = time.time()
|
|
|
|
# === Setup ===
|
|
use_wandb = setup_wandb(config)
|
|
|
|
print(f"\n{'='*60}")
|
|
print("LORA MODE (adapter-only training)")
|
|
print(f"{'='*60}")
|
|
print(f"Base model: {config.model_name}")
|
|
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
|
print(f"Save path: {config.save_path}")
|
|
print(f"vLLM port: {config.vllm_port}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Check that external vLLM is running
|
|
print("[1/3] Checking external vLLM server...")
|
|
if not _check_vllm_health(config.vllm_port):
|
|
print(f"\nERROR: vLLM server not running on port {config.vllm_port}")
|
|
print("\nLoRA mode requires an external vLLM server. Start it first:")
|
|
print(f" python example_trainer/vllm_api_server.py \\")
|
|
print(f" --model {config.model_name} \\")
|
|
print(f" --port {config.vllm_port} \\")
|
|
print(f" --gpu-memory-utilization 0.45")
|
|
raise RuntimeError(f"External vLLM server required on port {config.vllm_port}")
|
|
print(f"vLLM server healthy on port {config.vllm_port}")
|
|
|
|
# Load model with LoRA adapters
|
|
print("[2/3] Loading model with LoRA adapters...")
|
|
model, tokenizer = load_model_and_tokenizer(config)
|
|
|
|
# Only optimize LoRA parameters (base model is frozen)
|
|
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
|
optimizer = AdamW(trainable_params, lr=config.lr)
|
|
|
|
print(f"[3/3] Starting training for {config.training_steps} steps")
|
|
print("-" * 60)
|
|
|
|
os.makedirs(config.save_path, exist_ok=True)
|
|
register_trainer(config)
|
|
|
|
# NOTE: No vLLM launch here - using external vLLM server
|
|
|
|
# === Benchmark tracking ===
|
|
benchmark_stats = {
|
|
"step_times": [],
|
|
"sync_times": [], # For LoRA mode, this is adapter save + hot-swap time
|
|
"data_fetch_times": [],
|
|
"gpu_memories": [],
|
|
}
|
|
|
|
# === Training Loop ===
|
|
batches = []
|
|
for step in range(config.training_steps):
|
|
print(f"\nStep {step+1}/{config.training_steps}")
|
|
|
|
# Track data fetch time
|
|
data_fetch_start = time.time()
|
|
if len(batches) == 0:
|
|
batches = get_data(config.batch_size, config.seq_len)
|
|
token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0)
|
|
data_fetch_time = time.time() - data_fetch_start
|
|
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
|
|
|
# Track step time
|
|
step_start = time.time()
|
|
|
|
# Run training step
|
|
metrics = run_training_step(
|
|
model, optimizer,
|
|
token_batches, label_batches, advantage_batches, temperature_batches,
|
|
config
|
|
)
|
|
|
|
step_time = time.time() - step_start
|
|
benchmark_stats["step_times"].append(step_time)
|
|
|
|
# Track GPU memory
|
|
if torch.cuda.is_available():
|
|
gpu_mem_gb = torch.cuda.memory_allocated() / 1e9
|
|
gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9
|
|
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
|
else:
|
|
gpu_mem_gb = 0
|
|
gpu_mem_reserved_gb = 0
|
|
|
|
# Track sync time (adapter save + hot-swap)
|
|
sync_time = 0
|
|
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
|
if should_sync:
|
|
sync_start = time.time()
|
|
adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
|
|
# Try to hot-swap the adapter in vLLM (non-blocking, best effort)
|
|
_hotswap_lora_adapter(config.vllm_port, adapter_path)
|
|
sync_time = time.time() - sync_start
|
|
benchmark_stats["sync_times"].append(sync_time)
|
|
|
|
# Add timing metrics
|
|
metrics["step_time"] = step_time
|
|
metrics["sync_time"] = sync_time
|
|
metrics["data_fetch_time"] = data_fetch_time
|
|
metrics["gpu_memory_gb"] = gpu_mem_gb
|
|
metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb
|
|
|
|
# Log metrics
|
|
log_metrics(metrics, step + 1, use_wandb, {
|
|
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
|
"lora/trainable_params": sum(p.numel() for p in trainable_params),
|
|
})
|
|
|
|
# === Cleanup ===
|
|
# NOTE: No vLLM termination - external server keeps running
|
|
|
|
# Save final adapter (track this sync time too)
|
|
final_sync_start = time.time()
|
|
final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True)
|
|
|
|
# Hot-swap to final adapter
|
|
_hotswap_lora_adapter(config.vllm_port, final_adapter_path)
|
|
final_sync_time = time.time() - final_sync_start
|
|
benchmark_stats["sync_times"].append(final_sync_time)
|
|
|
|
finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats)
|
|
|
|
# Also save tokenizer for convenience
|
|
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
|
tokenizer.save_pretrained(tokenizer_path)
|
|
print(f"Tokenizer saved to {tokenizer_path}")
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
"""Parse command-line arguments for the GRPO trainer."""
|
|
parser = argparse.ArgumentParser(
|
|
description="GRPO Trainer with optional shared-weight vLLM integration",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
# --- Core training arguments ---
|
|
parser.add_argument(
|
|
"--model-name",
|
|
type=str,
|
|
required=True,
|
|
help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')",
|
|
)
|
|
parser.add_argument(
|
|
"--lr",
|
|
type=float,
|
|
default=1e-5,
|
|
help="Learning rate for the optimizer",
|
|
)
|
|
parser.add_argument(
|
|
"--training-steps",
|
|
type=int,
|
|
default=10,
|
|
help="Number of training steps to run",
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
default=2,
|
|
help="Batch size for training",
|
|
)
|
|
parser.add_argument(
|
|
"--seq-len",
|
|
type=int,
|
|
default=2048,
|
|
help="Maximum sequence length",
|
|
)
|
|
parser.add_argument(
|
|
"--gradient-accumulation-steps",
|
|
type=int,
|
|
default=32,
|
|
help="Number of gradient accumulation steps",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="Device to train on (cuda/cpu)",
|
|
)
|
|
parser.add_argument(
|
|
"--save-path",
|
|
type=str,
|
|
default="trained_model_checkpoints",
|
|
help="Directory to save model checkpoints",
|
|
)
|
|
|
|
# --- vLLM arguments ---
|
|
parser.add_argument(
|
|
"--vllm-restart-interval",
|
|
type=int,
|
|
default=3,
|
|
help="Restart vLLM every N training steps (legacy mode only)",
|
|
)
|
|
parser.add_argument(
|
|
"--vllm-port",
|
|
type=int,
|
|
default=9001,
|
|
help="Port for the vLLM server",
|
|
)
|
|
parser.add_argument(
|
|
"--vllm-gpu-memory-utilization",
|
|
type=float,
|
|
default=0.45,
|
|
help="GPU memory utilization for vLLM server (0.0-1.0)",
|
|
)
|
|
|
|
# --- Wandb arguments ---
|
|
parser.add_argument(
|
|
"--use-wandb",
|
|
action="store_true",
|
|
help="Enable Weights & Biases logging",
|
|
)
|
|
parser.add_argument(
|
|
"--wandb-project",
|
|
type=str,
|
|
default=None,
|
|
help="Wandb project name",
|
|
)
|
|
parser.add_argument(
|
|
"--wandb-group",
|
|
type=str,
|
|
default=None,
|
|
help="Wandb group name",
|
|
)
|
|
|
|
# --- Pipeline / weight bridge arguments ---
|
|
parser.add_argument(
|
|
"--weight-bridge-mode",
|
|
type=str,
|
|
choices=["shared_vllm", "lora_only", "none"],
|
|
default="none",
|
|
help=(
|
|
"Weight sync mode: "
|
|
"'shared_vllm' = attach to vLLM shared memory, "
|
|
"'lora_only' = train LoRA adapters only, "
|
|
"'none' = legacy restart-based sync"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--trainer-rank",
|
|
type=int,
|
|
default=0,
|
|
help="Rank of this trainer in the distributed group",
|
|
)
|
|
parser.add_argument(
|
|
"--world-size",
|
|
type=int,
|
|
default=1,
|
|
help="Total processes in the distributed group",
|
|
)
|
|
parser.add_argument(
|
|
"--init-method",
|
|
type=str,
|
|
default="env://",
|
|
help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')",
|
|
)
|
|
parser.add_argument(
|
|
"--num-inference-nodes",
|
|
type=int,
|
|
default=0,
|
|
help="Number of inference nodes to coordinate with (0 = single-node local)",
|
|
)
|
|
|
|
# --- LoRA arguments ---
|
|
parser.add_argument(
|
|
"--lora-r",
|
|
type=int,
|
|
default=16,
|
|
help="LoRA rank (dimension of low-rank matrices)",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-alpha",
|
|
type=int,
|
|
default=32,
|
|
help="LoRA alpha (scaling factor, typically 2x rank)",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-dropout",
|
|
type=float,
|
|
default=0.05,
|
|
help="Dropout probability for LoRA layers",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-target-modules",
|
|
type=str,
|
|
nargs="+",
|
|
default=None,
|
|
help="Module names to apply LoRA to (default: q_proj v_proj)",
|
|
)
|
|
|
|
# --- Shared memory arguments ---
|
|
parser.add_argument(
|
|
"--use-shared-memory",
|
|
action="store_true",
|
|
help=(
|
|
"Enable NCCL shared memory weight updates (shared_vllm mode only). "
|
|
"Weights are broadcast to vLLM's daemon via NCCL. "
|
|
"vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1."
|
|
),
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def config_from_args(args: argparse.Namespace) -> TrainingConfig:
|
|
"""Build a TrainingConfig from parsed CLI arguments."""
|
|
return TrainingConfig(
|
|
model_name=args.model_name,
|
|
lr=args.lr,
|
|
training_steps=args.training_steps,
|
|
batch_size=args.batch_size,
|
|
seq_len=args.seq_len,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
device=args.device,
|
|
save_path=args.save_path,
|
|
vllm_restart_interval=args.vllm_restart_interval,
|
|
vllm_port=args.vllm_port,
|
|
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
|
|
use_wandb=args.use_wandb,
|
|
wandb_project=args.wandb_project,
|
|
wandb_group=args.wandb_group,
|
|
weight_bridge_mode=args.weight_bridge_mode,
|
|
trainer_rank=args.trainer_rank,
|
|
world_size=args.world_size,
|
|
init_method=args.init_method,
|
|
num_inference_nodes=args.num_inference_nodes,
|
|
lora_r=args.lora_r,
|
|
lora_alpha=args.lora_alpha,
|
|
lora_dropout=args.lora_dropout,
|
|
lora_target_modules=args.lora_target_modules,
|
|
use_shared_memory=getattr(args, 'use_shared_memory', False),
|
|
)
|
|
|
|
|
|
# Example usage (optional, can be run from another script)
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
training_config = config_from_args(args)
|
|
|
|
print(f"Weight bridge mode: {training_config.weight_bridge_mode}")
|
|
|
|
if training_config.weight_bridge_mode == "shared_vllm":
|
|
# Shared vLLM mode: attach to vLLM's weights, update in-place
|
|
train_shared_vllm(training_config)
|
|
|
|
elif training_config.weight_bridge_mode == "lora_only":
|
|
# LoRA mode: freeze base model, train adapters only
|
|
train_lora(training_config)
|
|
|
|
else:
|
|
# Legacy mode: periodic checkpoint saves + vLLM restarts
|
|
train(training_config)
|