mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
lora restart saving gradient changes
This commit is contained in:
parent
1127083b5f
commit
90281f5993
7 changed files with 805 additions and 19 deletions
|
|
@ -5,9 +5,11 @@ Contains the four main training modes:
|
|||
- train_legacy: Checkpoint-based training with vLLM restarts
|
||||
- train_shared_vllm: Single-copy mode with CUDA IPC
|
||||
- train_lora: LoRA adapter training with HTTP hot-swap
|
||||
- train_lora_restart: LoRA training with vLLM restarts (FAST mode)
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -658,3 +660,279 @@ def _hotswap_lora_adapter(
|
|||
return False
|
||||
|
||||
|
||||
def train_lora_restart(config: TrainingConfig):
|
||||
"""
|
||||
GRPO training with LoRA adapters using vLLM restarts (FAST mode).
|
||||
|
||||
This mode:
|
||||
1. Freezes base model, trains only LoRA adapter weights
|
||||
2. Runs vLLM WITH CUDA graphs enabled (no --enforce-eager)
|
||||
3. Restarts vLLM every N steps with the new adapter pre-loaded
|
||||
|
||||
Performance comparison:
|
||||
- lora_only (--enforce-eager): ~13 TPS (SLOW)
|
||||
- lora_restart (CUDA graphs): ~170 TPS (FAST)
|
||||
|
||||
The restart overhead (~45s) is much less than the 12x inference slowdown.
|
||||
|
||||
Requirements:
|
||||
- No external vLLM needed - this mode manages vLLM internally
|
||||
- Requires PEFT library for LoRA
|
||||
"""
|
||||
if not PEFT_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"PEFT library required for LoRA mode. Install with: pip install peft"
|
||||
)
|
||||
|
||||
training_start_time = time.time()
|
||||
|
||||
# === Setup ===
|
||||
use_wandb = setup_wandb(config)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LORA RESTART MODE (fast inference with CUDA graphs)")
|
||||
print("=" * 60)
|
||||
print(f"Base model: {config.model_name}")
|
||||
print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}")
|
||||
print(f"Save path: {config.save_path}")
|
||||
print(f"vLLM port: {config.vllm_port}")
|
||||
print(f"Restart interval: every {config.vllm_restart_interval} steps")
|
||||
print("=" * 60)
|
||||
print("NOTE: This mode restarts vLLM to keep CUDA graphs enabled.")
|
||||
print(" Expected inference speed: ~170 TPS (vs ~13 TPS with --enforce-eager)")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Load model with LoRA adapters for training
|
||||
print("[1/4] Loading model with LoRA adapters...")
|
||||
model, tokenizer = load_model_and_tokenizer(config)
|
||||
|
||||
# Only optimize LoRA parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
optimizer = AdamW(trainable_params, lr=config.lr)
|
||||
|
||||
os.makedirs(config.save_path, exist_ok=True)
|
||||
|
||||
# Save initial adapter
|
||||
print("[2/4] Saving initial LoRA adapter...")
|
||||
initial_adapter_path = save_lora_checkpoint(model, config.save_path, 0)
|
||||
current_adapter_path = initial_adapter_path
|
||||
|
||||
# Launch vLLM with the initial adapter
|
||||
print("[3/4] Launching vLLM with CUDA graphs (no --enforce-eager)...")
|
||||
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
|
||||
if vllm_proc is None:
|
||||
raise RuntimeError("Failed to launch vLLM")
|
||||
|
||||
print(f"[4/4] Starting training for {config.training_steps} steps")
|
||||
print("-" * 60)
|
||||
|
||||
# Check Atropos API
|
||||
if not check_atropos_api(url=config.atropos_url, timeout=30):
|
||||
_terminate_vllm(vllm_proc)
|
||||
raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}")
|
||||
register_trainer(config)
|
||||
|
||||
# === Benchmark tracking ===
|
||||
benchmark_stats = {
|
||||
"step_times": [],
|
||||
"sync_times": [],
|
||||
"data_fetch_times": [],
|
||||
"gpu_memories": [],
|
||||
"restart_times": [],
|
||||
}
|
||||
|
||||
# === Training Loop ===
|
||||
batches = []
|
||||
for step in range(config.training_steps):
|
||||
print(f"\nStep {step+1}/{config.training_steps}")
|
||||
|
||||
# Fetch data (with inference logprobs for proper GRPO)
|
||||
data_fetch_start = time.time()
|
||||
if len(batches) == 0:
|
||||
batches, _ = get_data(
|
||||
config.batch_size,
|
||||
config.seq_len,
|
||||
config.atropos_url,
|
||||
extract_inference_logprobs=True,
|
||||
)
|
||||
batch_data = batches.pop(0)
|
||||
token_batches, label_batches, advantage_batches, temperature_batches = (
|
||||
batch_data[:4]
|
||||
)
|
||||
inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None
|
||||
data_fetch_time = time.time() - data_fetch_start
|
||||
benchmark_stats["data_fetch_times"].append(data_fetch_time)
|
||||
|
||||
# Training step with proper GRPO
|
||||
step_start = time.time()
|
||||
metrics = run_training_step(
|
||||
model,
|
||||
optimizer,
|
||||
token_batches,
|
||||
label_batches,
|
||||
advantage_batches,
|
||||
temperature_batches,
|
||||
config,
|
||||
inference_logprob_batches=inference_logprob_batches,
|
||||
)
|
||||
step_time = time.time() - step_start
|
||||
benchmark_stats["step_times"].append(step_time)
|
||||
|
||||
# GPU memory tracking
|
||||
gpu_mem_gb = (
|
||||
torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
gpu_mem_reserved_gb = (
|
||||
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
|
||||
)
|
||||
benchmark_stats["gpu_memories"].append(gpu_mem_gb)
|
||||
|
||||
# Periodic adapter save + vLLM restart
|
||||
sync_time = 0
|
||||
should_sync = (step + 1) % config.vllm_restart_interval == 0
|
||||
if should_sync and (step + 1) < config.training_steps: # Don't restart on last step
|
||||
sync_start = time.time()
|
||||
|
||||
# Save new adapter
|
||||
current_adapter_path = save_lora_checkpoint(model, config.save_path, step + 1)
|
||||
|
||||
# Restart vLLM with new adapter
|
||||
print(f" [RESTART] Restarting vLLM with new adapter...")
|
||||
_terminate_vllm(vllm_proc)
|
||||
vllm_proc = _launch_vllm_with_lora(config, current_adapter_path)
|
||||
if vllm_proc is None:
|
||||
raise RuntimeError("Failed to restart vLLM")
|
||||
|
||||
sync_time = time.time() - sync_start
|
||||
benchmark_stats["sync_times"].append(sync_time)
|
||||
benchmark_stats["restart_times"].append(sync_time)
|
||||
print(f" [RESTART] vLLM restarted in {sync_time:.1f}s")
|
||||
|
||||
# Update metrics
|
||||
metrics.update(
|
||||
{
|
||||
"step_time": step_time,
|
||||
"sync_time": sync_time,
|
||||
"data_fetch_time": data_fetch_time,
|
||||
"gpu_memory_gb": gpu_mem_gb,
|
||||
"gpu_memory_reserved_gb": gpu_mem_reserved_gb,
|
||||
}
|
||||
)
|
||||
|
||||
log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark)
|
||||
|
||||
# === Cleanup ===
|
||||
print("\nSaving final adapter...")
|
||||
final_sync_start = time.time()
|
||||
final_adapter_path = save_lora_checkpoint(
|
||||
model, config.save_path, config.training_steps, is_final=True
|
||||
)
|
||||
final_sync_time = time.time() - final_sync_start
|
||||
benchmark_stats["sync_times"].append(final_sync_time)
|
||||
|
||||
# Terminate vLLM
|
||||
_terminate_vllm(vllm_proc)
|
||||
|
||||
finalize_training(
|
||||
use_wandb,
|
||||
training_start_time,
|
||||
"lora_restart",
|
||||
config.training_steps,
|
||||
benchmark_stats,
|
||||
config.benchmark,
|
||||
)
|
||||
|
||||
# Save tokenizer
|
||||
tokenizer_path = os.path.join(config.save_path, "tokenizer")
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"Tokenizer saved to {tokenizer_path}")
|
||||
print(f"Final adapter saved to {final_adapter_path}")
|
||||
|
||||
|
||||
def _launch_vllm_with_lora(config: TrainingConfig, adapter_path: str) -> Optional[subprocess.Popen]:
|
||||
"""
|
||||
Launch vLLM with a LoRA adapter pre-loaded (CUDA graphs enabled).
|
||||
|
||||
Unlike lora_only mode, this does NOT use --enforce-eager, so we get
|
||||
full CUDA graph speed (~170 TPS instead of ~13 TPS).
|
||||
"""
|
||||
from .vllm_manager import kill_process_on_port, wait_for_vllm_ready
|
||||
|
||||
# Kill any existing process on the port
|
||||
kill_process_on_port(config.vllm_port)
|
||||
|
||||
# Find the vllm_api_server.py script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
server_script = os.path.join(script_dir, "vllm_api_server.py")
|
||||
|
||||
# Build command - NO --enforce-eager for full speed
|
||||
cmd = [
|
||||
"python", server_script,
|
||||
"--model", config.model_name,
|
||||
"--port", str(config.vllm_port),
|
||||
"--gpu-memory-utilization", str(config.vllm_gpu_memory_utilization),
|
||||
"--enable-lora",
|
||||
"--max-lora-rank", str(max(config.lora_r * 2, 32)),
|
||||
# Note: NOT adding --enforce-eager - this is the key difference!
|
||||
# LoRA adapter will be loaded at startup, CUDA graphs compiled with it
|
||||
]
|
||||
|
||||
# Set environment for GPU selection
|
||||
env = os.environ.copy()
|
||||
if config.vllm_gpu is not None:
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(config.vllm_gpu)
|
||||
print(f" GPU: {config.vllm_gpu} (via CUDA_VISIBLE_DEVICES)")
|
||||
else:
|
||||
print(f" GPU: Same as trainer (inherited CUDA_VISIBLE_DEVICES)")
|
||||
|
||||
print(f" Launching: {' '.join(cmd)}")
|
||||
print(f" Adapter: {adapter_path}")
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(cmd, env=env)
|
||||
print(f" vLLM PID: {proc.pid}")
|
||||
|
||||
# Wait for server to be ready
|
||||
if not wait_for_vllm_ready(config.vllm_port, timeout=180):
|
||||
print(" ERROR: vLLM failed to start")
|
||||
proc.terminate()
|
||||
return None
|
||||
|
||||
# Load the LoRA adapter
|
||||
print(f" Loading LoRA adapter...")
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"http://localhost:{config.vllm_port}/lora/load",
|
||||
json={"adapter_path": adapter_path, "adapter_name": "training_adapter"},
|
||||
timeout=60,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
print(f" ✓ Adapter loaded successfully")
|
||||
else:
|
||||
print(f" WARNING: Adapter load returned {resp.status_code}: {resp.text}")
|
||||
except Exception as e:
|
||||
print(f" WARNING: Could not load adapter: {e}")
|
||||
# Continue anyway - base model inference still works
|
||||
|
||||
return proc
|
||||
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _terminate_vllm(proc: Optional[subprocess.Popen]) -> None:
|
||||
"""Terminate a vLLM process."""
|
||||
if proc is None:
|
||||
return
|
||||
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue