mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
cleanup
This commit is contained in:
parent
04f2850980
commit
0b61dd047a
2 changed files with 453 additions and 540 deletions
|
|
@ -2,516 +2,507 @@
|
|||
|
||||
A modular training framework for fine-tuning language models with **Group Relative Policy Optimization (GRPO)**, designed to work with the Atropos environment system.
|
||||
|
||||
## 📁 Module Structure
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
example_trainer/
|
||||
├── grpo.py # CLI entry point (dispatches to trainers)
|
||||
├── run.py # Unified launcher for shared_vllm mode
|
||||
├── config.py # TrainingConfig dataclass
|
||||
├── cli.py # CLI argument parsing (single source of truth)
|
||||
├── api.py # Atropos API communication
|
||||
├── data.py # Data fetching & preprocessing
|
||||
├── model.py # Model loading & CUDA IPC shared memory
|
||||
├── training.py # Loss computation & training step
|
||||
├── training.py # GRPO loss computation & training step
|
||||
├── checkpointing.py # Save models & LoRA adapters
|
||||
├── vllm_manager.py # vLLM process management
|
||||
├── trainers.py # Training mode implementations
|
||||
├── cli.py # CLI argument parsing
|
||||
├── vllm_api_server.py # Custom vLLM server with IPC support
|
||||
├── vllm_patching/ # B200/Blackwell GPU patches
|
||||
│ └── patched_gpu_runner.py
|
||||
└── scripts/ # Helper scripts
|
||||
├── run_comparison.sh
|
||||
├── run_concurrent_tests.sh
|
||||
├── vllm_api_server.py # Custom vLLM server (streamlined for training)
|
||||
├── vllm_patching/ # CUDA IPC patches for weight sharing basically overriding standard vllm for this
|
||||
│ └── patched_gpu_runner.py
|
||||
└── scripts/ # Helper scripts
|
||||
├── test_lora_mode.sh
|
||||
└── test_single_copy_mode.sh
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🔄 Full System Architecture
|
||||
|
||||
The Atropos training system consists of 4 components that must run together:
|
||||
GRPO Training Loop
|
||||
|
||||
1. Generate multiple responses to the same prompt
|
||||
2. Score each response (reward)
|
||||
3. Compute ADVANTAGE = reward - mean(rewards)
|
||||
4. Train: increase probability of above-average responses
|
||||
decrease probability of below-average responses
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ ATROPOS TRAINING SYSTEM │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ vLLM │◄────►│ Environment │─────►│ run-api │
|
||||
│ Server │ │ (gsm8k_server) │ │ (Trajectory │
|
||||
│ (Inference)│ │ (Process Env) │ │ Handler API) │
|
||||
└─────────────┘ └──────────────────┘ └────────┬────────┘
|
||||
▲ │
|
||||
│ │
|
||||
│ ┌───────────────────────────────────┘
|
||||
│ │
|
||||
│ ▼
|
||||
│ ┌─────────────┐
|
||||
└────────│ GRPO │
|
||||
│ Trainer │
|
||||
│ (grpo.py) │
|
||||
└─────────────┘
|
||||
### Key Concepts
|
||||
|
||||
| Concept | What It Means |
|
||||
|---------|---------------|
|
||||
| **Advantage** | How much better/worse than average a response was |
|
||||
| **Importance Sampling** | Corrects for policy drift during training |
|
||||
| **KL Penalty** | Prevents the model from changing too drastically from base |
|
||||
| **Clipping** | Limits update magnitude for stability |
|
||||
|
||||
|
||||
## System Architecture
|
||||
|
||||
Data Flow:
|
||||
1. run-api : Central API that receives trajectories and serves batches
|
||||
2. Environment : Generates prompts, calls vLLM, scores responses → sends to run-api
|
||||
3. Trainer : Fetches batches from run-api → trains model → updates weights
|
||||
4. vLLM : Serves inference for environment (and gets weight updates)
|
||||
```
|
||||
|
||||
### Components Explained
|
||||
|
||||
| Component | Command | Port | Purpose |
|
||||
|-----------|---------|------|---------|
|
||||
| **run-api** | `run-api` | 8000 | Central trajectory handler API |
|
||||
| **Environment** | `gsm8k_server.py serve` | (internal) | Generates rollouts, scores them |
|
||||
| **vLLM** | `vllm_api_server.py` | 9001 | Model inference |
|
||||
| **Trainer** | `grpo.py` | (client) | Fetches batches, trains model |
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Three Training Modes
|
||||
|
||||
| Mode | Description | vLLM Setup | Best For |
|
||||
|------|-------------|------------|----------|
|
||||
| **Legacy** (`none`) | Trainer manages vLLM, restarts with new checkpoints | Auto-managed | Simple setup, different GPUs |
|
||||
| **Shared vLLM** (`shared_vllm`) | Single-copy mode via CUDA IPC - no model duplication! | External, `VLLM_ENABLE_SHARED_WEIGHTS=1` | Same GPU, max efficiency |
|
||||
| **LoRA** (`lora_only`) | Train adapters only, hot-swap in vLLM | External, `--enable-lora` | Fast training, small checkpoints |
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
```bash
|
||||
# Install dependencies
|
||||
pip install torch transformers peft vllm wandb requests tenacity pydantic
|
||||
|
||||
# Set environment variables
|
||||
export LOGDIR=/tmp/atropos_test
|
||||
export MODEL=Qwen/Qwen2.5-3B-Instruct
|
||||
mkdir -p $LOGDIR
|
||||
1. Environment generates prompts → calls vLLM → scores responses
|
||||
2. Environment sends trajectories to run-api
|
||||
3. Trainer fetches batches from run-api
|
||||
4. Trainer updates model weights
|
||||
5. (shared_vllm) vLLM sees updates immediately via CUDA IPC
|
||||
(lora_only) Trainer pushes adapter to vLLM periodically
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📖 Detailed Usage for Each Mode
|
||||
## Three Training Modes
|
||||
|
||||
### Mode 1: Legacy (Checkpoint + Restart)
|
||||
| Mode | Description | Memory | Best For |
|
||||
|------|-------------|--------|----------|
|
||||
| **shared_vllm** | Single-copy via CUDA IPC | 1x model | Same GPU, maximum efficiency |
|
||||
| **lora_only** | Train adapters, hot-swap | 1x + small adapter | Fast iteration, small checkpoints |
|
||||
| **legacy** | Full model, restart vLLM | 2x model | Different GPUs, simple setup |
|
||||
|
||||
The simplest mode. Trainer manages vLLM internally.
|
||||
### Recommendation
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start the central API server (handles trajectories)
|
||||
run-api --port 8000
|
||||
|
||||
# Terminal 2: Start the environment server (generates rollouts)
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--env.use_wandb=False \
|
||||
--openai.model_name $MODEL \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm
|
||||
|
||||
# Terminal 3: Run training (trainer will launch its own vLLM)
|
||||
CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode none \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8000 \
|
||||
--training-steps 20 \
|
||||
--batch-size 2 \
|
||||
--save-path $LOGDIR/checkpoints_legacy \
|
||||
--benchmark
|
||||
```
|
||||
|
||||
### Mode 2: Shared vLLM (Single-Copy CUDA IPC)
|
||||
|
||||
Zero model duplication - trainer and vLLM share the exact same GPU memory!
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start the central API server
|
||||
run-api --port 8000
|
||||
|
||||
# Terminal 2: Start vLLM with shared weights enabled
|
||||
# IMPORTANT: --enforce-eager is REQUIRED to disable CUDA graphs
|
||||
# Without it, weight updates won't be visible to inference!
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \
|
||||
CUDA_VISIBLE_DEVICES=0 python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.45 \
|
||||
--enforce-eager
|
||||
|
||||
# Terminal 3: Start the environment server
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--env.use_wandb=False \
|
||||
--openai.model_name $MODEL \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm
|
||||
|
||||
# Terminal 4: Run training (attaches to vLLM's tensors)
|
||||
CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--vllm-port 9001 \
|
||||
--vllm-config-path $LOGDIR/vllm_bridge_config.json \
|
||||
--atropos-url http://localhost:8000 \
|
||||
--training-steps 20 \
|
||||
--batch-size 2 \
|
||||
--save-path $LOGDIR/checkpoints_shared \
|
||||
--benchmark
|
||||
```
|
||||
|
||||
### Mode 3: LoRA (Adapter Training)
|
||||
|
||||
Fast training with hot-swappable adapters.
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start the central API server
|
||||
run-api --port 8000
|
||||
|
||||
# Terminal 2: Start vLLM with LoRA support
|
||||
CUDA_VISIBLE_DEVICES=0 python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.45 \
|
||||
--enable-lora \
|
||||
--max-lora-rank 32 \
|
||||
--enforce-eager
|
||||
|
||||
# Terminal 3: Start the environment server
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--env.use_wandb=False \
|
||||
--openai.model_name $MODEL \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm
|
||||
|
||||
# Terminal 4: Run LoRA training
|
||||
CUDA_VISIBLE_DEVICES=1 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode lora_only \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8000 \
|
||||
--lora-r 16 \
|
||||
--lora-alpha 32 \
|
||||
--training-steps 20 \
|
||||
--batch-size 2 \
|
||||
--save-path $LOGDIR/checkpoints_lora \
|
||||
--benchmark
|
||||
```
|
||||
**Start with `lora_only`** - it's the easiest to set up and debug. Move to `shared_vllm` for production training when you need maximum efficiency for SINGLE GPU TRAINING RUNS. MULTIPLE GPU TRAINING NOT SUPPORTED .
|
||||
|
||||
---
|
||||
|
||||
## 🔬 Run All 3 Modes in Parallel (8-GPU Comparison)
|
||||
## Quick Start: LoRA Training (Recommended)
|
||||
|
||||
Use this setup to compare training efficiency across all three modes on a single 8-GPU node.
|
||||
### Step 1: Install Dependencies
|
||||
- They are listed in the requirements.txt file that you can see
|
||||
|
||||
### GPU & Port Allocation
|
||||
|
||||
| Mode | GPUs | vLLM Port | API Port | Env Port |
|
||||
|------|------|-----------|----------|----------|
|
||||
| Legacy | 0-1 | 9001 | 8001 | (internal) |
|
||||
| Shared vLLM | 2-3 | 9002 | 8002 | (internal) |
|
||||
| LoRA | 4-5 | 9003 | 8003 | (internal) |
|
||||
|
||||
### Quick Start: Use the Comparison Script
|
||||
### Step 2: Start All Components
|
||||
|
||||
**Terminal 1: API Server**
|
||||
```bash
|
||||
cd /path/to/atropos
|
||||
|
||||
# Run comparison with default 50 steps
|
||||
./example_trainer/scripts/run_comparison.sh
|
||||
|
||||
# Or specify steps
|
||||
./example_trainer/scripts/run_comparison.sh 100
|
||||
```
|
||||
|
||||
### Manual Parallel Execution
|
||||
|
||||
If you prefer to run each mode manually in separate terminal sessions:
|
||||
|
||||
```bash
|
||||
# Setup
|
||||
export MODEL="Qwen/Qwen2.5-3B-Instruct"
|
||||
export LOGDIR=/tmp/atropos_test
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
# =============================================
|
||||
# LEGACY MODE (Terminals 1-3)
|
||||
# =============================================
|
||||
|
||||
# Terminal 1: API server for legacy
|
||||
run-api --port 8001
|
||||
|
||||
# Terminal 2: Environment server
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--env.use_wandb=False \
|
||||
--env.rollout_server_url http://localhost:8001 \
|
||||
--openai.model_name $MODEL \
|
||||
--openai.base_url http://localhost:9001/v1 \
|
||||
--openai.server_type vllm \
|
||||
--slurm false
|
||||
|
||||
# Terminal 3: Trainer (manages its own vLLM)
|
||||
CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode none \
|
||||
--vllm-port 9001 \
|
||||
--atropos-url http://localhost:8001 \
|
||||
--training-steps 50 \
|
||||
--save-path $LOGDIR/checkpoints_legacy \
|
||||
--benchmark
|
||||
|
||||
# =============================================
|
||||
# SHARED VLLM MODE (Terminals 4-7)
|
||||
# =============================================
|
||||
|
||||
# Terminal 4: API server for shared
|
||||
run-api --port 8002
|
||||
```
|
||||
|
||||
# Terminal 5: vLLM server with shared weights
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \
|
||||
CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL --port 9002 --gpu-memory-utilization 0.45
|
||||
**Terminal 2: vLLM Server**
|
||||
```bash
|
||||
python -m example_trainer.vllm_api_server \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--max-model-len 4096 \
|
||||
--dtype bfloat16 \
|
||||
--enable-lora \
|
||||
--enforce-eager
|
||||
```
|
||||
|
||||
# Terminal 6: Environment server
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--env.use_wandb=False \
|
||||
--env.rollout_server_url http://localhost:8002 \
|
||||
--openai.model_name $MODEL \
|
||||
--openai.base_url http://localhost:9002/v1 \
|
||||
--openai.server_type vllm \
|
||||
--slurm false
|
||||
**Terminal 3: Environment**
|
||||
```bash
|
||||
python environments/gsm8k_server.py serve \
|
||||
--env.group_size 4 \
|
||||
--env.max_num 200 \
|
||||
--slurm.num_requests_per_time_interval 16 \
|
||||
--slurm.time_interval 10 \
|
||||
--openai.api_key "dummy" \
|
||||
--openai.base_url "http://localhost:9001" \
|
||||
--openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \
|
||||
--openai.server_type vllm
|
||||
```
|
||||
|
||||
# Terminal 7: Trainer (attaches to vLLM)
|
||||
CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--vllm-port 9002 \
|
||||
--vllm-config-path $LOGDIR/vllm_bridge_config.json \
|
||||
--atropos-url http://localhost:8002 \
|
||||
--training-steps 50 \
|
||||
--save-path $LOGDIR/checkpoints_shared \
|
||||
--benchmark
|
||||
|
||||
# =============================================
|
||||
# LORA MODE (Terminals 8-11)
|
||||
# =============================================
|
||||
|
||||
# Terminal 8: API server for LoRA
|
||||
run-api --port 8003
|
||||
|
||||
# Terminal 9: vLLM server with LoRA
|
||||
CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL --port 9003 --gpu-memory-utilization 0.45 \
|
||||
--enable-lora --max-lora-rank 32 --enforce-eager
|
||||
|
||||
# Terminal 10: Environment server
|
||||
python -u environments/gsm8k_server.py serve \
|
||||
--env.tokenizer_name $MODEL \
|
||||
--env.use_wandb=False \
|
||||
--env.rollout_server_url http://localhost:8003 \
|
||||
--openai.model_name $MODEL \
|
||||
--openai.base_url http://localhost:9003/v1 \
|
||||
--openai.server_type vllm \
|
||||
--slurm false
|
||||
|
||||
# Terminal 11: Trainer
|
||||
CUDA_VISIBLE_DEVICES=5 python -m example_trainer.grpo \
|
||||
--model-name $MODEL \
|
||||
**Terminal 4: Trainer**
|
||||
```bash
|
||||
python -m example_trainer.grpo \
|
||||
--model-name NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--weight-bridge-mode lora_only \
|
||||
--vllm-port 9003 \
|
||||
--atropos-url http://localhost:8003 \
|
||||
--lora-r 16 --lora-alpha 32 \
|
||||
--training-steps 50 \
|
||||
--save-path $LOGDIR/checkpoints_lora \
|
||||
--benchmark
|
||||
--vllm-port 9001 \
|
||||
--atropos-url "http://localhost:8002" \
|
||||
--batch-size 4 \
|
||||
--gradient-accumulation-steps 4 \
|
||||
--learning-rate 1e-5 \
|
||||
--training-steps 30 \
|
||||
--kl-coef 0.1 \
|
||||
--clip-eps 0.2 \
|
||||
--vllm-restart-interval 5 \
|
||||
--save-path ./lora_checkpoints \
|
||||
--wandb-project "grpo-training"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 Understanding the Benchmark Output
|
||||
|
||||
Each trainer outputs a benchmark summary at the end:
|
||||
|
||||
```
|
||||
======================================================================
|
||||
BENCHMARK SUMMARY (shared_vllm)
|
||||
======================================================================
|
||||
Total training time: 168.65s (2.81 min)
|
||||
Total steps: 50
|
||||
|
||||
TIMING BREAKDOWN:
|
||||
Avg step time: 11.95s
|
||||
Total step time: 59.76s
|
||||
Avg sync time: 0.00s (x0 syncs) <-- No syncs in shared mode!
|
||||
Total sync time: 0.00s
|
||||
Avg data fetch time: 10.90s
|
||||
Total data fetch time: 54.52s
|
||||
|
||||
MEMORY:
|
||||
Peak GPU memory: 31.44 GB
|
||||
Avg GPU memory: 18.88 GB
|
||||
======================================================================
|
||||
```
|
||||
|
||||
**Key metrics to compare:**
|
||||
|
||||
| Metric | Legacy | Shared vLLM | LoRA |
|
||||
|--------|--------|-------------|------|
|
||||
| **Sync time** | High (restart vLLM) | 0 (in-place update) | Low (adapter swap) |
|
||||
| **GPU memory** | 2x model | 1x model | 1x + adapter |
|
||||
| **Step time** | ~10-15s | ~10-15s | ~5-10s |
|
||||
| **Checkpoint size** | ~6GB | ~6GB | ~50MB |
|
||||
|
||||
---
|
||||
|
||||
## 🛠 CLI Reference
|
||||
### Startup Order
|
||||
|
||||
```bash
|
||||
python -m example_trainer.grpo --help
|
||||
# 1. Start API
|
||||
# 2. Wait 5s, start vLLM
|
||||
# 3. Wait for vLLM to load (check: curl http://localhost:9001/health)
|
||||
# 4. Start environment
|
||||
# 5. Start trainer
|
||||
```
|
||||
|
||||
### Key Arguments
|
||||
---
|
||||
|
||||
## Shared vLLM Mode (Advanced)
|
||||
|
||||
Single-copy mode shares GPU memory between vLLM and the trainer - zero model duplication!
|
||||
|
||||
### How It Works
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ SINGLE GPU (CUDA IPC) │
|
||||
│ ┌─────────────────────────────────────────────────────────────┐ │
|
||||
│ │ Model Weights (ONE copy in GPU memory) │ │
|
||||
│ │ (accessible via CUDA IPC handles) │ │
|
||||
│ └─────────────────────────────────────────────────────────────┘ │
|
||||
│ ▲ ▲ │
|
||||
│ │ Reads (inference) │ Writes │
|
||||
│ ┌────────┴────────┐ ┌───────────┴───────────┐ │
|
||||
│ │ vLLM Worker │ │ Trainer Process │ │
|
||||
│ │ │ │ (attached via IPC) │ │
|
||||
│ └─────────────────┘ └───────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Running Shared vLLM Mode
|
||||
|
||||
**Terminal 1: API**
|
||||
```bash
|
||||
run-api --port 8002
|
||||
```
|
||||
|
||||
**Terminal 2: vLLM with Shared Weights**
|
||||
```bash
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/grpo_training \
|
||||
python -m example_trainer.vllm_api_server \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.45 \
|
||||
--enforce-eager
|
||||
```
|
||||
|
||||
**Terminal 3: Environment**
|
||||
```bash
|
||||
python environments/gsm8k_server.py serve \
|
||||
--openai.base_url "http://localhost:9001" \
|
||||
--openai.model_name "NousResearch/Hermes-3-Llama-3.1-8B" \
|
||||
--openai.server_type vllm
|
||||
```
|
||||
|
||||
**Terminal 4: Trainer**
|
||||
```bash
|
||||
python -m example_trainer.grpo \
|
||||
--model-name NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--vllm-port 9001 \
|
||||
--vllm-config-path /tmp/grpo_training/vllm_bridge_config.json \
|
||||
--atropos-url "http://localhost:8002" \
|
||||
--kl-coef 0.1 \
|
||||
--clip-eps 0.2
|
||||
```
|
||||
|
||||
### Or Use the Unified Launcher
|
||||
|
||||
```bash
|
||||
# Single command starts both vLLM and trainer
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 python -m example_trainer.run \
|
||||
--model-name NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--atropos-url "http://localhost:8002" \
|
||||
--training-steps 30
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices & Lessons Learned
|
||||
|
||||
### 1. Always Use `--enforce-eager` with Shared Weights
|
||||
|
||||
**Why:** CUDA graphs "bake" weights at compile time. Without eager mode, vLLM won't see weight updates!
|
||||
|
||||
```bash
|
||||
# WRONG - weight updates won't be visible to inference
|
||||
python vllm_api_server.py --model $MODEL
|
||||
|
||||
# CORRECT - disables CUDA graphs
|
||||
python vllm_api_server.py --model $MODEL --enforce-eager
|
||||
```
|
||||
|
||||
### 2. Use `--openai.server_type vllm` for Training
|
||||
|
||||
The gsm8k environment needs logprobs for GRPO. Only `server_type=vllm` uses the `/generate` endpoint which returns logprobs.
|
||||
|
||||
```bash
|
||||
# CORRECT - gets logprobs for training
|
||||
--openai.server_type vllm
|
||||
|
||||
# WRONG for training - no logprobs
|
||||
--openai.server_type openai
|
||||
```
|
||||
|
||||
### 3. KL Coefficient and Clipping Are Essential
|
||||
|
||||
Without these, training will collapse (reward hacking):
|
||||
|
||||
```bash
|
||||
--kl-coef 0.1 # Prevents policy from drifting too far
|
||||
--clip-eps 0.2 # Limits update magnitude
|
||||
```
|
||||
|
||||
**Symptoms of missing KL/clipping:**
|
||||
- Accuracy drops dramatically (e.g., 59% → 7%)
|
||||
- Loss goes to very negative values
|
||||
- Model outputs become repetitive/degenerate
|
||||
|
||||
### 4. Memory Budgeting for Large Models
|
||||
|
||||
| Model Size | GPU Memory | Recommended Settings |
|
||||
|------------|------------|----------------------|
|
||||
| 8B | 80GB | `--gpu-memory-utilization 0.5` |
|
||||
| 14B | 80GB | `--gpu-memory-utilization 0.45`, `--batch-size 2` |
|
||||
| 24B | 192GB (B200) | `--gpu-memory-utilization 0.30`, `--optimizer adafactor` |
|
||||
|
||||
### 5. Start with Small Batch Sizes
|
||||
|
||||
```bash
|
||||
# Start conservative, increase if no OOM
|
||||
--batch-size 2 --gradient-accumulation-steps 8 # Effective batch = 16
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tensor Mapping (vLLM ↔ HuggingFace)
|
||||
|
||||
### The Problem
|
||||
|
||||
vLLM fuses certain layers for efficiency, but HuggingFace keeps them separate:
|
||||
|
||||
```
|
||||
HuggingFace Model: vLLM Model:
|
||||
├── q_proj [4096, 4096] ├── qkv_proj [12288, 4096] ← FUSED!
|
||||
├── k_proj [1024, 4096] │ (contains q, k, v concatenated)
|
||||
├── v_proj [1024, 4096] │
|
||||
│ │
|
||||
├── gate_proj [14336, 4096] ├── gate_up_proj [28672, 4096] ← FUSED!
|
||||
├── up_proj [14336, 4096] │ (contains gate and up concatenated)
|
||||
```
|
||||
|
||||
### How We Solve It
|
||||
|
||||
The trainer creates **views** into vLLM's fused tensors:
|
||||
|
||||
```python
|
||||
# vLLM has: qkv_proj.weight [12288, 4096]
|
||||
# We need: q_proj [4096], k_proj [1024], v_proj [1024]
|
||||
|
||||
# Get sizes from model config
|
||||
q_size = num_heads * head_dim # e.g., 4096
|
||||
k_size = num_kv_heads * head_dim # e.g., 1024
|
||||
v_size = num_kv_heads * head_dim # e.g., 1024
|
||||
|
||||
# Create views (no copy!)
|
||||
hf_model.q_proj.weight = vllm_qkv[0:4096, :] # First chunk
|
||||
hf_model.k_proj.weight = vllm_qkv[4096:5120, :] # Second chunk
|
||||
hf_model.v_proj.weight = vllm_qkv[5120:6144, :] # Third chunk
|
||||
```
|
||||
|
||||
### Key Insight: Views Share Memory
|
||||
|
||||
```python
|
||||
# These point to the SAME GPU memory:
|
||||
trainer_q_proj.data_ptr() == vllm_qkv_proj.data_ptr() # True!
|
||||
|
||||
# So when optimizer updates trainer weights:
|
||||
optimizer.step() # Updates trainer_q_proj
|
||||
|
||||
# vLLM sees the change immediately (same memory)!
|
||||
```
|
||||
|
||||
### The Config File
|
||||
|
||||
vLLM exports tensor mappings to `vllm_bridge_config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
"param_mappings": {
|
||||
"model.layers.0.self_attn.qkv_proj.weight": {
|
||||
"ipc_handle": "base64_encoded_cuda_ipc_handle",
|
||||
"shape": [6144, 4096],
|
||||
"dtype": "bfloat16"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ❓ FAQ
|
||||
|
||||
|
||||
### Q: Why isn't vLLM seeing my weight updates?
|
||||
|
||||
**A:** CUDA graphs are caching the old weights. Add `--enforce-eager`:
|
||||
|
||||
```bash
|
||||
python vllm_api_server.py --model $MODEL --enforce-eager
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Q: How do I debug logprob alignment issues?
|
||||
|
||||
**A:** Look for these log messages:
|
||||
```
|
||||
[WARNING] ref_logprobs at generated positions avg 0.85 (should be negative!)
|
||||
```
|
||||
|
||||
This means inference logprobs aren't being passed correctly. Check that:
|
||||
1. Environment uses `--openai.server_type vllm`
|
||||
2. vLLM returns logprobs (check `/generate` response)
|
||||
|
||||
### Q: Why does vLLM v1 engine fail with CUDA fork errors?
|
||||
|
||||
**A:** vLLM v1 uses multiprocessing that conflicts with CUDA initialization. We default to v0 engine:
|
||||
|
||||
```python
|
||||
# vllm_api_server.py automatically sets:
|
||||
os.environ.setdefault("VLLM_USE_V1", "0")
|
||||
```
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Atropos API not reachable"
|
||||
|
||||
```bash
|
||||
# Start the API server first
|
||||
run-api --port 8002
|
||||
```
|
||||
|
||||
### "404 Not Found" on /generate
|
||||
|
||||
You're using a vLLM server that doesn't expose `/generate`. Use our custom server:
|
||||
|
||||
```bash
|
||||
python -m example_trainer.vllm_api_server ... # Has /generate
|
||||
# NOT: python -m vllm.entrypoints.openai.api_server # Only has /v1/*
|
||||
```
|
||||
|
||||
### "Cannot re-initialize CUDA in forked subprocess"
|
||||
|
||||
vLLM v1 engine issue. We disable it by default, but if you see this:
|
||||
|
||||
```bash
|
||||
VLLM_USE_V1=0 python -m example_trainer.vllm_api_server ...
|
||||
```
|
||||
|
||||
### "LogProb Alignment: MISMATCH!"
|
||||
|
||||
Weight updates aren't visible to inference. Fix:
|
||||
|
||||
```bash
|
||||
# Add --enforce-eager to vLLM
|
||||
python vllm_api_server.py --model $MODEL --enforce-eager
|
||||
```
|
||||
|
||||
### OOM (Out of Memory)
|
||||
|
||||
Reduce memory usage:
|
||||
|
||||
```bash
|
||||
--gpu-memory-utilization 0.4 # Less vLLM memory
|
||||
--batch-size 2 # Smaller batches
|
||||
--gradient-accumulation-steps 8 # Compensate with accumulation
|
||||
--seq-len 1024 # Shorter sequences
|
||||
--optimizer adafactor # Uses less memory than AdamW
|
||||
```
|
||||
|
||||
### "FlexibleArgumentParser" import error
|
||||
|
||||
vLLM version incompatibility. Our server handles this automatically, but make sure you're using:
|
||||
|
||||
```bash
|
||||
python -m example_trainer.vllm_api_server # NOT direct vllm commands
|
||||
```
|
||||
|
||||
### Training is slow / no batches
|
||||
|
||||
1. Check vLLM is running: `curl http://localhost:9001/health`
|
||||
2. Check API is running: `curl http://localhost:8002/info`
|
||||
3. Check environment is connected and generating rollouts
|
||||
|
||||
---
|
||||
|
||||
## 📊 Monitoring Training
|
||||
|
||||
### Key Metrics to Watch
|
||||
|
||||
| Metric | Healthy Range | Problem If... |
|
||||
|--------|---------------|---------------|
|
||||
| `mean_ratio` | 0.8 - 1.2 | Far from 1.0 = policy changed too much |
|
||||
| `mean_kl` | 0.01 - 0.1 | > 0.5 = policy drifting |
|
||||
| `clipped_fraction` | < 0.3 | > 0.5 = learning rate too high |
|
||||
| `loss` | Gradually decreasing | Exploding or very negative |
|
||||
|
||||
### WandB Logging
|
||||
|
||||
```bash
|
||||
--use-wandb \
|
||||
--wandb-project "my-grpo-training" \
|
||||
--wandb-run-name "hermes-8b-gsm8k"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CLI Reference
|
||||
|
||||
### Essential Arguments
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `--model-name` | (required) | HuggingFace model ID |
|
||||
| `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` |
|
||||
| `--weight-bridge-mode` | `none` | `shared_vllm`, `lora_only`, or `none` |
|
||||
| `--training-steps` | 10 | Number of training steps |
|
||||
| `--batch-size` | 2 | Batch size |
|
||||
| `--vllm-port` | 9001 | vLLM server port |
|
||||
| `--atropos-url` | `http://localhost:8000` | Atropos API server URL |
|
||||
| `--save-path` | `trained_model_checkpoints` | Checkpoint directory |
|
||||
| `--benchmark` | false | Show timing stats |
|
||||
| `--debug-loading` | false | Verbose model loading output |
|
||||
| `--batch-size` | 2 | Micro-batch size |
|
||||
| `--gradient-accumulation-steps` | 1 | Effective batch = batch × accum |
|
||||
|
||||
### LoRA-specific Arguments
|
||||
### GRPO Hyperparameters
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `--kl-coef` | 0.1 | KL penalty strength (higher = more conservative) |
|
||||
| `--clip-eps` | 0.2 | PPO clipping range [1-ε, 1+ε] |
|
||||
| `--learning-rate` | 1e-6 | Learning rate |
|
||||
|
||||
### LoRA Arguments
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `--lora-r` | 16 | LoRA rank |
|
||||
| `--lora-alpha` | 32 | LoRA alpha (scaling) |
|
||||
| `--lora-alpha` | 32 | LoRA scaling factor |
|
||||
| `--lora-dropout` | 0.05 | LoRA dropout |
|
||||
| `--lora-target-modules` | `q_proj v_proj` | Modules to apply LoRA |
|
||||
|
||||
### Single-Copy Mode Arguments
|
||||
### vLLM Arguments
|
||||
|
||||
| Argument | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `--single-copy` | false | Enable CUDA IPC mode |
|
||||
| `--vllm-config-path` | auto-detect | Path to `vllm_bridge_config.json` |
|
||||
| `--vllm-port` | 9001 | vLLM server port |
|
||||
| `--vllm-config-path` | auto | Path to bridge config (shared mode) |
|
||||
| `--gpu-memory-utilization` | 0.9 | vLLM GPU memory fraction |
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
## Module Documentation
|
||||
|
||||
### "Atropos API not reachable"
|
||||
```bash
|
||||
# Make sure run-api is running
|
||||
run-api --port 8000
|
||||
```
|
||||
| Module | Purpose |
|
||||
|--------|---------|
|
||||
| `grpo.py` | CLI entry point, dispatches to training modes |
|
||||
| `run.py` | Unified launcher for shared_vllm mode |
|
||||
| `cli.py` | Single source of truth for all CLI arguments |
|
||||
| `config.py` | `TrainingConfig` Pydantic model |
|
||||
| `api.py` | Communication with Atropos API |
|
||||
| `data.py` | Batch preprocessing, logprob extraction |
|
||||
| `model.py` | Model loading, CUDA IPC attachment, tensor mapping |
|
||||
| `training.py` | GRPO loss computation |
|
||||
| `trainers.py` | Mode-specific training loops |
|
||||
| `vllm_api_server.py` | Streamlined vLLM server for training |
|
||||
| `vllm_manager.py` | vLLM process lifecycle management |
|
||||
| `checkpointing.py` | Save/load checkpoints and adapters |
|
||||
|
||||
### "vLLM server not running" (LoRA mode)
|
||||
```bash
|
||||
# LoRA mode requires external vLLM with --enable-lora
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL --port 9001 --enable-lora --enforce-eager
|
||||
```
|
||||
|
||||
### "Could not find vllm_bridge_config.json" (Shared mode)
|
||||
```bash
|
||||
# Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/atropos python example_trainer/vllm_api_server.py ...
|
||||
```
|
||||
|
||||
### "LogProb Alignment: MISMATCH!" in shared_vllm mode
|
||||
If you see `[MISMATCH!]` in the logprob alignment output, inference and training are seeing different weights. This is usually caused by **CUDA graphs**.
|
||||
|
||||
**Symptom:** `inference_mean` stays constant while `training_mean` changes. The `diff` increases over time.
|
||||
|
||||
**Fix:** Add `--enforce-eager` when starting vLLM:
|
||||
```bash
|
||||
VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model $MODEL --port 9001 --enforce-eager # <-- REQUIRED!
|
||||
```
|
||||
|
||||
**Why:** CUDA graphs "bake" model weights into compiled graphs at startup. Updates to the underlying tensors are NOT reflected in inference. Using `--enforce-eager` disables CUDA graphs, so vLLM reads from the shared tensors on every forward pass.
|
||||
|
||||
### "Triton compilation error" on B200/Blackwell GPUs
|
||||
The patched vLLM server (`vllm_api_server.py`) automatically applies B200 fixes. If using standard vLLM, add `--enforce-eager`.
|
||||
|
||||
### Port already in use
|
||||
```bash
|
||||
# Kill existing processes
|
||||
pkill -f "run-api"
|
||||
pkill -f "vllm_api_server.py"
|
||||
pkill -f "gsm8k_server.py"
|
||||
```
|
||||
|
||||
### No batches available / trainer hangs
|
||||
```bash
|
||||
# Ensure the environment server is connected to the correct API and vLLM
|
||||
# Check that vLLM is running and environment can reach it
|
||||
curl http://localhost:9001/health
|
||||
curl http://localhost:8000/info
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📚 Module Documentation
|
||||
|
||||
### `config.py`
|
||||
Contains `TrainingConfig` - all training parameters as a Pydantic model.
|
||||
|
||||
### `api.py`
|
||||
- `check_atropos_api()` - Wait for run-api server
|
||||
- `register_trainer()` - Register with Atropos
|
||||
- `get_batch()` - Fetch training batch from run-api
|
||||
|
||||
### `data.py`
|
||||
- `pad_data_to_good_offset()` - Pad sequences to GPU-friendly lengths
|
||||
- `get_data()` - Fetch and preprocess batches
|
||||
|
||||
### `model.py`
|
||||
- `load_model_and_tokenizer()` - Load model based on mode
|
||||
- `_attach_to_vllm_shared_tensors()` - CUDA IPC attachment
|
||||
- `_create_vllm_to_hf_mapping()` - Handle QKV/Gate-Up fusion
|
||||
|
||||
### `training.py`
|
||||
- `compute_grpo_loss()` - GRPO loss computation
|
||||
- `run_training_step()` - Single step with gradient accumulation
|
||||
- `log_metrics()` - Console and WandB logging
|
||||
- `finalize_training()` - Cleanup and summary
|
||||
|
||||
### `checkpointing.py`
|
||||
- `save_checkpoint()` - Save full model
|
||||
- `save_lora_checkpoint()` - Save LoRA adapter only
|
||||
|
||||
### `vllm_manager.py`
|
||||
- `launch_vllm_server()` - Start vLLM process
|
||||
- `terminate_vllm_process()` - Stop vLLM
|
||||
- `hotswap_lora_adapter()` - Hot-swap LoRA in vLLM
|
||||
|
||||
### `trainers.py`
|
||||
- `train_legacy()` - Checkpoint + restart mode
|
||||
- `train_shared_vllm()` - Single-copy CUDA IPC mode
|
||||
- `train_lora()` - Adapter training mode
|
||||
|
||||
### `cli.py`
|
||||
- `parse_args()` - Argparse setup
|
||||
- `config_from_args()` - Convert args to TrainingConfig
|
||||
|
||||
---
|
||||
|
||||
## 📝 License
|
||||
|
||||
MIT License
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue