diff --git a/example_trainer/README.md b/example_trainer/README.md index a1f69c2a..81cd89c4 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -2,516 +2,507 @@ A modular training framework for fine-tuning language models with **Group Relative Policy Optimization (GRPO)**, designed to work with the Atropos environment system. -## 📁 Module Structure +## Module Structure ``` example_trainer/ ├── grpo.py # CLI entry point (dispatches to trainers) +├── run.py # Unified launcher for shared_vllm mode ├── config.py # TrainingConfig dataclass +├── cli.py # CLI argument parsing (single source of truth) ├── api.py # Atropos API communication ├── data.py # Data fetching & preprocessing ├── model.py # Model loading & CUDA IPC shared memory -├── training.py # Loss computation & training step +├── training.py # GRPO loss computation & training step ├── checkpointing.py # Save models & LoRA adapters ├── vllm_manager.py # vLLM process management ├── trainers.py # Training mode implementations -├── cli.py # CLI argument parsing -├── vllm_api_server.py # Custom vLLM server with IPC support -├── vllm_patching/ # B200/Blackwell GPU patches -│ └── patched_gpu_runner.py -└── scripts/ # Helper scripts - ├── run_comparison.sh - ├── run_concurrent_tests.sh +├── vllm_api_server.py # Custom vLLM server (streamlined for training) +├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this +│ └── patched_gpu_runner.py +└── scripts/ # Helper scripts ├── test_lora_mode.sh └── test_single_copy_mode.sh ``` ---- -## 🔄 Full System Architecture - -The Atropos training system consists of 4 components that must run together: +GRPO Training Loop +1. Generate multiple responses to the same prompt +2. Score each response (reward) +3. Compute ADVANTAGE = reward - mean(rewards) +4. Train: increase probability of above-average responses + decrease probability of below-average responses ``` -┌─────────────────────────────────────────────────────────────────────────────┐ -│ ATROPOS TRAINING SYSTEM │ -└─────────────────────────────────────────────────────────────────────────────┘ - ┌─────────────┐ ┌──────────────────┐ ┌─────────────────┐ - │ vLLM │◄────►│ Environment │─────►│ run-api │ - │ Server │ │ (gsm8k_server) │ │ (Trajectory │ - │ (Inference)│ │ (Process Env) │ │ Handler API) │ - └─────────────┘ └──────────────────┘ └────────┬────────┘ - ▲ │ - │ │ - │ ┌───────────────────────────────────┘ - │ │ - │ ▼ - │ ┌─────────────┐ - └────────│ GRPO │ - │ Trainer │ - │ (grpo.py) │ - └─────────────┘ +### Key Concepts + +| Concept | What It Means | +|---------|---------------| +| **Advantage** | How much better/worse than average a response was | +| **Importance Sampling** | Corrects for policy drift during training | +| **KL Penalty** | Prevents the model from changing too drastically from base | +| **Clipping** | Limits update magnitude for stability | + + +## System Architecture Data Flow: -1. run-api : Central API that receives trajectories and serves batches -2. Environment : Generates prompts, calls vLLM, scores responses → sends to run-api -3. Trainer : Fetches batches from run-api → trains model → updates weights -4. vLLM : Serves inference for environment (and gets weight updates) -``` - -### Components Explained - -| Component | Command | Port | Purpose | -|-----------|---------|------|---------| -| **run-api** | `run-api` | 8000 | Central trajectory handler API | -| **Environment** | `gsm8k_server.py serve` | (internal) | Generates rollouts, scores them | -| **vLLM** | `vllm_api_server.py` | 9001 | Model inference | -| **Trainer** | `grpo.py` | (client) | Fetches batches, trains model | - ---- - -## 🎯 Three Training Modes - -| Mode | Description | vLLM Setup | Best For | -|------|-------------|------------|----------| -| **Legacy** (`none`) | Trainer manages vLLM, restarts with new checkpoints | Auto-managed | Simple setup, different GPUs | -| **Shared vLLM** (`shared_vllm`) | Single-copy mode via CUDA IPC - no model duplication! | External, `VLLM_ENABLE_SHARED_WEIGHTS=1` | Same GPU, max efficiency | -| **LoRA** (`lora_only`) | Train adapters only, hot-swap in vLLM | External, `--enable-lora` | Fast training, small checkpoints | - ---- - -## 🚀 Quick Start - -### Prerequisites - -```bash -# Install dependencies -pip install torch transformers peft vllm wandb requests tenacity pydantic - -# Set environment variables -export LOGDIR=/tmp/atropos_test -export MODEL=Qwen/Qwen2.5-3B-Instruct -mkdir -p $LOGDIR +1. Environment generates prompts → calls vLLM → scores responses +2. Environment sends trajectories to run-api +3. Trainer fetches batches from run-api +4. Trainer updates model weights +5. (shared_vllm) vLLM sees updates immediately via CUDA IPC + (lora_only) Trainer pushes adapter to vLLM periodically ``` --- -## 📖 Detailed Usage for Each Mode +## Three Training Modes -### Mode 1: Legacy (Checkpoint + Restart) +| Mode | Description | Memory | Best For | +|------|-------------|--------|----------| +| **shared_vllm** | Single-copy via CUDA IPC | 1x model | Same GPU, maximum efficiency | +| **lora_only** | Train adapters, hot-swap | 1x + small adapter | Fast iteration, small checkpoints | +| **legacy** | Full model, restart vLLM | 2x model | Different GPUs, simple setup | -The simplest mode. Trainer manages vLLM internally. +### Recommendation -```bash -# Terminal 1: Start the central API server (handles trajectories) -run-api --port 8000 - -# Terminal 2: Start the environment server (generates rollouts) -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name $MODEL \ - --env.use_wandb=False \ - --openai.model_name $MODEL \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm - -# Terminal 3: Run training (trainer will launch its own vLLM) -CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode none \ - --vllm-port 9001 \ - --atropos-url http://localhost:8000 \ - --training-steps 20 \ - --batch-size 2 \ - --save-path $LOGDIR/checkpoints_legacy \ - --benchmark -``` - -### Mode 2: Shared vLLM (Single-Copy CUDA IPC) - -Zero model duplication - trainer and vLLM share the exact same GPU memory! - -```bash -# Terminal 1: Start the central API server -run-api --port 8000 - -# Terminal 2: Start vLLM with shared weights enabled -# IMPORTANT: --enforce-eager is REQUIRED to disable CUDA graphs -# Without it, weight updates won't be visible to inference! -VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \ -CUDA_VISIBLE_DEVICES=0 python example_trainer/vllm_api_server.py \ - --model $MODEL \ - --port 9001 \ - --gpu-memory-utilization 0.45 \ - --enforce-eager - -# Terminal 3: Start the environment server -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name $MODEL \ - --env.use_wandb=False \ - --openai.model_name $MODEL \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm - -# Terminal 4: Run training (attaches to vLLM's tensors) -CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode shared_vllm \ - --vllm-port 9001 \ - --vllm-config-path $LOGDIR/vllm_bridge_config.json \ - --atropos-url http://localhost:8000 \ - --training-steps 20 \ - --batch-size 2 \ - --save-path $LOGDIR/checkpoints_shared \ - --benchmark -``` - -### Mode 3: LoRA (Adapter Training) - -Fast training with hot-swappable adapters. - -```bash -# Terminal 1: Start the central API server -run-api --port 8000 - -# Terminal 2: Start vLLM with LoRA support -CUDA_VISIBLE_DEVICES=0 python example_trainer/vllm_api_server.py \ - --model $MODEL \ - --port 9001 \ - --gpu-memory-utilization 0.45 \ - --enable-lora \ - --max-lora-rank 32 \ - --enforce-eager - -# Terminal 3: Start the environment server -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name $MODEL \ - --env.use_wandb=False \ - --openai.model_name $MODEL \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm - -# Terminal 4: Run LoRA training -CUDA_VISIBLE_DEVICES=1 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode lora_only \ - --vllm-port 9001 \ - --atropos-url http://localhost:8000 \ - --lora-r 16 \ - --lora-alpha 32 \ - --training-steps 20 \ - --batch-size 2 \ - --save-path $LOGDIR/checkpoints_lora \ - --benchmark -``` +**Start with `lora_only`** - it's the easiest to set up and debug. Move to `shared_vllm` for production training when you need maximum efficiency for SINGLE GPU TRAINING RUNS. MULTIPLE GPU TRAINING NOT SUPPORTED . --- -## 🔬 Run All 3 Modes in Parallel (8-GPU Comparison) +## Quick Start: LoRA Training (Recommended) -Use this setup to compare training efficiency across all three modes on a single 8-GPU node. +### Step 1: Install Dependencies +- They are listed in the requirements.txt file that you can see -### GPU & Port Allocation - -| Mode | GPUs | vLLM Port | API Port | Env Port | -|------|------|-----------|----------|----------| -| Legacy | 0-1 | 9001 | 8001 | (internal) | -| Shared vLLM | 2-3 | 9002 | 8002 | (internal) | -| LoRA | 4-5 | 9003 | 8003 | (internal) | - -### Quick Start: Use the Comparison Script +### Step 2: Start All Components +**Terminal 1: API Server** ```bash -cd /path/to/atropos - -# Run comparison with default 50 steps -./example_trainer/scripts/run_comparison.sh - -# Or specify steps -./example_trainer/scripts/run_comparison.sh 100 -``` - -### Manual Parallel Execution - -If you prefer to run each mode manually in separate terminal sessions: - -```bash -# Setup -export MODEL="Qwen/Qwen2.5-3B-Instruct" -export LOGDIR=/tmp/atropos_test -mkdir -p $LOGDIR - -# ============================================= -# LEGACY MODE (Terminals 1-3) -# ============================================= - -# Terminal 1: API server for legacy -run-api --port 8001 - -# Terminal 2: Environment server -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name $MODEL \ - --env.use_wandb=False \ - --env.rollout_server_url http://localhost:8001 \ - --openai.model_name $MODEL \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - --slurm false - -# Terminal 3: Trainer (manages its own vLLM) -CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode none \ - --vllm-port 9001 \ - --atropos-url http://localhost:8001 \ - --training-steps 50 \ - --save-path $LOGDIR/checkpoints_legacy \ - --benchmark - -# ============================================= -# SHARED VLLM MODE (Terminals 4-7) -# ============================================= - -# Terminal 4: API server for shared run-api --port 8002 +``` -# Terminal 5: vLLM server with shared weights -VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \ -CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \ - --model $MODEL --port 9002 --gpu-memory-utilization 0.45 +**Terminal 2: vLLM Server** +```bash +python -m example_trainer.vllm_api_server \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --port 9001 \ + --gpu-memory-utilization 0.5 \ + --max-model-len 4096 \ + --dtype bfloat16 \ + --enable-lora \ + --enforce-eager +``` -# Terminal 6: Environment server -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name $MODEL \ - --env.use_wandb=False \ - --env.rollout_server_url http://localhost:8002 \ - --openai.model_name $MODEL \ - --openai.base_url http://localhost:9002/v1 \ - --openai.server_type vllm \ - --slurm false +**Terminal 3: Environment** +```bash +python environments/gsm8k_server.py serve \ + --env.group_size 4 \ + --env.max_num 200 \ + --slurm.num_requests_per_time_interval 16 \ + --slurm.time_interval 10 \ + --openai.api_key "dummy" \ + --openai.base_url "http://localhost:9001" \ + --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ + --openai.server_type vllm +``` -# Terminal 7: Trainer (attaches to vLLM) -CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \ - --model-name $MODEL \ - --weight-bridge-mode shared_vllm \ - --vllm-port 9002 \ - --vllm-config-path $LOGDIR/vllm_bridge_config.json \ - --atropos-url http://localhost:8002 \ - --training-steps 50 \ - --save-path $LOGDIR/checkpoints_shared \ - --benchmark - -# ============================================= -# LORA MODE (Terminals 8-11) -# ============================================= - -# Terminal 8: API server for LoRA -run-api --port 8003 - -# Terminal 9: vLLM server with LoRA -CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \ - --model $MODEL --port 9003 --gpu-memory-utilization 0.45 \ - --enable-lora --max-lora-rank 32 --enforce-eager - -# Terminal 10: Environment server -python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name $MODEL \ - --env.use_wandb=False \ - --env.rollout_server_url http://localhost:8003 \ - --openai.model_name $MODEL \ - --openai.base_url http://localhost:9003/v1 \ - --openai.server_type vllm \ - --slurm false - -# Terminal 11: Trainer -CUDA_VISIBLE_DEVICES=5 python -m example_trainer.grpo \ - --model-name $MODEL \ +**Terminal 4: Trainer** +```bash +python -m example_trainer.grpo \ + --model-name NousResearch/Hermes-3-Llama-3.1-8B \ --weight-bridge-mode lora_only \ - --vllm-port 9003 \ - --atropos-url http://localhost:8003 \ - --lora-r 16 --lora-alpha 32 \ - --training-steps 50 \ - --save-path $LOGDIR/checkpoints_lora \ - --benchmark + --vllm-port 9001 \ + --atropos-url "http://localhost:8002" \ + --batch-size 4 \ + --gradient-accumulation-steps 4 \ + --learning-rate 1e-5 \ + --training-steps 30 \ + --kl-coef 0.1 \ + --clip-eps 0.2 \ + --vllm-restart-interval 5 \ + --save-path ./lora_checkpoints \ + --wandb-project "grpo-training" ``` ---- - -## 📊 Understanding the Benchmark Output - -Each trainer outputs a benchmark summary at the end: - -``` -====================================================================== -BENCHMARK SUMMARY (shared_vllm) -====================================================================== - Total training time: 168.65s (2.81 min) - Total steps: 50 - - TIMING BREAKDOWN: - Avg step time: 11.95s - Total step time: 59.76s - Avg sync time: 0.00s (x0 syncs) <-- No syncs in shared mode! - Total sync time: 0.00s - Avg data fetch time: 10.90s - Total data fetch time: 54.52s - - MEMORY: - Peak GPU memory: 31.44 GB - Avg GPU memory: 18.88 GB -====================================================================== -``` - -**Key metrics to compare:** - -| Metric | Legacy | Shared vLLM | LoRA | -|--------|--------|-------------|------| -| **Sync time** | High (restart vLLM) | 0 (in-place update) | Low (adapter swap) | -| **GPU memory** | 2x model | 1x model | 1x + adapter | -| **Step time** | ~10-15s | ~10-15s | ~5-10s | -| **Checkpoint size** | ~6GB | ~6GB | ~50MB | - ---- - -## 🛠 CLI Reference +### Startup Order ```bash -python -m example_trainer.grpo --help +# 1. Start API +# 2. Wait 5s, start vLLM +# 3. Wait for vLLM to load (check: curl http://localhost:9001/health) +# 4. Start environment +# 5. Start trainer ``` -### Key Arguments +--- + +## Shared vLLM Mode (Advanced) + +Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication! + +### How It Works + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ SINGLE GPU (CUDA IPC) │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ Model Weights (ONE copy in GPU memory) │ │ +│ │ (accessible via CUDA IPC handles) │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ ▲ ▲ │ +│ │ Reads (inference) │ Writes │ +│ ┌────────┴────────┐ ┌───────────┴───────────┐ │ +│ │ vLLM Worker │ │ Trainer Process │ │ +│ │ │ │ (attached via IPC) │ │ +│ └─────────────────┘ └───────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +### Running Shared vLLM Mode + +**Terminal 1: API** +```bash +run-api --port 8002 +``` + +**Terminal 2: vLLM with Shared Weights** +```bash +VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/grpo_training \ +python -m example_trainer.vllm_api_server \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --port 9001 \ + --gpu-memory-utilization 0.45 \ + --enforce-eager +``` + +**Terminal 3: Environment** +```bash +python environments/gsm8k_server.py serve \ + --openai.base_url "http://localhost:9001" \ + --openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \ + --openai.server_type vllm +``` + +**Terminal 4: Trainer** +```bash +python -m example_trainer.grpo \ + --model-name NousResearch/Hermes-3-Llama-3.1-8B \ + --weight-bridge-mode shared_vllm \ + --vllm-port 9001 \ + --vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \ + --atropos-url "http://localhost:8002" \ + --kl-coef 0.1 \ + --clip-eps 0.2 +``` + +### Or Use the Unified Launcher + +```bash +# Single command starts both vLLM and trainer +VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \ + --model-name NousResearch/Hermes-3-Llama-3.1-8B \ + --atropos-url "http://localhost:8002" \ + --training-steps 30 +``` + +--- + +## Best Practices & Lessons Learned + +### 1. Always Use `--enforce-eager` with Shared Weights + +**Why:** CUDA graphs "bake" weights at compile time. Without eager mode, vLLM won't see weight updates! + +```bash +# WRONG - weight updates won't be visible to inference +python vllm_api_server.py --model $MODEL + +# CORRECT - disables CUDA graphs +python vllm_api_server.py --model $MODEL --enforce-eager +``` + +### 2. Use `--openai.server_type vllm` for Training + +The gsm8k environment needs logprobs for GRPO. Only `server_type=vllm` uses the `/generate` endpoint which returns logprobs. + +```bash +# CORRECT - gets logprobs for training +--openai.server_type vllm + +# WRONG for training - no logprobs +--openai.server_type openai +``` + +### 3. KL Coefficient and Clipping Are Essential + +Without these, training will collapse (reward hacking): + +```bash +--kl-coef 0.1 # Prevents policy from drifting too far +--clip-eps 0.2 # Limits update magnitude +``` + +**Symptoms of missing KL/clipping:** +- Accuracy drops dramatically (e.g., 59% → 7%) +- Loss goes to very negative values +- Model outputs become repetitive/degenerate + +### 4. Memory Budgeting for Large Models + +| Model Size | GPU Memory | Recommended Settings | +|------------|------------|----------------------| +| 8B | 80GB | `--gpu-memory-utilization 0.5` | +| 14B | 80GB | `--gpu-memory-utilization 0.45`, `--batch-size 2` | +| 24B | 192GB (B200) | `--gpu-memory-utilization 0.30`, `--optimizer adafactor` | + +### 5. Start with Small Batch Sizes + +```bash +# Start conservative, increase if no OOM +--batch-size 2 --gradient-accumulation-steps 8 # Effective batch = 16 +``` + +--- + +## Tensor Mapping (vLLM ↔ HuggingFace) + +### The Problem + +vLLM fuses certain layers for efficiency, but HuggingFace keeps them separate: + +``` +HuggingFace Model: vLLM Model: +├── q_proj [4096, 4096] ├── qkv_proj [12288, 4096] ← FUSED! +├── k_proj [1024, 4096] │ (contains q, k, v concatenated) +├── v_proj [1024, 4096] │ +│ │ +├── gate_proj [14336, 4096] ├── gate_up_proj [28672, 4096] ← FUSED! +├── up_proj [14336, 4096] │ (contains gate and up concatenated) +``` + +### How We Solve It + +The trainer creates **views** into vLLM's fused tensors: + +```python +# vLLM has: qkv_proj.weight [12288, 4096] +# We need: q_proj [4096], k_proj [1024], v_proj [1024] + +# Get sizes from model config +q_size = num_heads * head_dim # e.g., 4096 +k_size = num_kv_heads * head_dim # e.g., 1024 +v_size = num_kv_heads * head_dim # e.g., 1024 + +# Create views (no copy!) +hf_model.q_proj.weight = vllm_qkv[0:4096, :] # First chunk +hf_model.k_proj.weight = vllm_qkv[4096:5120, :] # Second chunk +hf_model.v_proj.weight = vllm_qkv[5120:6144, :] # Third chunk +``` + +### Key Insight: Views Share Memory + +```python +# These point to the SAME GPU memory: +trainer_q_proj.data_ptr() == vllm_qkv_proj.data_ptr() # True! + +# So when optimizer updates trainer weights: +optimizer.step() # Updates trainer_q_proj + +# vLLM sees the change immediately (same memory)! +``` + +### The Config File + +vLLM exports tensor mappings to `vllm_bridge_config.json`: + +```json +{ + "model": "NousResearch/Hermes-3-Llama-3.1-8B", + "param_mappings": { + "model.layers.0.self_attn.qkv_proj.weight": { + "ipc_handle": "base64_encoded_cuda_ipc_handle", + "shape": [6144, 4096], + "dtype": "bfloat16" + } + } +} +``` + +--- + +## ❓ FAQ + + +### Q: Why isn't vLLM seeing my weight updates? + +**A:** CUDA graphs are caching the old weights. Add `--enforce-eager`: + +```bash +python vllm_api_server.py --model $MODEL --enforce-eager +``` + + + +### Q: How do I debug logprob alignment issues? + +**A:** Look for these log messages: +``` +[WARNING] ref_logprobs at generated positions avg 0.85 (should be negative!) +``` + +This means inference logprobs aren't being passed correctly. Check that: +1. Environment uses `--openai.server_type vllm` +2. vLLM returns logprobs (check `/generate` response) + +### Q: Why does vLLM v1 engine fail with CUDA fork errors? + +**A:** vLLM v1 uses multiprocessing that conflicts with CUDA initialization. We default to v0 engine: + +```python +# vllm_api_server.py automatically sets: +os.environ.setdefault("VLLM_USE_V1", "0") +``` + + +## Troubleshooting + +### "Atropos API not reachable" + +```bash +# Start the API server first +run-api --port 8002 +``` + +### "404 Not Found" on /generate + +You're using a vLLM server that doesn't expose `/generate`. Use our custom server: + +```bash +python -m example_trainer.vllm_api_server ... # Has /generate +# NOT: python -m vllm.entrypoints.openai.api_server # Only has /v1/* +``` + +### "Cannot re-initialize CUDA in forked subprocess" + +vLLM v1 engine issue. We disable it by default, but if you see this: + +```bash +VLLM_USE_V1=0 python -m example_trainer.vllm_api_server ... +``` + +### "LogProb Alignment: MISMATCH!" + +Weight updates aren't visible to inference. Fix: + +```bash +# Add --enforce-eager to vLLM +python vllm_api_server.py --model $MODEL --enforce-eager +``` + +### OOM (Out of Memory) + +Reduce memory usage: + +```bash +--gpu-memory-utilization 0.4 # Less vLLM memory +--batch-size 2 # Smaller batches +--gradient-accumulation-steps 8 # Compensate with accumulation +--seq-len 1024 # Shorter sequences +--optimizer adafactor # Uses less memory than AdamW +``` + +### "FlexibleArgumentParser" import error + +vLLM version incompatibility. Our server handles this automatically, but make sure you're using: + +```bash +python -m example_trainer.vllm_api_server # NOT direct vllm commands +``` + +### Training is slow / no batches + +1. Check vLLM is running: `curl http://localhost:9001/health` +2. Check API is running: `curl http://localhost:8002/info` +3. Check environment is connected and generating rollouts + +--- + +## 📊 Monitoring Training + +### Key Metrics to Watch + +| Metric | Healthy Range | Problem If... | +|--------|---------------|---------------| +| `mean_ratio` | 0.8 - 1.2 | Far from 1.0 = policy changed too much | +| `mean_kl` | 0.01 - 0.1 | > 0.5 = policy drifting | +| `clipped_fraction` | < 0.3 | > 0.5 = learning rate too high | +| `loss` | Gradually decreasing | Exploding or very negative | + +### WandB Logging + +```bash +--use-wandb \ +--wandb-project "my-grpo-training" \ +--wandb-run-name "hermes-8b-gsm8k" +``` + +--- + +## CLI Reference + +### Essential Arguments | Argument | Default | Description | |----------|---------|-------------| | `--model-name` | (required) | HuggingFace model ID | -| `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` | +| `--weight-bridge-mode` | `none` | `shared_vllm`, `lora_only`, or `none` | | `--training-steps` | 10 | Number of training steps | -| `--batch-size` | 2 | Batch size | -| `--vllm-port` | 9001 | vLLM server port | -| `--atropos-url` | `http://localhost:8000` | Atropos API server URL | -| `--save-path` | `trained_model_checkpoints` | Checkpoint directory | -| `--benchmark` | false | Show timing stats | -| `--debug-loading` | false | Verbose model loading output | +| `--batch-size` | 2 | Micro-batch size | +| `--gradient-accumulation-steps` | 1 | Effective batch = batch × accum | -### LoRA-specific Arguments +### GRPO Hyperparameters + +| Argument | Default | Description | +|----------|---------|-------------| +| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) | +| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] | +| `--learning-rate` | 1e-6 | Learning rate | + +### LoRA Arguments | Argument | Default | Description | |----------|---------|-------------| | `--lora-r` | 16 | LoRA rank | -| `--lora-alpha` | 32 | LoRA alpha (scaling) | +| `--lora-alpha` | 32 | LoRA scaling factor | | `--lora-dropout` | 0.05 | LoRA dropout | -| `--lora-target-modules` | `q_proj v_proj` | Modules to apply LoRA | -### Single-Copy Mode Arguments +### vLLM Arguments | Argument | Default | Description | |----------|---------|-------------| -| `--single-copy` | false | Enable CUDA IPC mode | -| `--vllm-config-path` | auto-detect | Path to `vllm_bridge_config.json` | +| `--vllm-port` | 9001 | vLLM server port | +| `--vllm-config-path` | auto | Path to bridge config (shared mode) | +| `--gpu-memory-utilization` | 0.9 | vLLM GPU memory fraction | --- -## 🐛 Troubleshooting +## Module Documentation -### "Atropos API not reachable" -```bash -# Make sure run-api is running -run-api --port 8000 -``` +| Module | Purpose | +|--------|---------| +| `grpo.py` | CLI entry point, dispatches to training modes | +| `run.py` | Unified launcher for shared_vllm mode | +| `cli.py` | Single source of truth for all CLI arguments | +| `config.py` | `TrainingConfig` Pydantic model | +| `api.py` | Communication with Atropos API | +| `data.py` | Batch preprocessing, logprob extraction | +| `model.py` | Model loading, CUDA IPC attachment, tensor mapping | +| `training.py` | GRPO loss computation | +| `trainers.py` | Mode-specific training loops | +| `vllm_api_server.py` | Streamlined vLLM server for training | +| `vllm_manager.py` | vLLM process lifecycle management | +| `checkpointing.py` | Save/load checkpoints and adapters | -### "vLLM server not running" (LoRA mode) -```bash -# LoRA mode requires external vLLM with --enable-lora -python example_trainer/vllm_api_server.py \ - --model $MODEL --port 9001 --enable-lora --enforce-eager -``` - -### "Could not find vllm_bridge_config.json" (Shared mode) -```bash -# Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set -VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/atropos python example_trainer/vllm_api_server.py ... -``` - -### "LogProb Alignment: MISMATCH!" in shared_vllm mode -If you see `[MISMATCH!]` in the logprob alignment output, inference and training are seeing different weights. This is usually caused by **CUDA graphs**. - -**Symptom:** `inference_mean` stays constant while `training_mean` changes. The `diff` increases over time. - -**Fix:** Add `--enforce-eager` when starting vLLM: -```bash -VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \ -python example_trainer/vllm_api_server.py \ - --model $MODEL --port 9001 --enforce-eager # <-- REQUIRED! -``` - -**Why:** CUDA graphs "bake" model weights into compiled graphs at startup. Updates to the underlying tensors are NOT reflected in inference. Using `--enforce-eager` disables CUDA graphs, so vLLM reads from the shared tensors on every forward pass. - -### "Triton compilation error" on B200/Blackwell GPUs -The patched vLLM server (`vllm_api_server.py`) automatically applies B200 fixes. If using standard vLLM, add `--enforce-eager`. - -### Port already in use -```bash -# Kill existing processes -pkill -f "run-api" -pkill -f "vllm_api_server.py" -pkill -f "gsm8k_server.py" -``` - -### No batches available / trainer hangs -```bash -# Ensure the environment server is connected to the correct API and vLLM -# Check that vLLM is running and environment can reach it -curl http://localhost:9001/health -curl http://localhost:8000/info -``` - ---- - -## 📚 Module Documentation - -### `config.py` -Contains `TrainingConfig` - all training parameters as a Pydantic model. - -### `api.py` -- `check_atropos_api()` - Wait for run-api server -- `register_trainer()` - Register with Atropos -- `get_batch()` - Fetch training batch from run-api - -### `data.py` -- `pad_data_to_good_offset()` - Pad sequences to GPU-friendly lengths -- `get_data()` - Fetch and preprocess batches - -### `model.py` -- `load_model_and_tokenizer()` - Load model based on mode -- `_attach_to_vllm_shared_tensors()` - CUDA IPC attachment -- `_create_vllm_to_hf_mapping()` - Handle QKV/Gate-Up fusion - -### `training.py` -- `compute_grpo_loss()` - GRPO loss computation -- `run_training_step()` - Single step with gradient accumulation -- `log_metrics()` - Console and WandB logging -- `finalize_training()` - Cleanup and summary - -### `checkpointing.py` -- `save_checkpoint()` - Save full model -- `save_lora_checkpoint()` - Save LoRA adapter only - -### `vllm_manager.py` -- `launch_vllm_server()` - Start vLLM process -- `terminate_vllm_process()` - Stop vLLM -- `hotswap_lora_adapter()` - Hot-swap LoRA in vLLM - -### `trainers.py` -- `train_legacy()` - Checkpoint + restart mode -- `train_shared_vllm()` - Single-copy CUDA IPC mode -- `train_lora()` - Adapter training mode - -### `cli.py` -- `parse_args()` - Argparse setup -- `config_from_args()` - Convert args to TrainingConfig - ---- - -## 📝 License - -MIT License diff --git a/example_trainer/training.py b/example_trainer/training.py index 5ab7f2f6..3d6d74d1 100644 --- a/example_trainer/training.py +++ b/example_trainer/training.py @@ -12,7 +12,6 @@ import string import time from typing import Dict, List, Optional, Tuple -import numpy as np import torch import torch.nn.functional as F import wandb @@ -131,7 +130,7 @@ def compute_grpo_loss( # Move inference logprobs to correct device/dtype ref_logprobs = inference_logprobs.to(logp_per_token.device, logp_per_token.dtype) - # DEBUG: Check if inference logprobs look valid + # NOTE: inference_logprobs uses 1.0 for masked (prompt) positions, actual negative values for generated with torch.no_grad(): # Only look at generated positions (where mask == 1) @@ -146,7 +145,7 @@ def compute_grpo_loss( elif abs(ref_at_generated - train_at_generated) > 2.0: print(f" [DEBUG] Logprob gap (may be OK for first step): ref={ref_at_generated:.3f}, train={train_at_generated:.3f}") - # Compute importance sampling ratio: π(a|s) / π_old(a|s) = exp(log π - log π_old) + # Compute importance sampling ratio: policy(a|s) / policy_old(a|s) = exp(log policy - log policy_old) log_ratio = logp_per_token - ref_logprobs ratio = torch.exp(log_ratio) @@ -159,7 +158,7 @@ def compute_grpo_loss( # Pessimistic bound: min for positive advantages, max for negative # This is equivalent to: -min(ratio * A, clipped_ratio * A) when A > 0 - # -max(ratio * A, clipped_ratio * A) when A < 0 + # -max(ratio * A, clipped_ratio * A) when A < 0 policy_loss_per_token = -torch.where( adv_expanded >= 0, torch.min(surr1, surr2), @@ -171,14 +170,6 @@ def compute_grpo_loss( # KL penalty: encourage staying close to reference policy # Using Schulman's unbiased KL estimator from the DeepSeek GRPO paper (Equation 4): - # D_KL(π_θ || π_ref) = (π_ref / π_θ) - log(π_ref / π_θ) - 1 - # - # In terms of log probabilities: - # log_ratio = log π_θ - log π_ref (what we computed above) - # ratio_ref_over_pi = exp(-log_ratio) = π_ref / π_θ - # kl = ratio_ref_over_pi - log(ratio_ref_over_pi) - 1 - # = exp(-log_ratio) + log_ratio - 1 - # # This estimator is guaranteed to be non-negative (unlike squared log-ratio). if kl_coef > 0: # Schulman's unbiased KL estimator: (π_ref/π) - log(π_ref/π) - 1 @@ -211,26 +202,18 @@ def compute_grpo_loss( ).view(labels.shape) training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() else: - # Fallback: REINFORCE-style (no reference policy) - # This is what the original code did - NOT recommended! - print(" [WARNING] No reference logprobs - using REINFORCE (may cause reward hacking!)") - - # Simple policy gradient: -log(π) * A - policy_loss = ((-logp_per_token * mask * adv_expanded).sum(dim=-1) / mask_sum).mean() - total_loss = policy_loss / gradient_accumulation_steps - kl_penalty = torch.tensor(0.0, device=logp_per_token.device) - - with torch.no_grad(): - clipped_fraction = torch.tensor(0.0) - mean_ratio = torch.tensor(1.0) - mean_kl = torch.tensor(0.0) - raw_logp_per_token = -F.cross_entropy( - outputs.logits.view(-1, outputs.logits.size(-1)), - labels.view(-1), - reduction="none", - ignore_index=-100, - ).view(labels.shape) - training_logprobs_flat = raw_logp_per_token[mask.bool()].detach() + # Fail loudly + raise ValueError( + "GRPO requires inference_logprobs for importance sampling!\n" + "\n" + "This error means the environment isn't providing logprobs. To fix:\n" + " 1. Use --openai.server_type vllm (not 'openai')\n" + " 2. Ensure vLLM is returning logprobs in /generate response\n" + " 3. Check that gsm8k_server is configured correctly\n" + "\n" + "Without inference logprobs, training will cause reward hacking.\n" + "If you REALLY want vanilla REINFORCE (not recommended), set use_reference_logprobs=False" + ) # === Compute Additional Metrics === with torch.no_grad(): @@ -262,67 +245,6 @@ def compute_grpo_loss( return total_loss, metrics -def compute_logprob_alignment( - inference_logprobs: List[np.ndarray], - training_logprobs: List[torch.Tensor], - debug: bool = False, -) -> Dict[str, float]: - """ - Compute alignment stats between inference and training logprobs. - - At initialization (step 0), these should match closely if the model - weights are correctly shared between training and inference. - - Args: - inference_logprobs: Logprobs from vLLM inference (numpy arrays) - training_logprobs: Logprobs computed during training forward pass (PyTorch tensors, bfloat16 supported) - debug: If True, print detailed debugging info - - Returns: - Dict of alignment statistics - """ - if not inference_logprobs or not training_logprobs: - return {} - - # Process inference logprobs (numpy) - inf_flat = np.concatenate(inference_logprobs) - # Filter out placeholder values (1.0 or 0.0 used for prompt tokens) - inf_mask = (inf_flat != 1.0) & (inf_flat != 0.0) - inf_filtered = inf_flat[inf_mask] - - # Process training logprobs (PyTorch - supports bfloat16 natively) - train_flat = torch.cat(training_logprobs) - - if debug: - print(f" [DEBUG] Inference: {len(inf_flat)} total, {len(inf_filtered)} after filter") - print(f" [DEBUG] Training: {train_flat.numel()} logprobs") - if len(inf_filtered) > 0: - print(f" [DEBUG] Inf sample (first 5): {inf_filtered[:5]}") - if train_flat.numel() > 0: - print(f" [DEBUG] Train sample (first 5): {train_flat[:5].tolist()}") - - # Compute stats using PyTorch for training (keeps bfloat16 precision) - stats = {} - - if len(inf_filtered) > 0: - stats["logprobs/inference_mean"] = float(np.mean(inf_filtered)) - stats["logprobs/inference_std"] = float(np.std(inf_filtered)) - - if train_flat.numel() > 0: - # PyTorch operations - fully support bfloat16 - stats["logprobs/training_mean"] = train_flat.mean().item() - stats["logprobs/training_std"] = train_flat.std().item() - - # Compute diff (for tracking, not validation) - # NOTE: Per-token comparison is NOT reliable here because inference and training - # logprobs come from different batch orderings and can't be aligned token-by-token. - # The real-time test at startup is the proper alignment validation. - if "logprobs/inference_mean" in stats and "logprobs/training_mean" in stats: - stats["logprobs/diff"] = stats["logprobs/inference_mean"] - stats["logprobs/training_mean"] - - return stats - - def run_training_step( model: torch.nn.Module, optimizer: torch.optim.Optimizer,