diff --git a/example_trainer/README.md b/example_trainer/README.md index 46a36882..022c2680 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -1,1069 +1,485 @@ -# GRPO Example Trainer +# GRPO Trainer -This directory contains an example script (`grpo.py`) demonstrating how to integrate a custom training loop with the Atropos API for reinforcement learning using the GRPO (Group Relative Policy Optimization) algorithm. +A modular training framework for fine-tuning language models with **Group Relative Policy Optimization (GRPO)**, designed to work with the Atropos environment system. -## Training Modes +## πŸ“ Module Structure -The trainer supports three weight synchronization modes: - -| Mode | Description | Sync Latency | Best For | -|------|-------------|--------------|----------| -| **Legacy** (`none`) | Save checkpoints, restart vLLM | ~30-60 seconds | Simple setups, debugging | -| **Single-Copy** (`shared_vllm`) | Direct CUDA IPC - ONE model copy! | 0 ms | Production, memory efficiency | -| **LoRA** (`lora_only`) | Train adapters, hot-swap | ~1-5 seconds | Memory-constrained, fast iteration | +``` +example_trainer/ +β”œβ”€β”€ grpo.py # CLI entry point (dispatches to trainers) +β”œβ”€β”€ config.py # TrainingConfig dataclass +β”œβ”€β”€ api.py # Atropos API communication +β”œβ”€β”€ data.py # Data fetching & preprocessing +β”œβ”€β”€ model.py # Model loading & CUDA IPC shared memory +β”œβ”€β”€ training.py # Loss computation & training step +β”œβ”€β”€ checkpointing.py # Save models & LoRA adapters +β”œβ”€β”€ vllm_manager.py # vLLM process management +β”œβ”€β”€ trainers.py # Training mode implementations +β”œβ”€β”€ cli.py # CLI argument parsing +β”œβ”€β”€ vllm_api_server.py # Custom vLLM server with IPC support +β”œβ”€β”€ vllm_patching/ # B200/Blackwell GPU patches +β”‚ └── patched_gpu_runner.py +└── scripts/ # Helper scripts + β”œβ”€β”€ run_comparison.sh + β”œβ”€β”€ run_concurrent_tests.sh + β”œβ”€β”€ test_lora_mode.sh + └── test_single_copy_mode.sh +``` --- -## Model Compatibility +## πŸ”„ Full System Architecture -This training pipeline works with models that meet the following requirements: +The Atropos training system consists of 4 components that must run together: -### Required Compatibility +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ ATROPOS TRAINING SYSTEM β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -| Component | Requirement | -|-----------|-------------| -| **HuggingFace** | Must support `AutoModelForCausalLM` | -| **vLLM** | Must be in [vLLM's supported model list](https://docs.vllm.ai/en/latest/models/supported_models.html) | -| **Architecture** | Decoder-only (causal language model) | + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ vLLM │◄────►│ Environment │─────►│ run-api β”‚ + β”‚ Server β”‚ β”‚ (gsm8k_server) β”‚ β”‚ (Trajectory β”‚ + β”‚ (Inference)β”‚ β”‚ (Process Env) β”‚ β”‚ Handler API) β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β–² β”‚ + β”‚ β”‚ + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ + β”‚ β–Ό + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + └────────│ GRPO β”‚ + β”‚ Trainer β”‚ + β”‚ (grpo.py) β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -### βœ… Compatible Model Families +Data Flow: +1. run-api : Central API that receives trajectories and serves batches +2. Environment : Generates prompts, calls vLLM, scores responses β†’ sends to run-api +3. Trainer : Fetches batches from run-api β†’ trains model β†’ updates weights +4. vLLM : Serves inference for environment (and gets weight updates) +``` -- **Qwen** (Qwen2, Qwen2.5) -- **Llama** (Llama-2, Llama-3, Llama-3.1) -- **Mistral** (Mistral, Mixtral) -- **Phi** (Phi-2, Phi-3) -- **Gemma** (Gemma, Gemma-2) -- **DeepSeek** (DeepSeek-Coder, DeepSeek-V2) -- **Yi** (Yi, Yi-1.5) -- **StarCoder** (StarCoder2) +### Components Explained -### ❌ Not Compatible - -| Type | Reason | -|------|--------| -| Encoder-only (BERT, RoBERTa) | No causal language modeling head | -| Encoder-decoder (T5, BART) | Different architecture, not supported by vLLM | -| Non-HuggingFace models | Requires `AutoModelForCausalLM.from_pretrained()` | - -### Single-Copy Mode Constraints - -| Constraint | Reason | -|------------|--------| -| `tensor-parallel-size` must be 1 | Multi-GPU tensor parallelism not yet supported for IPC | -| Model must fit on single GPU | No model sharding in single-copy mode | -| Trainer and vLLM on same GPU(s) | CUDA IPC requires same device | - -> **Tip**: For models too large for a single GPU, use **LoRA mode** (`--weight-bridge-mode lora_only`) instead. +| Component | Command | Port | Purpose | +|-----------|---------|------|---------| +| **run-api** | `run-api` | 8000 | Central trajectory handler API | +| **Environment** | `gsm8k_server.py serve` | (internal) | Generates rollouts, scores them | +| **vLLM** | `vllm_api_server.py` | 9001 | Model inference | +| **Trainer** | `grpo.py` | (client) | Fetches batches, trains model | --- -## Quick Start with GSM8k (Single-Copy Mode) +## 🎯 Three Training Modes -This is the **recommended** production setup for maximum training throughput and memory efficiency. +| Mode | Description | vLLM Setup | Best For | +|------|-------------|------------|----------| +| **Legacy** (`none`) | Trainer manages vLLM, restarts with new checkpoints | Auto-managed | Simple setup, different GPUs | +| **Shared vLLM** (`shared_vllm`) | Single-copy mode via CUDA IPC - no model duplication! | External, `VLLM_ENABLE_SHARED_WEIGHTS=1` | Same GPU, max efficiency | +| **LoRA** (`lora_only`) | Train adapters only, hot-swap in vLLM | External, `--enable-lora` | Fast training, small checkpoints | + +--- + +## πŸš€ Quick Start ### Prerequisites ```bash # Install dependencies -pip install -r example_trainer/requirements.txt +pip install torch transformers peft vllm wandb requests tenacity pydantic -# Install GSM8k environment dependencies -pip install datasets latex2sympy2_extended math_verify -``` - -### Architecture Overview - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ SINGLE-COPY TRAINING ARCHITECTURE β”‚ -β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ GSM8k Env │───▢│ Atropos API │◀───│ GRPO Trainer β”‚ β”‚ -β”‚ β”‚ (problems) β”‚ β”‚ (batching) β”‚ β”‚ - Attached to vLLM's tensors β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ - optimizer.step() updates both β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ CUDA IPC β”‚ -β”‚ β”‚ β”‚ (same memory!) β”‚ -β”‚ β–Ό β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ vLLM Inference Server (GPU 0) β”‚ β”‚ -β”‚ β”‚ - Model weights in GPU memory β”‚ β”‚ -β”‚ β”‚ - Trainer sees same tensors via IPC β”‚ β”‚ -β”‚ β”‚ - Generates rollouts for scoring β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -### How Single-Copy Mode Works - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ SAME GPU(s) β”‚ -β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ SHARED MODEL TENSORS β”‚ β”‚ -β”‚ β”‚ (only ONE copy in GPU memory!) β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β–² β–² β”‚ -β”‚ β”‚ Reads/Writes β”‚ Reads β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ Trainer β”‚ β”‚ vLLM β”‚ β”‚ -β”‚ β”‚ (gradients) β”‚ β”‚ (inference) β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ β”‚ optimizer.step() β”‚ -β”‚ β”‚ (updates shared tensors in-place) β”‚ -β”‚ β–Ό β”‚ -β”‚ vLLM immediately sees new weights! β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -- **Memory**: 1x model size (truly shared via CUDA IPC!) -- **Sync Latency**: 0ms (same memory, no copy needed) -- **Requirement**: Trainer and vLLM on SAME GPU(s) - ---- - -### Step-by-Step Guide - -**IMPORTANT: GPU Allocation** -- vLLM and Trainer run on the SAME GPU(s) -- Use `tensor-parallel-size 1` for single-copy mode (TP>1 not yet supported) - ---- - -#### Step 1: Kill Any Existing Processes - -```bash -pkill -9 -u $USER -f "vllm|grpo|python|run-api" 2>/dev/null; sleep 3 -``` - -#### Step 2: Setup Directory - -```bash -cd ~/atropos_stuff/atropos -rm -f vllm_bridge_config.json vllm.log trainer.log api.log gsm8k.log -``` - -#### Step 3: Set Environment Variables - -```bash -export VLLM_ENABLE_SHARED_WEIGHTS=1 -export NUM_INFERENCE_NODES=0 -export LOGDIR=. -``` - -#### Step 4: Start vLLM Server - -```bash -CUDA_VISIBLE_DEVICES=0 python -u example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-14B-Instruct \ - --tensor-parallel-size 1 \ - --port 9001 \ - > vllm.log 2>&1 & -echo "vLLM starting on GPU 0..." -``` - -#### Step 5: Wait for vLLM to Load - -```bash -tail -f vllm.log -``` - -Wait until you see: `Uvicorn running on http://0.0.0.0:9001` - -Then press **Ctrl+C** to stop tailing. - -#### Step 6: Verify IPC Handles Exported - -```bash -grep -E "IPC|Exported|single_copy" vllm.log -``` - -You should see: -``` -[vLLM Patch] Exported X IPC handles for single-copy mode -[vLLM Patch] βœ“ Exported 339 params to vllm_bridge_config.json -``` - -#### Step 7: Start an Environment (GSM8K in this case) - -```bash -python environments/gsm8k_server.py serve \ - --slurm False \ - --openai.model_name Qwen/Qwen2.5-14B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - --openai.api_key x \ - --env.tokenizer_name Qwen/Qwen2.5-14B-Instruct \ - --env.use_wandb False \ - > gsm8k.log 2>&1 & -echo "GSM8K environment started" -sleep 10 -``` - -#### Step 8: Start Trainer (Same GPU as vLLM!) - -```bash -CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-14B-Instruct \ - --weight-bridge-mode shared_vllm \ - --training-steps 100 \ - 2>&1 | tee trainer.log -``` - -#### Step 9: Monitor Training - -```bash -tail -f trainer.log -``` - -You should see: -``` -[Setup] βœ“ Attached 195 tensors to vLLM's shared memory -[Setup] βœ“ Single-copy mode active - using vLLM's tensors directly! -[2/2] Starting training for 100 steps -Step 1/100 - [SINGLE-COPY] Weights updated in-place - step 1 -``` - ---- - -### Quick Copy-Paste (All-in-One) - -```bash -# Kill everything and setup -pkill -9 -u $USER -f "vllm|grpo|python" 2>/dev/null; sleep 3 -cd ~/atropos_stuff/atropos -rm -f vllm_bridge_config.json *.log - -# Environment variables -export VLLM_ENABLE_SHARED_WEIGHTS=1 NUM_INFERENCE_NODES=0 LOGDIR=. - -# Start vLLM -CUDA_VISIBLE_DEVICES=0 python -u example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-14B-Instruct --tensor-parallel-size 1 --port 9001 > vllm.log 2>&1 & -echo "Waiting 90s for vLLM..."; sleep 90 - -# Start GSM8k environment -python environments/gsm8k_server.py serve --slurm False \ - --openai.model_name Qwen/Qwen2.5-14B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm --openai.api_key x \ - --env.tokenizer_name Qwen/Qwen2.5-14B-Instruct \ - --env.use_wandb False > gsm8k.log 2>&1 & -sleep 10 - -# Start trainer (same GPU!) -CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-14B-Instruct \ - --weight-bridge-mode shared_vllm \ - --training-steps 100 \ - 2>&1 | tee trainer.log -``` - ---- - -## How Each Mode Works (Data Flow Diagrams) - -### Single-Copy Mode (`--weight-bridge-mode shared_vllm`) ⭐ RECOMMENDED - -**The Magic**: Trainer and vLLM share the EXACT SAME GPU memory via CUDA IPC. - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ SINGLE-COPY MODE - COMPLETE DATA FLOW β”‚ -β”‚ β”‚ -β”‚ STEP 1: GSM8k sends problem β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ GSM8k Server │──── "What is 15 Γ— 7?" β”€β”€β”€β”€β–Άβ”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ (Environment) β”‚ β”‚ Atropos API β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ (Batching) β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ STEP 2: Atropos forwards to vLLM β”‚ β”‚ -β”‚ β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ GPU MEMORY β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ MODEL WEIGHTS (ONE COPY - SHARED!) β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ embed_tokens.weight, layers.*.qkv_proj, ..., lm_head.weight β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ (address: 0x7f8a12340000) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β”‚ β–² β–² β”‚ β”‚ -β”‚ β”‚ β”‚ STEP 3: READ β”‚ STEP 6: WRITE β”‚ β”‚ -β”‚ β”‚ β”‚ (generate tokens) β”‚ (optimizer.step) β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ vLLM Server β”‚ β”‚ Trainer β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ (grpo.py) β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ Generates: β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ "15 Γ— 7 = 105" β”‚ β”‚ STEP 5: Compute β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ GRPO loss & β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ gradients β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β–²β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ STEP 4: Return completion β”‚ β”‚ -β”‚ β–Ό β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ GSM8k Server β”‚β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ (Scoring) β”‚ β”‚ -β”‚ β”‚ β”‚ Scores: "15 Γ— 7 = 105" βœ“ reward=1.0 β”‚ -β”‚ β”‚ β”‚ "15 Γ— 7 = 100" βœ— reward=0.0 β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ -β”‚ STEP 7: IMMEDIATE UPDATE β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ After optimizer.step(), vLLM's NEXT inference uses the NEW weights! β”‚ β”‚ -β”‚ β”‚ NO SYNC NEEDED - it's the same memory! β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -**Key Points:** -- βœ… ONE copy of weights in GPU memory -- βœ… 0ms sync latency (same memory!) -- βœ… Memory efficient (~1x model size) -- ⚠️ Requires same GPU for trainer and vLLM - ---- - -### LoRA Mode (`--weight-bridge-mode lora_only`) - -**The Idea**: Freeze base model, only train small adapter layers. Hot-swap adapters into vLLM. - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ LORA MODE - COMPLETE DATA FLOW β”‚ -β”‚ β”‚ -β”‚ STEP 1: GSM8k sends problem β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ GSM8k Server │──── "What is 15 Γ— 7?" β”€β”€β”€β”€β–Άβ”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ (Environment) β”‚ β”‚ Atropos API β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ STEP 2: Forward to vLLM β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ vLLM GPU MEMORY β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ BASE MODEL (frozen, ~6GB) β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ + LORA ADAPTER A (current, ~50MB) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ STEP 3: Inference with base + adapter A β”‚ β”‚ -β”‚ β”‚ β–Ό β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ vLLM Server β”‚ ──── "15 Γ— 7 = 105" ────▢ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ TRAINER GPU MEMORY (separate!) β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ BASE MODEL (frozen, ~6GB) β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ + LORA ADAPTER B (training, ~50MB) ◀── gradients flow here only! β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ STEP 4-5: Receive rollout, compute loss, update adapter B β”‚ β”‚ -β”‚ β”‚ β–Ό β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ Trainer β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ (grpo.py) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ β”‚ STEP 6: Every N steps, save adapter B to disk β”‚ -β”‚ β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” STEP 7: POST /lora/load β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ adapter_step_N/ β”‚ ─────────────────────────────────▢│ vLLM Server β”‚ β”‚ -β”‚ β”‚ (50MB on disk) β”‚ β”‚ Swaps A β†’ B β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ -β”‚ STEP 8: Next inference uses NEW adapter B β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ Sync latency: 1-5 seconds (save to disk + HTTP load) β”‚ β”‚ -β”‚ β”‚ Memory: 2x base model + adapters β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - -**Key Points:** -- βœ… Small adapter files (~50MB vs ~28GB) -- βœ… Works on separate GPUs -- βœ… Easy to switch between adapters -- ⚠️ 1-5 second sync latency -- ⚠️ 2x base model memory (trainer + vLLM) - ---- - -### Legacy Mode (`--weight-bridge-mode none`) - -**The Simple Approach**: Save full checkpoints, restart vLLM to load new weights. - -> **Note**: In Legacy mode, the **trainer manages its own vLLM process**. Do NOT start vLLM separately - the trainer will automatically start, stop, and restart vLLM with updated checkpoints. - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ LEGACY MODE - COMPLETE DATA FLOW β”‚ -β”‚ β”‚ -β”‚ STEP 1: GSM8k sends problem β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ GSM8k Server │──── "What is 15 Γ— 7?" β”€β”€β”€β”€β–Άβ”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ (Environment) β”‚ β”‚ Atropos API β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ STEP 2: Forward to vLLM β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ vLLM GPU MEMORY β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ FULL MODEL - Version 1 (~28GB) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ STEP 3: Inference β”‚ β”‚ -β”‚ β”‚ β–Ό β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ vLLM Server β”‚ ──── "15 Γ— 7 = 105" ────▢ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ TRAINER GPU MEMORY (separate!) β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ FULL MODEL - Version 2 (~28GB + gradients + optimizer) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β”‚ β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ STEP 4-5: Receive rollout, compute loss, update weights β”‚ β”‚ -β”‚ β”‚ β–Ό β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ Trainer β”‚ β”‚ β”‚ -β”‚ β”‚ β”‚ (grpo.py) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ β”‚ STEP 6: Every N steps, save FULL checkpoint to disk (~28GB) β”‚ -β”‚ β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ checkpoint/ β”‚ β”‚ -β”‚ β”‚ step_N/ β”‚ (28GB on disk!) β”‚ -β”‚ β”‚ - model.safetensors β”‚ -β”‚ β”‚ - config.json β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ β”‚ -β”‚ β”‚ STEP 7: RESTART vLLM with new checkpoint β”‚ -β”‚ β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ β”‚ 1. Kill vLLM process β”‚ β”‚ -β”‚ β”‚ β”‚ 2. Start new vLLM with --model checkpoint/step_N/ β”‚ β”‚ -β”‚ β”‚ β”‚ 3. Wait for model to load (~30-60 seconds) β”‚ β”‚ -β”‚ β”‚ β”‚ 4. Resume training β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β–Ό β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ vLLM GPU MEMORY (restarted) β”‚ β”‚ -β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚ -β”‚ β”‚ β”‚ FULL MODEL - Version 2 (loaded from checkpoint) β”‚ β”‚ β”‚ -β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β”‚ β”‚ -β”‚ STEP 8: Next inference uses updated model β”‚ -β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ -β”‚ β”‚ Sync latency: 30-60 seconds (save + restart + reload) β”‚ β”‚ -β”‚ β”‚ Memory: 2x full model β”‚ β”‚ -β”‚ β”‚ Disk: 28GB per checkpoint β”‚ β”‚ -β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` ---- - -## Mode Comparison Summary - -``` -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ MODE COMPARISON AT A GLANCE β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ β”‚ SINGLE-COPY β”‚ LORA β”‚ LEGACY β”‚ -β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ -β”‚ Sync Latency β”‚ 0 ms ⚑ β”‚ 1-5 sec β”‚ 30-60 sec β”‚ -β”‚ GPU Memory β”‚ 1x model β”‚ 2x model β”‚ 2x model β”‚ -β”‚ Disk Space β”‚ 28GB/ckpt β”‚ 50MB/adapter β”‚ 28GB/ckpt β”‚ -β”‚ Complexity β”‚ Medium β”‚ Medium β”‚ Simple β”‚ -β”‚ Same GPU? β”‚ Required β”‚ Optional β”‚ Optional β”‚ -β”‚ Best For β”‚ Production β”‚ Experiments β”‚ Debugging β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ -``` - ---- - -## Alternative Mode Commands - -### Legacy Mode (Checkpoint + Restart) - -For simple setups or debugging. Saves checkpoints and restarts vLLM to load new weights. - -**IMPORTANT**: In Legacy mode, the **trainer manages its own vLLM process**. Do NOT start vLLM separately - the trainer will start, stop, and restart vLLM automatically as needed. - -```bash -# Step 1: Set environment +# Set environment variables export LOGDIR=/tmp/atropos_test +export MODEL=Qwen/Qwen2.5-3B-Instruct mkdir -p $LOGDIR +``` -# Step 2: Kill any existing processes -pkill -f "vllm_api_server" || true -pkill -f "gsm8k_server" || true -sleep 2 +--- -# Step 3: Start GSM8k environment (pointing to port 9001 where trainer will launch vLLM) -LOGDIR=$LOGDIR python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \ - --env.use_wandb false \ - --openai.model_name Qwen/Qwen2.5-3B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - > $LOGDIR/gsm8k_legacy.log 2>&1 & -sleep 5 +## πŸ“– Detailed Usage for Each Mode -# Step 4: Start trainer (it will launch vLLM automatically!) -CUDA_VISIBLE_DEVICES=0 python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ +### Mode 1: Legacy (Checkpoint + Restart) + +The simplest mode. Trainer manages vLLM internally. + +```bash +# Terminal 1: Start the central API server (handles trajectories) +run-api --port 8000 + +# Terminal 2: Start the environment server (generates rollouts) +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9001/v1 + +# Terminal 3: Run training (trainer will launch its own vLLM) +CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ + --model-name $MODEL \ --weight-bridge-mode none \ --vllm-port 9001 \ - --training-steps 100 \ - --vllm-restart-interval 10 \ + --atropos-url http://localhost:8000 \ + --training-steps 20 \ --batch-size 2 \ - --lr 1e-5 \ --save-path $LOGDIR/checkpoints_legacy \ - --benchmark \ - 2>&1 | tee $LOGDIR/trainer_legacy.log + --benchmark ``` -**What happens:** -1. Trainer starts its own vLLM process on port 9001 -2. Training proceeds, accumulating weight updates -3. Every `--vllm-restart-interval` steps, trainer: - - Saves a checkpoint to disk - - Kills the current vLLM process - - Starts a new vLLM process with the updated checkpoint -4. This continues until training completes +### Mode 2: Shared vLLM (Single-Copy CUDA IPC) -### LoRA Mode (Adapter Training) - -Trains only adapter weights. Small checkpoints, lower memory. Requires vLLM to be started separately with `--enable-lora`. +Zero model duplication - trainer and vLLM share the exact same GPU memory! ```bash -# Step 1: Set environment +# Terminal 1: Start the central API server +run-api --port 8000 + +# Terminal 2: Start vLLM with shared weights enabled +VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \ +CUDA_VISIBLE_DEVICES=0 python example_trainer/vllm_api_server.py \ + --model $MODEL \ + --port 9001 \ + --gpu-memory-utilization 0.45 + +# Terminal 3: Start the environment server +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9001/v1 + +# Terminal 4: Run training (attaches to vLLM's tensors) +CUDA_VISIBLE_DEVICES=0 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode shared_vllm \ + --vllm-port 9001 \ + --vllm-config-path $LOGDIR/vllm_bridge_config.json \ + --atropos-url http://localhost:8000 \ + --training-steps 20 \ + --batch-size 2 \ + --save-path $LOGDIR/checkpoints_shared \ + --benchmark +``` + +### Mode 3: LoRA (Adapter Training) + +Fast training with hot-swappable adapters. + +```bash +# Terminal 1: Start the central API server +run-api --port 8000 + +# Terminal 2: Start vLLM with LoRA support +CUDA_VISIBLE_DEVICES=0 python example_trainer/vllm_api_server.py \ + --model $MODEL \ + --port 9001 \ + --gpu-memory-utilization 0.45 \ + --enable-lora \ + --max-lora-rank 32 \ + --enforce-eager + +# Terminal 3: Start the environment server +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9001/v1 + +# Terminal 4: Run LoRA training +CUDA_VISIBLE_DEVICES=1 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode lora_only \ + --vllm-port 9001 \ + --atropos-url http://localhost:8000 \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps 20 \ + --batch-size 2 \ + --save-path $LOGDIR/checkpoints_lora \ + --benchmark +``` + +--- + +## πŸ”¬ Run All 3 Modes in Parallel (8-GPU Comparison) + +Use this setup to compare training efficiency across all three modes on a single 8-GPU node. + +### GPU & Port Allocation + +| Mode | GPUs | vLLM Port | API Port | Env Port | +|------|------|-----------|----------|----------| +| Legacy | 0-1 | 9001 | 8001 | (internal) | +| Shared vLLM | 2-3 | 9002 | 8002 | (internal) | +| LoRA | 4-5 | 9003 | 8003 | (internal) | + +### Quick Start: Use the Comparison Script + +```bash +cd /path/to/atropos + +# Run comparison with default 50 steps +./example_trainer/scripts/run_comparison.sh + +# Or specify steps +./example_trainer/scripts/run_comparison.sh 100 +``` + +### Manual Parallel Execution + +If you prefer to run each mode manually in separate terminal sessions: + +```bash +# Setup +export MODEL="Qwen/Qwen2.5-3B-Instruct" export LOGDIR=/tmp/atropos_test mkdir -p $LOGDIR -# Step 2: Kill any existing processes -pkill -f "vllm_api_server" || true -pkill -f "gsm8k_server" || true -sleep 2 +# ============================================= +# LEGACY MODE (Terminals 1-3) +# ============================================= -# Step 3: Start vLLM with LoRA support (use --enforce-eager to avoid Triton issues) -LOGDIR=$LOGDIR python -u example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-3B-Instruct \ - --port 9001 \ - --dtype bfloat16 \ - --gpu-memory-utilization 0.4 \ - --enable-lora \ - --max-lora-rank 32 \ - --enforce-eager \ - > $LOGDIR/vllm_lora.log 2>&1 & -echo "Waiting 60s for vLLM..."; sleep 60 +# Terminal 1: API server for legacy +run-api --port 8001 -# Verify vLLM is ready -curl -s http://localhost:9001/health && echo " vLLM is ready!" - -# Step 4: Start GSM8k environment -LOGDIR=$LOGDIR python -u environments/gsm8k_server.py serve \ - --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \ - --env.use_wandb false \ - --openai.model_name Qwen/Qwen2.5-3B-Instruct \ +# Terminal 2: Environment server +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - > $LOGDIR/gsm8k_lora.log 2>&1 & -sleep 10 + --server.port 8001 -# Step 5: Start trainer with LoRA (can use different GPU) -CUDA_VISIBLE_DEVICES=1 python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode lora_only \ +# Terminal 3: Trainer (manages its own vLLM) +CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode none \ --vllm-port 9001 \ - --lora-r 16 \ - --lora-alpha 32 \ - --training-steps 100 \ - --batch-size 2 \ - --lr 1e-4 \ - --save-path $LOGDIR/checkpoints_lora \ - --benchmark \ - 2>&1 | tee $LOGDIR/trainer_lora.log -``` + --atropos-url http://localhost:8001 \ + --training-steps 50 \ + --save-path $LOGDIR/checkpoints_legacy \ + --benchmark -**What happens:** -1. vLLM runs with LoRA support enabled -2. Trainer loads base model + creates LoRA adapters -3. After each sync interval, trainer: - - Saves small LoRA adapter (~50MB) - - Hot-swaps adapter to vLLM via `/lora/load` endpoint -4. vLLM uses new adapter for next inference batch +# ============================================= +# SHARED VLLM MODE (Terminals 4-7) +# ============================================= + +# Terminal 4: API server for shared +run-api --port 8002 + +# Terminal 5: vLLM server with shared weights +VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \ +CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \ + --model $MODEL --port 9002 --gpu-memory-utilization 0.45 + +# Terminal 6: Environment server +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9002/v1 \ + --server.port 8002 + +# Terminal 7: Trainer (attaches to vLLM) +CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode shared_vllm \ + --vllm-port 9002 \ + --vllm-config-path $LOGDIR/vllm_bridge_config.json \ + --atropos-url http://localhost:8002 \ + --training-steps 50 \ + --save-path $LOGDIR/checkpoints_shared \ + --benchmark + +# ============================================= +# LORA MODE (Terminals 8-11) +# ============================================= + +# Terminal 8: API server for LoRA +run-api --port 8003 + +# Terminal 9: vLLM server with LoRA +CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \ + --model $MODEL --port 9003 --gpu-memory-utilization 0.45 \ + --enable-lora --max-lora-rank 32 --enforce-eager + +# Terminal 10: Environment server +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9003/v1 \ + --server.port 8003 + +# Terminal 11: Trainer +CUDA_VISIBLE_DEVICES=5 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode lora_only \ + --vllm-port 9003 \ + --atropos-url http://localhost:8003 \ + --lora-r 16 --lora-alpha 32 \ + --training-steps 50 \ + --save-path $LOGDIR/checkpoints_lora \ + --benchmark +``` --- -## Configuration Reference +## πŸ“Š Understanding the Benchmark Output -### Environment Variables +Each trainer outputs a benchmark summary at the end: -| Variable | Required | Description | Example | -|----------|----------|-------------|---------| -| `VLLM_ENABLE_SHARED_WEIGHTS` | Yes (single-copy) | Enable vLLM patching for IPC | `1` | -| `NUM_INFERENCE_NODES` | Yes | Number of vLLM nodes (0 = local) | `0` | -| `LOGDIR` | Recommended | Directory for vllm_bridge_config.json | `.` | -| `CUDA_VISIBLE_DEVICES` | Recommended | GPU allocation | `0` | +``` +====================================================================== +BENCHMARK SUMMARY (shared_vllm) +====================================================================== + Total training time: 168.65s (2.81 min) + Total steps: 50 + + TIMING BREAKDOWN: + Avg step time: 11.95s + Total step time: 59.76s + Avg sync time: 0.00s (x0 syncs) <-- No syncs in shared mode! + Total sync time: 0.00s + Avg data fetch time: 10.90s + Total data fetch time: 54.52s + + MEMORY: + Peak GPU memory: 31.44 GB + Avg GPU memory: 18.88 GB +====================================================================== +``` -### Trainer CLI Options +**Key metrics to compare:** -| Option | Default | Description | -|--------|---------|-------------| +| Metric | Legacy | Shared vLLM | LoRA | +|--------|--------|-------------|------| +| **Sync time** | High (restart vLLM) | 0 (in-place update) | Low (adapter swap) | +| **GPU memory** | 2x model | 1x model | 1x + adapter | +| **Step time** | ~10-15s | ~10-15s | ~5-10s | +| **Checkpoint size** | ~6GB | ~6GB | ~50MB | + +--- + +## πŸ›  CLI Reference + +```bash +python -m example_trainer.grpo --help +``` + +### Key Arguments + +| Argument | Default | Description | +|----------|---------|-------------| | `--model-name` | (required) | HuggingFace model ID | | `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` | -| `--single-copy` | `false` | Enable TRUE single-copy mode via CUDA IPC | -| `--vllm-config-path` | (auto-detect) | Explicit path to `vllm_bridge_config.json` | -| `--vllm-port` | `9001` | vLLM server port | -| `--training-steps` | `10` | Total optimization steps | -| `--batch-size` | `2` | Micro-batch size | -| `--lr` | `1e-5` | Learning rate | +| `--training-steps` | 10 | Number of training steps | +| `--batch-size` | 2 | Batch size | +| `--vllm-port` | 9001 | vLLM server port | +| `--atropos-url` | `http://localhost:8000` | Atropos API server URL | | `--save-path` | `trained_model_checkpoints` | Checkpoint directory | +| `--benchmark` | false | Show timing stats | +| `--debug-loading` | false | Verbose model loading output | -### vLLM Server Options +### LoRA-specific Arguments -| Option | Description | -|--------|-------------| -| `--model` | HuggingFace model ID | -| `--tensor-parallel-size` | Number of GPUs (use 1 for single-copy) | -| `--port` | Server port (default: 9001) | -| `--dtype` | Model dtype (`bfloat16`, `float16`, `auto`) | -| `--gpu-memory-utilization` | Fraction of GPU memory for KV cache (default: 0.9) | +| Argument | Default | Description | +|----------|---------|-------------| +| `--lora-r` | 16 | LoRA rank | +| `--lora-alpha` | 32 | LoRA alpha (scaling) | +| `--lora-dropout` | 0.05 | LoRA dropout | +| `--lora-target-modules` | `q_proj v_proj` | Modules to apply LoRA | + +### Single-Copy Mode Arguments + +| Argument | Default | Description | +|----------|---------|-------------| +| `--single-copy` | false | Enable CUDA IPC mode | +| `--vllm-config-path` | auto-detect | Path to `vllm_bridge_config.json` | --- -## The vLLM Bridge Config (vllm_bridge_config.json) - -The `vllm_bridge_config.json` file is the critical communication mechanism between the vLLM inference server and the GRPO trainer in single-copy mode. Understanding this file is essential for debugging and advanced configurations. - -### What It Is - -When you start vLLM with `VLLM_ENABLE_SHARED_WEIGHTS=1`, the patched `GPUModelRunner` exports CUDA IPC (Inter-Process Communication) handles for all model tensors. These handles allow another process (the trainer) to access the exact same GPU memoryβ€”no copying required. - -### Why It's Important - -1. **True Single-Copy Architecture**: Instead of loading the model twice (once for training, once for inference), both processes share the same tensors in GPU memory. - -2. **Zero-Latency Weight Updates**: When `optimizer.step()` modifies the weights, vLLM immediately sees the changesβ€”no serialization, no network transfer, no disk I/O. - -3. **Memory Efficiency**: For a 7B model (~14GB in bf16), you save ~14GB of GPU memory compared to having two separate copies. - -### File Location - -The trainer searches for `vllm_bridge_config.json` in this order: - -1. **Explicit path** (if `--vllm-config-path` is provided) -2. **`$LOGDIR/vllm_bridge_config.json`** (if `LOGDIR` env var is set) -3. **`./vllm_bridge_config.json`** (current directory) -4. **`/tmp/atropos_bridge/vllm_bridge_config.json`** (default fallback) - -**Tip**: To avoid "Config not found" errors, always set `LOGDIR`: +## πŸ› Troubleshooting +### "Atropos API not reachable" ```bash -export LOGDIR=. +# Make sure run-api is running +run-api --port 8000 ``` -### File Contents - -The JSON file contains everything needed to reconstruct tensor references in another process: - -```json -{ - "model": "Qwen/Qwen2.5-3B-Instruct", - "tp_degree": 1, - "dp_shard_degree": 1, - - "param_names": [ - "model.embed_tokens.weight", - "model.layers.0.self_attn.qkv_proj.weight", - ... - ], - - "param_mappings": { - "model.embed_tokens.weight": { - "vllm_name": "model.embed_tokens.weight", - "shape": [152064, 2048], - "dtype": "torch.bfloat16", - "device": "cuda:0" - }, - ... - }, - - "ipc_handles": { - "model.embed_tokens.weight": { - "device_index": 0, - "ipc_handle_b64": "AmPA0pN...", - "storage_size": 623902720, - "storage_offset": 0, - "ref_counter_handle_b64": "Y2JY...", - "ref_counter_offset": 0, - "event_handle_b64": "wRIs...", - "event_sync_required": true, - "shape": [152064, 2048], - "dtype": "torch.bfloat16" - }, - ... - }, - - "shared_weights_enabled": true, - "single_copy_enabled": true, - "num_params": 255 -} -``` - -#### Field Descriptions - -| Field | Description | -|-------|-------------| -| `model` | HuggingFace model identifier | -| `tp_degree` | Tensor parallel degree (must be 1 for single-copy) | -| `param_names` | List of all parameter names in the model | -| `param_mappings` | Shape, dtype, and device info for each parameter | -| `ipc_handles` | CUDA IPC handles for reconstructing shared tensors | -| `ipc_handle_b64` | The actual CUDA IPC handle (base64-encoded bytes) | -| `ref_counter_handle_b64` | Reference counter for CUDA memory (base64) | -| `event_handle_b64` | CUDA event handle for synchronization (base64) | -| `storage_size` | Size of the underlying storage in bytes | - -### How the Trainer Uses It - -1. **Load Config**: Trainer reads `vllm_bridge_config.json` -2. **Create Shell Model**: Uses `AutoModelForCausalLM.from_config()` with meta tensors (no memory allocation) -3. **Attach IPC Handles**: For each parameter, reconstructs the tensor using `torch.UntypedStorage._new_shared_cuda()` with the IPC handles -4. **Verify Shapes**: Ensures trainer's model architecture matches vLLM's sharding - -```python -# Simplified version of what happens internally: -for name, ipc_info in config["ipc_handles"].items(): - # Decode IPC handle from base64 - ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) - - # Reconstruct storage from IPC handle - storage = torch.UntypedStorage._new_shared_cuda( - device_index, ipc_handle, storage_size, ... - ) - - # Create tensor from shared storage - tensor = torch.tensor(storage).view(shape).to(dtype) - - # Replace model parameter with shared tensor - model.get_parameter(name).data = tensor -``` - -### Specifying the Config Path Explicitly - -If auto-detection isn't working (e.g., in complex cluster setups), you can specify the path explicitly: - +### "vLLM server not running" (LoRA mode) ```bash -# If vLLM writes config to a non-standard location: -python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode shared_vllm \ - --single-copy \ - --vllm-config-path /shared/nfs/vllm_bridge_config.json \ - --training-steps 50 +# LoRA mode requires external vLLM with --enable-lora +python example_trainer/vllm_api_server.py \ + --model $MODEL --port 9001 --enable-lora --enforce-eager ``` -### Common Issues - -| Symptom | Cause | Fix | -|---------|-------|-----| -| "Could not find vllm_bridge_config.json" | vLLM didn't export config | Check `VLLM_ENABLE_SHARED_WEIGHTS=1` was set BEFORE starting vLLM | -| Config exists but has empty `ipc_handles` | Patch didn't run | Ensure vLLM is using our custom `vllm_api_server.py` | -| "tuple of 8 items expected" | IPC handle format mismatch | Update to latest code (handles all 8 CUDA IPC tuple components) | -| "size mismatch" errors | Tensor parallel mismatch | Use `tensor-parallel-size 1` for single-copy mode | - ---- - -## FAQ & Troubleshooting - -### Q: I get "Could not find vllm_bridge_config.json" - -**A:** vLLM didn't export the IPC handles. Check: - -1. `VLLM_ENABLE_SHARED_WEIGHTS=1` was set **before** starting vLLM -2. `LOGDIR` is set to a valid, writable directory -3. Look for export messages in vllm.log: +### "Could not find vllm_bridge_config.json" (Shared mode) ```bash -grep "Exported" vllm.log +# Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set +VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=/tmp/atropos python example_trainer/vllm_api_server.py ... ``` -If the file exists but in a different location, specify it explicitly: +### "Triton compilation error" on B200/Blackwell GPUs +The patched vLLM server (`vllm_api_server.py`) automatically applies B200 fixes. If using standard vLLM, add `--enforce-eager`. + +### Port already in use ```bash -python grpo.py ... --vllm-config-path /path/to/vllm_bridge_config.json +# Kill existing processes +pkill -f "run-api" +pkill -f "vllm_api_server.py" +pkill -f "gsm8k_server.py" ``` ---- - -### Q: I get "CUDA out of memory" when starting the trainer - -**A:** For single-copy mode, trainer and vLLM MUST be on the same GPU(s). Check: - +### No batches available / trainer hangs ```bash -# Both should use the same CUDA_VISIBLE_DEVICES -CUDA_VISIBLE_DEVICES=0 python ... vllm_api_server.py ... -CUDA_VISIBLE_DEVICES=0 python ... grpo.py ... +# Ensure the environment server is connected to the correct API and vLLM +# Check that vLLM is running and environment can reach it +curl http://localhost:9001/health +curl http://localhost:8000/info ``` --- -### Q: Trainer crashes with "Cannot copy out of meta tensor" +## πŸ“š Module Documentation -**A:** Some model buffers (like rotary embeddings) weren't initialized. This is a known issue being fixed. Update to the latest code. +### `config.py` +Contains `TrainingConfig` - all training parameters as a Pydantic model. + +### `api.py` +- `check_atropos_api()` - Wait for run-api server +- `register_trainer()` - Register with Atropos +- `get_batch()` - Fetch training batch from run-api + +### `data.py` +- `pad_data_to_good_offset()` - Pad sequences to GPU-friendly lengths +- `get_data()` - Fetch and preprocess batches + +### `model.py` +- `load_model_and_tokenizer()` - Load model based on mode +- `_attach_to_vllm_shared_tensors()` - CUDA IPC attachment +- `_create_vllm_to_hf_mapping()` - Handle QKV/Gate-Up fusion + +### `training.py` +- `compute_grpo_loss()` - GRPO loss computation +- `run_training_step()` - Single step with gradient accumulation +- `log_metrics()` - Console and WandB logging +- `finalize_training()` - Cleanup and summary + +### `checkpointing.py` +- `save_checkpoint()` - Save full model +- `save_lora_checkpoint()` - Save LoRA adapter only + +### `vllm_manager.py` +- `launch_vllm_server()` - Start vLLM process +- `terminate_vllm_process()` - Stop vLLM +- `hotswap_lora_adapter()` - Hot-swap LoRA in vLLM + +### `trainers.py` +- `train_legacy()` - Checkpoint + restart mode +- `train_shared_vllm()` - Single-copy CUDA IPC mode +- `train_lora()` - Adapter training mode + +### `cli.py` +- `parse_args()` - Argparse setup +- `config_from_args()` - Convert args to TrainingConfig --- -### Q: Single-copy mode doesn't work with tensor-parallel > 1 +## πŸ“ License -**A:** Currently, single-copy mode only works with `tensor-parallel-size 1`. For larger models that need tensor parallelism, use a single GPU with a smaller model, or wait for multi-GPU single-copy support. - ---- - -### Q: How do I check GPU memory usage? - -**A:** -```bash -nvidia-smi -``` - -For single-copy mode with Qwen2.5-14B: -- GPU 0: ~28GB (shared between vLLM and trainer) - ---- - -### Q: How do I stop all processes? - -**A:** -```bash -pkill -9 -u $USER -f "vllm|grpo|python|run-api" -``` - ---- - -## Files in This Directory - -| File | Description | -|------|-------------| -| `grpo.py` | Main trainer script with all modes | -| `vllm_api_server.py` | Custom vLLM server with shared memory patches | -| `vllm_patching/` | vLLM patches for CUDA IPC support | -| `requirements.txt` | Python dependencies | -| `README.md` | This documentation | - -### vllm_patching/ Directory - -| File | Description | -|------|-------------| -| `__init__.py` | Module exports and patch application | -| `patched_gpu_runner.py` | Patches GPUModelRunner to export CUDA IPC handles | - ---- - -## Performance Comparison - -| Mode | Sync Latency | Memory (14B model) | Best For | -|------|--------------|-------------------|----------| -| **Legacy** | 30-60s | 2x model | Debugging | -| **Single-Copy** | 0ms | 1x model (shared!) | Production | -| **LoRA** | 5-10s | 1x model + adapters | Memory-constrained | - ---- - -## Checkpoint Locations - -| Mode | Location | Size | -|------|----------|------| -| Legacy | `trained_model_checkpoints/step_N/` | ~28GB (14B model) | -| Single-Copy | `trained_model_checkpoints/step_N/` | ~28GB | -| LoRA | `trained_model_checkpoints/adapter_step_N/` | ~50MB | - ---- - -## Feature Availability Matrix - -### What's Available - -| Feature | Status | Notes | -|---------|--------|-------| -| **Single-Copy Mode** | Working | True shared memory via CUDA IPC | -| **LoRA Mode** | Working | Hot-swap adapters without restart | -| **Legacy Mode** | Working | Checkpoint-based, restart vLLM | -| **Qwen Models** | Working | Qwen2, Qwen2.5 (0.5B - 72B) | -| **Llama Models** | Working | Llama-2, Llama-3, Llama-3.1 | -| **Mistral Models** | Working | Mistral-7B, Mixtral | -| **Single GPU** | Working | All modes supported | -| **bfloat16/float16** | Working | Configurable via `--dtype` | -| **Gradient Checkpointing** | Available | Reduces memory usage | -| **Wandb Logging** | Working | Via `--use-wandb` flag | -| **Custom Environments** | Working | Extend `BaseEnv` class | - -### What's NOT Available - -| Feature | Mode | Status | Reason / Workaround | -|---------|------|--------|---------------------| -| **Multi-GPU (TP > 1)** | Single-Copy | Not Supported | CUDA IPC handles are per-device; sharding complicates sharing | -| **Multi-GPU (TP > 1)** | LoRA | Supported | vLLM handles TP, trainer only swaps adapters | -| **Multi-GPU (TP > 1)** | Legacy | Supported | Standard vLLM with TP supported | -| **Pipeline Parallel** | Single-Copy | Not Supported | Would need cross-device IPC | -| **Pipeline Parallel** | LoRA/Legacy | Via vLLM | Use `--pipeline-parallel-size` flag | -| **Data Parallel** | Single-Copy | Not Supported | Shared tensors can't be safely updated by multiple trainers | -| **Data Parallel** | LoRA/Legacy | Manual | Run multiple trainer instances (see docs below) | -| **Multi-Node** | Single-Copy | Not Supported | CUDA IPC is single-node only | -| **Multi-Node** | LoRA/Legacy | Via vLLM | vLLM supports distributed inference | -| **DeepSpeed/FSDP** | All | Not Integrated | Would require custom integration with trainer | -| **Quantized Models** | Single-Copy | Not Supported | IPC handles may not work with quantized tensors | -| **Quantized Models** | LoRA/Legacy | Supported | Standard vLLM quantization (GPTQ, AWQ, etc.) | -| **Encoder-Decoder** | All | Not Supported | Architecture not supported by vLLM | - -### Multi-GPU Support Summary - -| Mode | Tensor Parallel | Pipeline Parallel | Data Parallel | -|------|-----------------|-------------------|---------------| -| **Single-Copy** | TP=1 only | Not Supported | Not Supported | -| **LoRA** | Supported | Via vLLM | Multiple Trainers | -| **Legacy** | Supported | Via vLLM | Multiple Trainers | - -> **Key Point**: The multi-GPU limitation is **ONLY for single-copy mode** due to CUDA IPC constraints. -> LoRA and Legacy modes work with standard vLLM which fully supports tensor parallelism. - -#### Pipeline Parallel (PP) - -vLLM supports pipeline parallelism via `--pipeline-parallel-size`. For LoRA/Legacy modes: - -```bash -# LoRA/Legacy with Pipeline Parallel (2 GPUs for PP) -python -u example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-14B-Instruct \ - --tensor-parallel-size 1 \ - --pipeline-parallel-size 2 \ - --port 9001 -``` - -**Note**: PP requires the model to be split across GPUs by layers. Performance may vary. - -#### Data Parallel (DP) - -Data parallelism means running **multiple trainer instances** against the same vLLM server. Each trainer processes different batches: - -```bash -# Terminal 1: First trainer instance -CUDA_VISIBLE_DEVICES=4 python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-7B-Instruct \ - --weight-bridge-mode lora_only \ - --trainer-rank 0 \ - --world-size 2 \ - > trainer_0.log 2>&1 & - -# Terminal 2: Second trainer instance -CUDA_VISIBLE_DEVICES=5 python -u example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-7B-Instruct \ - --weight-bridge-mode lora_only \ - --trainer-rank 1 \ - --world-size 2 \ - > trainer_1.log 2>&1 & -``` - -**Note**: DP requires gradient synchronization between trainers. Currently, trainers operate independently - true distributed DP would need additional coordination. - -### GPU Support - -| GPU Type | Single-Copy | LoRA | Legacy | Notes | -|----------|-------------|------|--------|-------| -| **NVIDIA A100** | YES | YES | YES | Recommended | -| **NVIDIA H100** | YES | YES | YES | Recommended | -| **NVIDIA B200** | YES | YES | YES | Recommended | -| **NVIDIA RTX 4090** | YES | YES | YES | Consumer, works well | -| **NVIDIA RTX 3090** | YES | YES | YES | Consumer, works well | -| **NVIDIA V100** | ? | YES | YES | Old, may have IPC issues | - -### Memory Requirements (Approximate) - -| Model Size | Single-Copy | LoRA | Legacy | -|------------|-------------|------|--------| -| 0.5B - 1B | 4-6 GB | 4-6 GB | 8-12 GB | -| 3B | 8-12 GB | 8-12 GB | 16-24 GB | -| 7B | 16-20 GB | 16-20 GB | 32-40 GB | -| 14B | 32-40 GB | 32-40 GB | 64-80 GB | -| 32B | 70-80 GB | 70-80 GB | 140+ GB | -| 70B+ | Single GPU impossible | 80+ GB | 160+ GB | - -> **Note**: Single-copy mode uses ~50% less memory than legacy because there's only ONE model copy. - ---- - -## Known Limitations - -### Single-Copy Mode Specific - -| Limitation | Reason | Workaround | -|------------|--------|------------| -| **Same GPU Required** | CUDA IPC only works within same physical device | Use same `CUDA_VISIBLE_DEVICES` for trainer and vLLM | -| **TP=1 Only** | Trainer expects unsharded model; IPC per-device | Use LoRA mode for TP > 1 | -| **Custom Server Required** | Standard `vllm serve` doesn't export IPC handles | Use `vllm_api_server.py` | -| **Single Node Only** | CUDA IPC is node-local | Use LoRA/Legacy for multi-node | - -### LoRA Mode Specific - -| Limitation | Reason | Workaround | -|------------|--------|------------| -| **~5s Swap Latency** | Adapter weights need to be loaded | Acceptable for most use cases | -| **vLLM LoRA Support Required** | Model must support LoRA in vLLM | Check vLLM documentation | - -### General Limitations - -| Limitation | Reason | Workaround | -|------------|--------|------------| -| **GSM8k Needs `server_type=vllm`** | Default `openai` type lacks state tracking | Use `--openai.server_type vllm` | -| **Decoder-Only Models Only** | vLLM architecture constraint | Use different framework for encoder-decoder | -| **Custom vLLM Server Required** | Standard `vllm serve` lacks IPC patches | Use `vllm_api_server.py` for all modes | - ---- - -## Future Work - -### High Priority - -| Feature | Description | -|---------|-------------| -| **Multi-GPU Single-Copy** | Support `tensor-parallel-size > 1` with sharded IPC | -| **Automatic Server Type Detection** | Auto-detect correct `server_type` for environments | -| **Checkpoint Resume** | Resume training from checkpoints seamlessly | - -### Medium Priority - -| Feature | Description | Difficulty | -|---------|-------------|------------| -| **DeepSpeed Integration** | ZeRO optimization for larger models | Hard | -| **Quantization Support** | Test and document GPTQ/AWQ in single-copy | Medium | -| **Multi-Node Training** | Distributed training across nodes | Hard | -| **Streaming Weights** | Stream weight updates instead of full sync | Medium | -| **Mixed Precision Training** | Support fp8/int8 training | Medium | - - -## Contributing - -We welcome contributions! Priority areas: - -1. **Multi-GPU single-copy support** - The biggest missing feature -2. **Better documentation** - More examples, tutorials -3. **Environment implementations** - New RL environments -4. **Bug fixes** - Especially edge cases in IPC handling - -See the main repository CONTRIBUTING.md for guidelines. +MIT License diff --git a/example_trainer/__init__.py b/example_trainer/__init__.py index f0ebdb72..01334052 100644 --- a/example_trainer/__init__.py +++ b/example_trainer/__init__.py @@ -1,7 +1,34 @@ """ -Example trainer implementations of how to implement a trainer for the Atropos library. +GRPO (Group Relative Policy Optimization) Trainer + +A training framework for fine-tuning language models with reinforcement learning, +designed to work with the Atropos environment system. + +Supports three training modes: +- Legacy: Checkpoint-based training with vLLM restarts +- Shared vLLM: Single-copy mode with CUDA IPC (no model duplication!) +- LoRA: Adapter-only training with hot-swap capability + +Usage: + # As CLI + python -m example_trainer.grpo --model-name Qwen/Qwen2.5-3B-Instruct --training-steps 100 + + # As library + from example_trainer import TrainingConfig, train_legacy, train_shared_vllm, train_lora + + config = TrainingConfig(model_name="Qwen/Qwen2.5-3B-Instruct", training_steps=100) + train_legacy(config) """ -from example_trainer.grpo import TrainingConfig, train +from .config import TrainingConfig +from .trainers import train_legacy, train_shared_vllm, train_lora +from .cli import parse_args, config_from_args -__all__ = ["TrainingConfig", "train"] +__all__ = [ + "TrainingConfig", + "train_legacy", + "train_shared_vllm", + "train_lora", + "parse_args", + "config_from_args", +] diff --git a/example_trainer/api.py b/example_trainer/api.py new file mode 100644 index 00000000..cfa4ac48 --- /dev/null +++ b/example_trainer/api.py @@ -0,0 +1,102 @@ +""" +Atropos API communication utilities. + +Handles communication with the Atropos API server for: +- Server health checks +- Trainer registration +- Batch retrieval +""" + +import time as _time + +import requests +from tenacity import retry, stop_after_attempt, wait_exponential + +from .config import TrainingConfig + + +def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool: + """ + Check if the Atropos API server is reachable. + + Args: + url: Base URL of the Atropos API server + timeout: Maximum time to wait for the server + + Returns: + True if server is reachable + """ + start = _time.time() + while _time.time() - start < timeout: + try: + response = requests.get(f"{url}/info", timeout=2) + if response.status_code == 200: + print(f"[Trainer] βœ“ Atropos API server is reachable at {url}") + return True + except requests.exceptions.ConnectionError: + pass + except Exception as e: + print(f"[Trainer] Waiting for Atropos API at {url}... ({e})") + _time.sleep(1) + + print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}") + return False + + +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) +def register_trainer(config: TrainingConfig): + """ + Register the trainer with the Atropos API. + + Verifies registration succeeded before returning. + """ + url = config.atropos_url + response = requests.post( + f"{url}/register", + json={ + # wandb fields are required strings - use empty string if None + "wandb_group": config.wandb_group or "", + "wandb_project": config.wandb_project or "", + "batch_size": config.batch_size * config.gradient_accumulation_steps, + "max_token_len": config.seq_len, + "starting_step": 0, + "checkpoint_dir": config.save_path, + "save_checkpoint_interval": config.training_steps, + "num_steps": config.training_steps, + }, + timeout=10, + ) + + # Check for HTTP errors + response.raise_for_status() + + # Verify we got a valid response with UUID + data = response.json() + if "uuid" not in data: + raise RuntimeError(f"Registration failed: {data}") + + print(f"[Trainer] βœ“ Registered with Atropos API at {url} (uuid: {data['uuid']})") + + +@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) +def get_batch(url: str = "http://localhost:8000"): + """ + Get a batch of training data from the Atropos API. + + Args: + url: Base URL of the Atropos API server + + Returns: + Batch data dictionary containing tokens, masks, scores, etc. + + Raises: + RuntimeError: If trainer is not registered or other API error + """ + data = requests.get(f"{url}/batch", timeout=10).json() + + # Check if there was an error (trainer not registered) + if data.get("status") == "error": + raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}") + + return data + diff --git a/example_trainer/checkpointing.py b/example_trainer/checkpointing.py new file mode 100644 index 00000000..dec746a5 --- /dev/null +++ b/example_trainer/checkpointing.py @@ -0,0 +1,90 @@ +""" +Checkpoint saving utilities for GRPO trainer. + +Handles saving model checkpoints for different training modes: +- Full model checkpoints (legacy and shared_vllm modes) +- LoRA adapter checkpoints +""" + +import os +import shutil + +import torch + + +def save_checkpoint( + model: torch.nn.Module, + tokenizer, + save_path: str, + step: int, + is_final: bool = False, +) -> str: + """ + Save full model checkpoint. + + Args: + model: Model to save + tokenizer: Tokenizer to save + save_path: Base directory for checkpoints + step: Current training step + is_final: Whether this is the final checkpoint + + Returns: + Path where checkpoint was saved + """ + if is_final: + checkpoint_path = os.path.join(save_path, "final_model") + else: + checkpoint_path = os.path.join(save_path, f"step_{step}") + + print(f" Saving checkpoint to {checkpoint_path}...") + + if os.path.exists(checkpoint_path): + shutil.rmtree(checkpoint_path) + os.makedirs(checkpoint_path, exist_ok=True) + + model.save_pretrained(checkpoint_path) + tokenizer.save_pretrained(checkpoint_path) + + print(" Checkpoint saved.") + return checkpoint_path + + +def save_lora_checkpoint( + model: torch.nn.Module, + save_path: str, + step: int, + is_final: bool = False, +) -> str: + """ + Save LoRA adapter checkpoint. + + Only saves the LoRA adapter weights, not the full model. + This results in much smaller checkpoint files. + + Args: + model: PEFT model with LoRA adapters + save_path: Base directory for checkpoints + step: Current training step + is_final: Whether this is the final checkpoint + + Returns: + Path where adapter was saved + """ + if is_final: + adapter_path = os.path.join(save_path, "final_adapter") + else: + adapter_path = os.path.join(save_path, f"adapter_step_{step}") + + print(f" Saving LoRA adapter to {adapter_path}...") + + if os.path.exists(adapter_path): + shutil.rmtree(adapter_path) + os.makedirs(adapter_path, exist_ok=True) + + # Save only the adapter weights (much smaller than full model) + model.save_pretrained(adapter_path) + + print(" Adapter saved.") + return adapter_path + diff --git a/example_trainer/cli.py b/example_trainer/cli.py new file mode 100644 index 00000000..ad3fc99a --- /dev/null +++ b/example_trainer/cli.py @@ -0,0 +1,271 @@ +""" +Command-line interface for GRPO trainer. + +Provides argument parsing and configuration building. +""" + +import argparse + +import torch + +from .config import TrainingConfig + + +def parse_args() -> argparse.Namespace: + """ + Parse command-line arguments for the GRPO trainer. + + Returns: + Parsed arguments namespace + """ + parser = argparse.ArgumentParser( + description="GRPO Trainer with optional shared-weight vLLM integration", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # === Core Training Arguments === + parser.add_argument( + "--model-name", + type=str, + required=True, + help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')", + ) + parser.add_argument( + "--lr", + type=float, + default=1e-5, + help="Learning rate for the optimizer", + ) + parser.add_argument( + "--training-steps", + type=int, + default=10, + help="Number of training steps to run", + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size for training", + ) + parser.add_argument( + "--seq-len", + type=int, + default=2048, + help="Maximum sequence length", + ) + parser.add_argument( + "--gradient-accumulation-steps", + type=int, + default=32, + help="Number of gradient accumulation steps", + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to train on (cuda/cpu)", + ) + parser.add_argument( + "--save-path", + type=str, + default="trained_model_checkpoints", + help="Directory to save model checkpoints", + ) + + # === vLLM Arguments === + parser.add_argument( + "--vllm-restart-interval", + type=int, + default=3, + help="Restart vLLM every N training steps (legacy mode only)", + ) + parser.add_argument( + "--vllm-port", + type=int, + default=9001, + help="Port for the vLLM server", + ) + parser.add_argument( + "--atropos-url", + type=str, + default="http://localhost:8000", + help="URL of the Atropos API/environment server (e.g., gsm8k_server)", + ) + parser.add_argument( + "--vllm-gpu-memory-utilization", + type=float, + default=0.45, + help="GPU memory utilization for vLLM server (0.0-1.0)", + ) + + # === Wandb Arguments === + parser.add_argument( + "--use-wandb", + action="store_true", + help="Enable Weights & Biases logging", + ) + parser.add_argument( + "--wandb-project", + type=str, + default=None, + help="Wandb project name", + ) + parser.add_argument( + "--wandb-group", + type=str, + default=None, + help="Wandb group name", + ) + + # === Training Mode Arguments === + parser.add_argument( + "--weight-bridge-mode", + type=str, + choices=["shared_vllm", "lora_only", "none"], + default="none", + help=( + "Weight sync mode: " + "'shared_vllm' = attach to vLLM shared memory, " + "'lora_only' = train LoRA adapters only, " + "'none' = legacy restart-based sync" + ), + ) + parser.add_argument( + "--trainer-rank", + type=int, + default=0, + help="Rank of this trainer in the distributed group", + ) + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Total processes in the distributed group", + ) + parser.add_argument( + "--init-method", + type=str, + default="env://", + help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')", + ) + parser.add_argument( + "--num-inference-nodes", + type=int, + default=0, + help="Number of inference nodes to coordinate with (0 = single-node local)", + ) + + # === LoRA Arguments === + parser.add_argument( + "--lora-r", + type=int, + default=16, + help="LoRA rank (dimension of low-rank matrices)", + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=32, + help="LoRA alpha (scaling factor, typically 2x rank)", + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.05, + help="Dropout probability for LoRA layers", + ) + parser.add_argument( + "--lora-target-modules", + type=str, + nargs="+", + default=None, + help="Module names to apply LoRA to (default: q_proj v_proj)", + ) + + # === Single-Copy Mode Arguments === + parser.add_argument( + "--single-copy", + action="store_true", + help=( + "Enable TRUE single-copy mode (shared_vllm mode only). " + "Trainer attaches to vLLM's model tensors via CUDA IPC. " + "Only ONE copy of the model exists in GPU memory! " + "Requires trainer and vLLM to be on the SAME GPU(s). " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." + ), + ) + parser.add_argument( + "--vllm-config-path", + type=str, + default=None, + help=( + "Explicit path to vllm_bridge_config.json. " + "If not provided, auto-detects from LOGDIR, current directory, " + "or /tmp/atropos_bridge. " + "This file contains CUDA IPC handles created by vLLM." + ), + ) + + # === Debug Flags === + parser.add_argument( + "--debug-loading", + action="store_true", + help=( + "Enable verbose debug output during model loading and IPC attachment. " + "Useful for diagnosing single-copy mode issues." + ), + ) + parser.add_argument( + "--benchmark", + action="store_true", + help=( + "Enable benchmark timing output showing step time, sync time, " + "data fetch time, and GPU memory usage per step." + ), + ) + + return parser.parse_args() + + +def config_from_args(args: argparse.Namespace) -> TrainingConfig: + """ + Build a TrainingConfig from parsed CLI arguments. + + Args: + args: Parsed argparse namespace + + Returns: + TrainingConfig instance + """ + return TrainingConfig( + model_name=args.model_name, + lr=args.lr, + training_steps=args.training_steps, + batch_size=args.batch_size, + seq_len=args.seq_len, + gradient_accumulation_steps=args.gradient_accumulation_steps, + device=args.device, + save_path=args.save_path, + vllm_restart_interval=args.vllm_restart_interval, + vllm_port=args.vllm_port, + vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization, + use_wandb=args.use_wandb, + wandb_project=args.wandb_project, + wandb_group=args.wandb_group, + weight_bridge_mode=args.weight_bridge_mode, + trainer_rank=args.trainer_rank, + world_size=args.world_size, + init_method=args.init_method, + num_inference_nodes=args.num_inference_nodes, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + lora_target_modules=args.lora_target_modules, + single_copy=getattr(args, "single_copy", False), + vllm_config_path=getattr(args, "vllm_config_path", None), + debug_loading=getattr(args, "debug_loading", False), + benchmark=getattr(args, "benchmark", False), + atropos_url=getattr(args, "atropos_url", "http://localhost:8000"), + ) + diff --git a/example_trainer/config.py b/example_trainer/config.py new file mode 100644 index 00000000..87b92397 --- /dev/null +++ b/example_trainer/config.py @@ -0,0 +1,154 @@ +""" +Training configuration for GRPO trainer. + +This module contains the TrainingConfig class which defines all training +parameters, model settings, and operational modes. +""" + +from typing import List, Literal, Optional + +import torch +from pydantic import BaseModel, Field + + +class TrainingConfig(BaseModel): + """ + Training configuration for GRPO trainer. + + Supports three training modes: + - 'none' (legacy): Periodic checkpoint saves + vLLM restarts + - 'shared_vllm': Attach to vLLM's shared memory tensors, update in-place + - 'lora_only': Freeze base model, train LoRA adapters only + """ + + # === Model Configuration === + model_name: str = Field(..., description="Name of the base model to train") + + # === Training Hyperparameters === + lr: float = Field(1e-5, description="Learning rate for the optimizer") + training_steps: int = Field(10, description="Number of training steps") + batch_size: int = Field(2, description="Batch size for training") + seq_len: int = Field(2048, description="Sequence length for training") + gradient_accumulation_steps: int = Field( + 32, description="Number of gradient accumulation steps" + ) + + # === Device & Storage === + device: str = Field( + "cuda" if torch.cuda.is_available() else "cpu", + description="Device to train on" + ) + save_path: str = Field( + "trained_model_checkpoints", + description="Base path to save model checkpoints" + ) + + # === vLLM Server Configuration === + vllm_restart_interval: int = Field( + 3, description="Restart vLLM every N training steps (legacy mode)" + ) + vllm_port: int = Field(9001, description="Port for the vLLM server") + vllm_gpu_memory_utilization: float = Field( + 0.45, description="GPU memory utilization for vLLM server (0.0-1.0)" + ) + + # === Weights & Biases Configuration === + use_wandb: bool = Field( + False, description="Whether to use Weights & Biases for logging" + ) + wandb_project: Optional[str] = Field(None, description="Wandb project name") + wandb_group: Optional[str] = Field(None, description="Wandb group name") + + # === Training Mode Configuration === + weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field( + "none", + description=( + "How to synchronize weights with inference server. " + "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " + "'lora_only': keep base model frozen, train/swap LoRA adapters. " + "'none': legacy mode, restart vLLM with new checkpoint files." + ), + ) + + # === Distributed Training Configuration === + trainer_rank: int = Field( + 0, description="Rank of this trainer in the distributed group" + ) + world_size: int = Field( + 1, description="Total processes in the distributed group" + ) + init_method: str = Field( + "env://", + description=( + "PyTorch distributed init method URL. " + "Use 'env://' to read MASTER_ADDR/MASTER_PORT from environment, " + "or 'tcp://host:port' for explicit rendezvous." + ), + ) + num_inference_nodes: int = Field( + 0, + description=( + "Number of inference nodes (vLLM servers) to coordinate with. " + "0 means single-node local mode." + ), + ) + + # === LoRA Configuration === + lora_r: int = Field(16, description="LoRA rank (dimension of low-rank matrices)") + lora_alpha: int = Field(32, description="LoRA alpha (scaling factor)") + lora_dropout: float = Field(0.05, description="Dropout probability for LoRA layers") + lora_target_modules: Optional[List[str]] = Field( + None, + description=( + "List of module names to apply LoRA to. " + "If None, defaults to ['q_proj', 'v_proj'] for most models." + ), + ) + + # === Single-Copy Mode Configuration === + single_copy: bool = Field( + False, + description=( + "Enable TRUE single-copy mode via CUDA IPC. " + "The trainer attaches to vLLM's model tensors directly, " + "meaning only ONE copy of the model exists in GPU memory. " + "Requires trainer and vLLM to be on the SAME GPU(s). " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." + ), + ) + vllm_config_path: Optional[str] = Field( + None, + description=( + "Explicit path to vllm_bridge_config.json. " + "If not provided, auto-detects from LOGDIR environment variable, " + "current directory, or /tmp/atropos_bridge. " + "This file is created by vLLM when VLLM_ENABLE_SHARED_WEIGHTS=1 " + "and contains CUDA IPC handles for single-copy mode." + ), + ) + + # === Debug & Benchmark Flags === + debug_loading: bool = Field( + False, + description=( + "Enable verbose debug output during model loading and IPC attachment. " + "Useful for diagnosing single-copy mode issues." + ), + ) + benchmark: bool = Field( + False, + description=( + "Enable benchmark timing output showing step time, sync time, " + "data fetch time, and GPU memory usage per step." + ), + ) + + # === Atropos API Configuration === + atropos_url: str = Field( + "http://localhost:8000", + description=( + "URL of the Atropos API server (environment server). " + "Default is http://localhost:8000. Change for concurrent tests." + ), + ) + diff --git a/example_trainer/data.py b/example_trainer/data.py new file mode 100644 index 00000000..3d9b4cfd --- /dev/null +++ b/example_trainer/data.py @@ -0,0 +1,182 @@ +""" +Data processing utilities for GRPO trainer. + +Handles data retrieval from Atropos API, padding, batching, +and advantage normalization. +""" + +import json +import math +import time +from typing import List, Tuple + +import numpy as np +import torch + +from .api import get_batch + + +def pad_data_to_good_offset(data: dict, batch_size: int) -> Tuple[ + List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] +]: + """ + Pad and batch data from the Atropos API. + + Processes raw batch data into properly padded tensors suitable for training: + - Pads token sequences to nearest multiple of 64 + - Normalizes advantage scores + - Extracts temperature values + + Args: + data: Raw batch data from Atropos API + batch_size: Size of each training batch + + Returns: + Tuple of (token_batches, label_batches, advantage_batches, temperature_batches) + """ + max_token_len = max( + [max([len(x) for x in item["tokens"]]) for item in data["batch"]] + ) + + # Pad to nearest multiple of 64 for GPU efficiency + good_multiple = 64 + if (max_token_len - 1) % (good_multiple) != 0: + max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple + token_setup_len = max_token_len + 1 # +1 for causal shift + else: + token_setup_len = max_token_len + max_token_len = max_token_len - 1 # -1 for causal shift + + # Process all items + input_ids = [] + labels = [] + advantages = [] + lengths = [] + temperatures = [] + + for item in data["batch"]: + # Normalize advantage scores + scores = np.array(item["scores"]) + if len(scores) > 1: + scores = scores - scores.mean() + scores = scores / max(scores.std(), 1e-8) + item["scores"] = scores + + # Handle score overrides + if item["overrides"] is not None: + for i in range(len(item["overrides"])): + if item["overrides"][i].get("set_advantage_to_zero", False): + item["scores"][i] = 0 + + # Process each sample in the item + for i in range(len(item["tokens"])): + lengths.append( + math.ceil((len(item["tokens"][i]) - 1) / good_multiple) * good_multiple + ) + + # Create labels with padding + label_item = np.concatenate([ + np.array(item["masks"][i]), + np.full( + max(0, token_setup_len - len(item["tokens"][i])), + -100, + dtype=np.int32, + ), + ]) + + # Pad tokens + item["tokens"][i] = np.concatenate([ + np.array(item["tokens"][i]), + np.zeros( + max(0, token_setup_len - len(item["tokens"][i])), + dtype=np.int32, + ), + ]) + + input_ids.append(item["tokens"][i][:-1]) # Remove last for causal + labels.append(label_item[1:]) # Shift by 1 for causal + advantages.append(item["scores"][i]) + + # Extract temperature (priority: override > generation_params > group_overrides > 1.0) + t = 1.0 + if ( + item.get("overrides") + and i < len(item["overrides"]) + and isinstance(item["overrides"][i], dict) + and ("temperature" in item["overrides"][i]) + ): + t = float(item["overrides"][i]["temperature"]) + elif item.get("generation_params") and ("temperature" in item["generation_params"]): + t = float(item["generation_params"]["temperature"]) + elif item.get("group_overrides") and ("temperature" in item["group_overrides"]): + t = float(item["group_overrides"]["temperature"]) + temperatures.append(t) + + # Batch the data + token_batches = [] + label_batches = [] + advantage_batches = [] + temperature_batches = [] + + for i in range(len(input_ids) // batch_size): + start = i * batch_size + end = (i + 1) * batch_size + + token_batches.append( + torch.tensor(np.stack(input_ids[start:end], axis=0)) + ) + label_batches.append( + torch.tensor(np.stack(labels[start:end], axis=0)) + ) + advantage_batches.append( + torch.tensor(np.stack(advantages[start:end], axis=0)).view(-1, 1) + ) + temperature_batches.append( + torch.tensor( + np.array(temperatures[start:end], dtype=np.float32) + ).view(-1, 1, 1) + ) + + return token_batches, label_batches, advantage_batches, temperature_batches + + +def get_data( + batch_size: int, + seq_len: int, + atropos_url: str = "http://localhost:8000", +) -> List[Tuple[ + List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] +]]: + """ + Fetch and process training data from the Atropos API. + + Continuously polls the API until data is available, then processes + all available batches. + + Args: + batch_size: Size of each training batch + seq_len: Maximum sequence length (for reference, not used directly) + atropos_url: URL of the Atropos API server + + Returns: + List of processed batch tuples + """ + batches = [] + + while True: + data = get_batch(url=atropos_url) + + if data["batch"] is not None: + # Save batch for debugging + with open("temp.json", "w", encoding="utf-8") as f: + json.dump(data, f) + + # Process and accumulate batches + batches.append(pad_data_to_good_offset(data, batch_size)) + elif len(batches) > 0: + # Return accumulated batches when no more data + return batches + else: + # Wait for data + time.sleep(1) + diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 9cd935ad..1dc9bdc4 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -1,2450 +1,55 @@ -import argparse -import atexit -import json -import math -import os -import random -import shutil -import string -import subprocess -import time -from typing import Dict, List, Literal, Optional, Tuple +#!/usr/bin/env python3 +""" +GRPO Trainer - Main Entry Point -import numpy as np -import requests -import torch -import torch.nn.functional as F -import wandb # Added for logging -from pydantic import BaseModel, Field -from tenacity import retry, stop_after_attempt, wait_exponential -from torch.optim import AdamW -from transformers import AutoModelForCausalLM, AutoTokenizer +This is the command-line entry point for the GRPO trainer. +For the actual implementation, see the modular files: -# Weight bridge removed - single-copy mode uses direct CUDA IPC instead +- config.py - TrainingConfig class +- api.py - Atropos API communication +- data.py - Data processing and batching +- model.py - Model loading and shared memory +- training.py - Loss computation and training step +- checkpointing.py - Checkpoint saving +- vllm_manager.py - vLLM process management +- trainers.py - Training mode implementations +- cli.py - CLI argument parsing -# Import PEFT for LoRA training -try: - from peft import LoraConfig, TaskType, get_peft_model +Usage: + # Legacy mode (checkpoint + restart) + python grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --training-steps 100 - PEFT_AVAILABLE = True -except ImportError: - PEFT_AVAILABLE = False + # Single-copy mode (shared memory) + python grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode shared_vllm -# Global variable to keep track of the vLLM process -vllm_process = None + # LoRA mode (adapter training) + python grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode lora_only +""" +from .cli import parse_args, config_from_args +from .trainers import train_legacy, train_shared_vllm, train_lora -def cleanup_vllm(): - global vllm_process - if vllm_process: - print("\nTerminating vLLM process...") - vllm_process.terminate() - try: - vllm_process.wait(timeout=5) # Wait a bit for graceful shutdown - print("vLLM process terminated.") - except subprocess.TimeoutExpired: - print("vLLM process did not terminate gracefully, killing.") - vllm_process.kill() - vllm_process.wait() - print("vLLM process killed.") - vllm_process = None +def main(): + """Main entry point for GRPO trainer.""" + args = parse_args() + config = config_from_args(args) -# Register the cleanup function to be called on script exit -atexit.register(cleanup_vllm) + print(f"Weight bridge mode: {config.weight_bridge_mode}") - -class TrainingConfig(BaseModel): - """ - Training details, model, etc - """ - - model_name: str = Field(..., description="Name of the base model to train") - lr: float = Field(1e-5, description="Learning rate for the optimizer") - training_steps: int = Field( - 10, description="Number of training steps" - ) # Renamed from epochs - batch_size: int = Field( - 2, description="Batch size for training (will be handled by get_data)" - ) - seq_len: int = Field(2048, description="Sequence length for training") - gradient_accumulation_steps: int = Field( - 32, description="Number of gradient accumulation steps" - ) - device: str = Field( - "cuda" if torch.cuda.is_available() else "cpu", description="Device to train on" - ) - save_path: str = Field( - "trained_model_checkpoints", description="Base path to save model checkpoints" - ) - vllm_restart_interval: int = Field( - 3, description="Restart vLLM every N training steps" - ) - vllm_port: int = Field(9001, description="Port for the vLLM server") - vllm_gpu_memory_utilization: float = Field( - 0.45, description="GPU memory utilization for vLLM server (0.0-1.0)" - ) - - # Wandb configuration - use_wandb: bool = Field( - False, description="Whether to use Weights & Biases for logging" - ) - wandb_project: Optional[str] = Field(None, description="Wandb project name") - wandb_group: Optional[str] = Field(None, description="Wandb group name") - - # Pipeline / weight bridge configuration - weight_bridge_mode: Literal["shared_vllm", "lora_only", "none"] = Field( - "none", - description=( - "How to synchronize weights with inference server. " - "'shared_vllm': attach to vLLM's shared memory tensors and update in-place. " - "'lora_only': keep base model frozen, train/swap LoRA adapters. " - "'none': legacy mode, restart vLLM with new checkpoint files." - ), - ) - trainer_rank: int = Field( - 0, - description="Rank of this trainer in the distributed group (for shared_vllm mode)", - ) - world_size: int = Field( - 1, - description="Total processes in the distributed group (for shared_vllm mode)", - ) - init_method: str = Field( - "env://", - description=( - "PyTorch distributed init method URL. " - "Use 'env://' to read MASTER_ADDR/MASTER_PORT from environment, " - "or 'tcp://host:port' for explicit rendezvous." - ), - ) - num_inference_nodes: int = Field( - 0, - description=( - "Number of inference nodes (vLLM servers) to coordinate with. " - "0 means single-node local mode." - ), - ) - - # LoRA configuration (for lora_only mode) - lora_r: int = Field(16, description="LoRA rank (dimension of low-rank matrices)") - lora_alpha: int = Field(32, description="LoRA alpha (scaling factor)") - lora_dropout: float = Field(0.05, description="Dropout probability for LoRA layers") - lora_target_modules: Optional[List[str]] = Field( - None, - description=( - "List of module names to apply LoRA to. " - "If None, defaults to ['q_proj', 'v_proj'] for most models." - ), - ) - - # Single-copy mode (TRUE shared memory - no extra model copy) - single_copy: bool = Field( - False, - description=( - "Enable TRUE single-copy mode via CUDA IPC. " - "The trainer attaches to vLLM's model tensors directly, " - "meaning only ONE copy of the model exists in GPU memory. " - "Requires trainer and vLLM to be on the SAME GPU(s). " - "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." - ), - ) - vllm_config_path: Optional[str] = Field( - None, - description=( - "Explicit path to vllm_bridge_config.json. " - "If not provided, auto-detects from LOGDIR environment variable, " - "current directory, or /tmp/atropos_bridge. " - "This file is created by vLLM when VLLM_ENABLE_SHARED_WEIGHTS=1 " - "and contains CUDA IPC handles for single-copy mode." - ), - ) - - # Debug flags - debug_loading: bool = Field( - False, - description=( - "Enable verbose debug output during model loading and IPC attachment. " - "Useful for diagnosing single-copy mode issues." - ), - ) - benchmark: bool = Field( - False, - description=( - "Enable benchmark timing output showing step time, sync time, " - "data fetch time, and GPU memory usage per step." - ), - ) - atropos_url: str = Field( - "http://localhost:8000", - description=( - "URL of the Atropos API server (environment server). " - "Default is http://localhost:8000. Change for concurrent tests." - ), - ) - - -def check_atropos_api(url: str = "http://localhost:8000", timeout: float = 30.0) -> bool: - """ - Check if the Atropos API server is reachable. - - Args: - url: Base URL of the Atropos API server - timeout: Maximum time to wait for the server - - Returns: - True if server is reachable - """ - import time as _time - - start = _time.time() - while _time.time() - start < timeout: - try: - response = requests.get(f"{url}/info", timeout=2) - if response.status_code == 200: - print(f"[Trainer] βœ“ Atropos API server is reachable at {url}") - return True - except requests.exceptions.ConnectionError: - pass - except Exception as e: - print(f"[Trainer] Waiting for Atropos API at {url}... ({e})") - _time.sleep(1) - - print(f"[Trainer] ⚠ Warning: Atropos API server not reachable at {url}") - return False - - -@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) -def register_trainer(config: TrainingConfig): - """ - Register the trainer with the Atropos API. - - Verifies registration succeeded before returning. - """ - url = config.atropos_url - response = requests.post( - f"{url}/register", - json={ - # wandb fields are required strings - use empty string if None - "wandb_group": config.wandb_group or "", - "wandb_project": config.wandb_project or "", - "batch_size": config.batch_size * config.gradient_accumulation_steps, - "max_token_len": config.seq_len, - "starting_step": 0, - "checkpoint_dir": config.save_path, - "save_checkpoint_interval": config.training_steps, - "num_steps": config.training_steps, - }, - timeout=10, - ) - - # Check for HTTP errors - response.raise_for_status() - - # Verify we got a valid response with UUID - data = response.json() - if "uuid" not in data: - raise RuntimeError(f"Registration failed: {data}") - - print(f"[Trainer] βœ“ Registered with Atropos API at {url} (uuid: {data['uuid']})") - - -@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=2, max=30)) -def get_batch(url: str = "http://localhost:8000"): - data = requests.get(f"{url}/batch", timeout=10).json() - - # Check if there was an error (trainer not registered) - if data.get("status") == "error": - raise RuntimeError(f"Atropos API error: {data.get('message', 'Unknown error')}") - - return data - - -def pad_data_to_good_offset(data, batch_size: int): - max_token_len = max( - [max([len(x) for x in item["tokens"]]) for item in data["batch"]] - ) - # usually 64 is a good choice to ensure nonweird scaling behavior on GPUS - # so we pad to the nearest multiple of 64 - good_multiple = 64 - if (max_token_len - 1) % (good_multiple) != 0: - max_token_len = math.ceil((max_token_len - 1) / (good_multiple)) * good_multiple - token_setup_len = ( - max_token_len + 1 - ) # add 1 so we can make it causal at the proper length - else: - token_setup_len = max_token_len - max_token_len = ( - max_token_len - 1 - ) # since it's causal we need to remove the last bit... - # pad all tokens to max_token_len and add to lists - input_ids = list() - labels = list() - advantages = list() - lengths = list() - temperatures = list() - for item in data["batch"]: - scores = item["scores"] - scores = np.array(scores) - # check if we have more than 1 score... - if len(scores) > 1: - scores = scores - scores.mean() - scores = scores / max(scores.std(), 1e-8) - item["scores"] = scores - if item["overrides"] is not None: - for i in range(len(item["overrides"])): - if item["overrides"][i].get("set_advantage_to_zero", False): - item["scores"][i] = 0 - for i in range(len(item["tokens"])): - lengths.append( - math.ceil((len(item["tokens"][i]) - 1) / (good_multiple)) - * good_multiple - ) - label_item = np.concatenate( - [ - np.array(item["masks"][i]), - np.full( - max(0, token_setup_len - len(item["tokens"][i])), - -100, - dtype=np.int32, - ), - ] - ) - item["tokens"][i] = np.concatenate( - [ - np.array(item["tokens"][i]), - np.zeros( - max(0, token_setup_len - len(item["tokens"][i])), dtype=np.int32 - ), - ] - ) - input_ids.append(item["tokens"][i][:-1]) - labels.append(label_item[1:]) - advantages.append(item["scores"][i]) - # per-sample override -> group generation_params -> group_overrides - > 1.0 - # need to update docs since this lets you set the temperature for each sample from the override - t = 1.0 - if ( - item.get("overrides") - and i < len(item["overrides"]) - and isinstance(item["overrides"][i], dict) - and ("temperature" in item["overrides"][i]) - ): - t = float(item["overrides"][i]["temperature"]) - elif item.get("generation_params") and ( - "temperature" in item["generation_params"] - ): - t = float(item["generation_params"]["temperature"]) - elif item.get("group_overrides") and ( - "temperature" in item["group_overrides"] - ): - t = float(item["group_overrides"]["temperature"]) - temperatures.append(t) - # combine all lists into tensors - token_batches = [] - label_batches = [] - advantage_batches = [] - temperature_batches = [] - for i in range(len(input_ids) // batch_size): - token_batches.append( - torch.tensor( - np.stack(input_ids[i * batch_size : (i + 1) * batch_size], axis=0) - ) - ) - label_batches.append( - torch.tensor( - np.stack(labels[i * batch_size : (i + 1) * batch_size], axis=0) - ) - ) - advantage_batches.append( - torch.tensor( - np.stack(advantages[i * batch_size : (i + 1) * batch_size], axis=0) - ).view(-1, 1) - ) - # Temperatures: one per sample, shaped for broadcasting to [B, 1, 1] - temperature_batches.append( - torch.tensor( - np.array( - temperatures[i * batch_size : (i + 1) * batch_size], - dtype=np.float32, - ) - ).view(-1, 1, 1) - ) - - return token_batches, label_batches, advantage_batches, temperature_batches - - -def get_data( - batch_size: int, seq_len: int, atropos_url: str = "http://localhost:8000" -) -> List[ - Tuple[ - List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor] - ] -]: - """ - getting data from the api - """ - batches = [] - while True: - data = get_batch(url=atropos_url) - if data["batch"] is not None: - # Save the batch - with open("temp.json", "w", encoding="utf-8") as f: - json.dump(data, f) - # In case the inference runs ahead of the training, we loop until we don't have any more data - batches.append(pad_data_to_good_offset(data, batch_size)) - elif len(batches) > 0: - # Return the batches - return batches - else: - time.sleep(1) - - -# ============================================================================= -# Common Training Helpers (shared across all modes) -# ============================================================================= - - -def setup_wandb(config: TrainingConfig) -> bool: - """ - Initialize Weights & Biases logging if enabled. - - Args: - config: Training configuration - - Returns: - True if wandb is active, False otherwise - """ - if not config.use_wandb: - return False - - if not config.wandb_project: - print("Warning: wandb_project not set, disabling wandb.") - return False - - # Generate random group name if not provided - if not config.wandb_group: - config.wandb_group = "".join( - random.choices(string.ascii_letters + string.digits, k=8) - ) - - try: - wandb.init( - project=config.wandb_project, - group=config.wandb_group, - config=config.dict(), - ) - print( - f"Wandb logging enabled. Run: {wandb.run.name} " - f"(Project: {config.wandb_project})" - ) - return True - except Exception as e: - print(f"Error initializing wandb: {e}. Disabling wandb.") - return False - - -def _attach_to_vllm_shared_tensors( - config: TrainingConfig, - bridge_config_path: str, -) -> Optional[torch.nn.Module]: - """ - Attach to vLLM's shared tensors via CUDA IPC (true single-copy mode). - - This creates a model whose parameters point to the SAME GPU memory as vLLM, - meaning only ONE copy of the model exists in GPU memory. - - Args: - config: Training configuration - bridge_config_path: Path to vllm_bridge_config.json - - Returns: - Model with parameters pointing to vLLM's tensors, or None if not possible - """ - print(f"[Setup] Reading bridge config from: {bridge_config_path}") - try: - with open(bridge_config_path, "r") as f: - bridge_config = json.load(f) - print(f"[Setup] Bridge config keys: {list(bridge_config.keys())}") - except Exception as e: - print(f"[Setup] Could not read bridge config: {e}") - return None - - single_copy_enabled = bridge_config.get("single_copy_enabled", False) - print(f"[Setup] single_copy_enabled in config: {single_copy_enabled}") - - if not single_copy_enabled: - print("[Setup] Single-copy mode not available (single_copy_enabled=False)") - print("[Setup] Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1") - return None - - ipc_handles_raw = bridge_config.get("ipc_handles", {}) - print(f"[Setup] IPC handles count: {len(ipc_handles_raw)}") - if not ipc_handles_raw: - print("[Setup] No IPC handles found in bridge config") - return None - - # Deserialize base64-encoded bytes back to bytes - import base64 - - def deserialize_ipc_handles(handles): - result = {} - for k, v in handles.items(): - if isinstance(v, dict): - if "_bytes_b64_" in v: - result[k] = base64.b64decode(v["_bytes_b64_"]) - else: - result[k] = deserialize_ipc_handles(v) - else: - result[k] = v - return result - - ipc_handles = deserialize_ipc_handles(ipc_handles_raw) - - print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...") - print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") - - # Load model config (not weights) to get architecture - from transformers import AutoConfig - - model_config = AutoConfig.from_pretrained(config.model_name) - - # Create empty model on meta device (no memory allocation) - with torch.device("meta"): - model = AutoModelForCausalLM.from_config( - model_config, - torch_dtype=torch.bfloat16, - ) - - # Get parameter names from the empty model - param_names = list(model.state_dict().keys()) - print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True) - - # Initialize CUDA before IPC operations - # Get the device indices we'll be using - device_indices = set() - for name, info in ipc_handles.items(): - if "device_index" in info: - device_indices.add(info["device_index"]) - - print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True) - - # Initialize CUDA context on each device - for dev_idx in sorted(device_indices): - print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True) - torch.cuda.set_device(dev_idx) - torch.cuda.synchronize(dev_idx) - print(f"[Setup] βœ“ Device {dev_idx} ready", flush=True) - - # Map vLLM tensor names to HuggingFace model parameter names - hf_state_dict = {} - vllm_to_hf_mapping = _create_vllm_to_hf_mapping( - model, ipc_handles, debug=config.debug_loading - ) - - # Cache for reconstructed vLLM tensors (to avoid reconstructing fused tensors multiple times) - vllm_tensor_cache: Dict[str, torch.Tensor] = {} - - def reconstruct_vllm_tensor(vllm_name: str) -> Optional[torch.Tensor]: - """Reconstruct a vLLM tensor from IPC handle, with caching.""" - if vllm_name in vllm_tensor_cache: - return vllm_tensor_cache[vllm_name] - - if vllm_name not in ipc_handles: - return None - - ipc_info = ipc_handles[vllm_name] - - if "ipc_handle_b64" not in ipc_info: - return None - - try: - # Decode all the bytes fields from base64 - device_index = ipc_info["device_index"] - ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) - storage_size = ipc_info["storage_size"] - storage_offset_orig = ipc_info["storage_offset_orig"] - ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"]) - ref_counter_offset = ipc_info["ref_counter_offset"] - event_handle = base64.b64decode(ipc_info["event_handle_b64"]) - event_sync_required = ipc_info["event_sync_required"] - - # Reconstruct the 8-tuple that _new_shared_cuda expects - share_tuple = ( - device_index, - ipc_handle, - storage_size, - storage_offset_orig, - ref_counter_handle, - ref_counter_offset, - event_handle, - event_sync_required, - ) - - # Create storage from IPC handle - storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) - - # Reconstruct tensor - dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) - tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") - tensor.set_( - storage, - storage_offset=ipc_info["tensor_storage_offset"], - size=ipc_info["shape"], - stride=ipc_info["stride"], - ) - - vllm_tensor_cache[vllm_name] = tensor - return tensor - - except Exception as e: - print(f"[Setup] Failed to reconstruct {vllm_name}: {e}", flush=True) - return None - - attached_count = 0 - fused_count = 0 - - for hf_name, mapping_info in vllm_to_hf_mapping.items(): - try: - # Check if this is a fused mapping or direct mapping - if isinstance(mapping_info, dict): - # Fused mapping - need to slice the source tensor - vllm_name = mapping_info["source"] - slice_start, slice_end = mapping_info["slice"] - slice_dim = mapping_info["dim"] - - full_tensor = reconstruct_vllm_tensor(vllm_name) - if full_tensor is None: - if config.debug_loading: - print(f"[Setup] Could not get source tensor for {hf_name}") - continue - - # Create a VIEW (not copy!) into the fused tensor - # This maintains shared memory - gradients flow back to vLLM's tensor - if slice_dim == 0: - tensor = full_tensor[slice_start:slice_end] - else: - # For other dimensions (rare, but handle it) - tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start) - - # Verify it's a view, not a copy - if tensor.storage().data_ptr() != full_tensor.storage().data_ptr(): - print(f"[Setup] WARNING: {hf_name} is a COPY, not a view!") - - if attached_count == 0 and config.debug_loading: - print(f"[Setup DEBUG] Fused tensor slice: {hf_name}") - print(f"[Setup DEBUG] Source: {vllm_name} shape={full_tensor.shape}") - print(f"[Setup DEBUG] Slice: [{slice_start}:{slice_end}] -> {tensor.shape}") - - tensor.requires_grad_(True) - hf_state_dict[hf_name] = tensor - fused_count += 1 - attached_count += 1 - - else: - # Direct mapping - reconstruct tensor directly - vllm_name = mapping_info - - tensor = reconstruct_vllm_tensor(vllm_name) - if tensor is None: - continue - - if attached_count == 0 and config.debug_loading: - print(f"[Setup DEBUG] Attempting first tensor: {hf_name}", flush=True) - ipc_info = ipc_handles[vllm_name] - print(f"[Setup DEBUG] device_index: {ipc_info['device_index']}", flush=True) - print(f"[Setup DEBUG] storage_size: {ipc_info['storage_size']}", flush=True) - print(f"[Setup DEBUG] shape: {ipc_info['shape']}", flush=True) - - tensor.requires_grad_(True) - hf_state_dict[hf_name] = tensor - attached_count += 1 - - if attached_count == 1 and config.debug_loading: - print("[Setup DEBUG] βœ“ First tensor attached successfully!", flush=True) - - except Exception as e: - print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) - import traceback - traceback.print_exc() - - print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)") - - if attached_count == 0: - print("[Setup] Could not attach any tensors, falling back to regular loading") - return None - - # ========================================================================= - # EARLY VALIDATION: Check that we mapped a reasonable number of parameters - # This catches obvious mapping failures before we try to load - # ========================================================================= - hf_param_count = len(list(model.named_parameters())) - # Note: attached_count may include fused tensors that map to multiple HF params - # So coverage can exceed 100% - that's OK - mapping_coverage = attached_count / hf_param_count if hf_param_count > 0 else 0 - - print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters") - - # Expect at least 90% coverage for a valid mapping - # Note: with fused tensors, we may have MORE mappings than params - # So we check if we have at least 90% of params covered - if mapping_coverage < 0.90: - unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) - warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" - warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n" - for name in list(unmapped_params)[:20]: - warning_msg += f" - {name}\n" - print(warning_msg) - - if mapping_coverage < 0.50: - raise RuntimeError( - f"[Setup] CRITICAL: Only {mapping_coverage:.1%} of parameters mapped!\n" - "This indicates a serious mapping failure. Check:\n" - " 1. vLLM and HuggingFace use the same model architecture\n" - " 2. tensor-parallel-size=1 for single-copy mode\n" - " 3. vllm_bridge_config.json contains valid ipc_handles" - ) - else: - print(f"[Setup] βœ“ Good mapping coverage ({mapping_coverage:.1%})") - - print(f"[Setup] βœ“ Attached {attached_count} tensors to vLLM's shared memory") - - # Load state dict into model - # NOTE: We use strict=False because some buffers (like inv_freq) won't be in vLLM - # but we VALIDATE after loading to ensure nothing critical is left on meta - model.load_state_dict(hf_state_dict, strict=False, assign=True) - - # Initialize any remaining meta tensors (buffers like rotary embeddings) - # These are not in vLLM's state_dict but need to be initialized - device = f"cuda:{list(device_indices)[0]}" if device_indices else "cuda:0" - - # ========================================================================= - # DIAGNOSTIC: Count what's on meta vs cuda after load_state_dict - # ========================================================================= - meta_params = [] - cuda_params = [] - for name, param in model.named_parameters(): - if param.device.type == "meta": - meta_params.append(name) - elif param.device.type == "cuda": - cuda_params.append(name) - - meta_buffers = [] - cuda_buffers = [] - for name, buffer in model.named_buffers(): - if buffer.device.type == "meta": - meta_buffers.append(name) - elif buffer.device.type == "cuda": - cuda_buffers.append(name) - - if config.debug_loading: - print("\n[DIAGNOSTIC] After load_state_dict:") - print(f" - Parameters on CUDA: {len(cuda_params)}") - print(f" - Parameters on META: {len(meta_params)}") - print(f" - Buffers on CUDA: {len(cuda_buffers)}") - print(f" - Buffers on META: {len(meta_buffers)}") - - if meta_params: - print("\n[DIAGNOSTIC] First 10 META parameters:") - for name in meta_params[:10]: - param = dict(model.named_parameters())[name] - print( - f" {name}: shape={param.shape}, dtype={param.dtype}, device={param.device}" - ) - - if meta_buffers: - print("\n[DIAGNOSTIC] META buffers:") - for name in meta_buffers[:10]: - buffer = dict(model.named_buffers())[name] - print( - f" {name}: shape={buffer.shape}, dtype={buffer.dtype}, device={buffer.device}" - ) - - # ========================================================================= - # Helper function to navigate module hierarchy - # ========================================================================= - def get_parent_and_name(model, full_name): - """Get parent module and attribute name from full parameter name.""" - parts = full_name.split(".") - parent = model - for part in parts[:-1]: - parent = getattr(parent, part) - return parent, parts[-1] - - # ========================================================================= - # Initialize remaining meta parameters - # NOTE: Can't use param.data = ... on meta tensors! - # Must use setattr() to replace the entire parameter in the parent module - # ========================================================================= - meta_count = 0 - - for name in meta_params: - param = dict(model.named_parameters()).get(name) - if param is None: - continue - - try: - if config.debug_loading: - print(f"[DIAGNOSTIC] Initializing meta param: {name}") - print( - f" - Old: device={param.device}, dtype={param.dtype}, shape={param.shape}" - ) - - # Create new parameter with actual data on CUDA - new_data = torch.zeros(param.shape, dtype=param.dtype, device=device) - new_param = torch.nn.Parameter(new_data, requires_grad=param.requires_grad) - - if config.debug_loading: - print( - f" - New: device={new_param.device}, dtype={new_param.dtype}, shape={new_param.shape}" - ) - - # Replace in parent module using setattr (NOT param.data = ...) - parent, attr_name = get_parent_and_name(model, name) - if config.debug_loading: - print(f" - Parent module: {type(parent).__name__}, attr: {attr_name}") - - setattr(parent, attr_name, new_param) - meta_count += 1 - if config.debug_loading: - print(" - βœ“ Replaced successfully!") - - except Exception as e: - if config.debug_loading: - print(f"[DIAGNOSTIC] FAILED to initialize {name}: {e}") - import traceback - - traceback.print_exc() - - # ========================================================================= - # Initialize remaining meta buffers - # ========================================================================= - for name in meta_buffers: - buffer = dict(model.named_buffers()).get(name) - if buffer is None: - continue - - try: - if config.debug_loading: - print(f"[DIAGNOSTIC] Initializing meta buffer: {name}") - print( - f" - Old: device={buffer.device}, dtype={buffer.dtype}, shape={buffer.shape}" - ) - - # For buffers like inv_freq, we need proper initialization - if "inv_freq" in name: - # Rotary embedding inverse frequencies - dim = buffer.shape[0] * 2 # inv_freq has shape [dim/2] - base = 10000.0 # Default RoPE base - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) - ) - new_buffer = inv_freq.to(dtype=buffer.dtype, device=device) - if config.debug_loading: - print(f" - Computed inv_freq with dim={dim}, base={base}") - else: - # Other buffers - initialize with zeros - new_buffer = torch.zeros( - buffer.shape, dtype=buffer.dtype, device=device - ) - - if config.debug_loading: - print( - f" - New: device={new_buffer.device}, dtype={new_buffer.dtype}, shape={new_buffer.shape}" - ) - - # Replace in parent module - parent, attr_name = get_parent_and_name(model, name) - if config.debug_loading: - print(f" - Parent module: {type(parent).__name__}, attr: {attr_name}") - - parent.register_buffer(attr_name, new_buffer) - meta_count += 1 - if config.debug_loading: - print(" - βœ“ Replaced successfully!") - - except Exception as e: - if config.debug_loading: - print(f"[DIAGNOSTIC] FAILED to initialize buffer {name}: {e}") - import traceback - - traceback.print_exc() - - print(f"\n[Setup] Initialized {meta_count} remaining meta tensors") - - # ========================================================================= - # CRITICAL VALIDATION: Ensure no parameters/buffers are still on meta device - # This catches mapping bugs that would otherwise cause garbage output - # ========================================================================= - final_meta_params = [] - final_meta_buffers = [] - - for name, param in model.named_parameters(): - if param.device.type == "meta": - final_meta_params.append(name) - - for name, buffer in model.named_buffers(): - if buffer.device.type == "meta": - final_meta_buffers.append(name) - - if final_meta_params or final_meta_buffers: - error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n" - error_msg += "This means they were NOT properly mapped from vLLM or initialized.\n" - error_msg += "The model would produce GARBAGE output.\n\n" - - if final_meta_params: - error_msg += f"Meta parameters ({len(final_meta_params)}):\n" - for name in final_meta_params[:20]: - error_msg += f" - {name}\n" - if len(final_meta_params) > 20: - error_msg += f" ... and {len(final_meta_params) - 20} more\n" - - if final_meta_buffers: - error_msg += f"\nMeta buffers ({len(final_meta_buffers)}):\n" - for name in final_meta_buffers[:20]: - error_msg += f" - {name}\n" - if len(final_meta_buffers) > 20: - error_msg += f" ... and {len(final_meta_buffers) - 20} more\n" - - error_msg += "\nPossible causes:\n" - error_msg += " 1. vLLM parameter names don't match HuggingFace names\n" - error_msg += " 2. QKV/Gate-Up fusion mapping failed\n" - error_msg += " 3. vLLM running with tensor-parallel-size > 1 (not supported)\n" - - raise RuntimeError(error_msg) - - print("[Setup] βœ“ All tensors successfully initialized on CUDA") - - return model - - -def _create_vllm_to_hf_mapping( - model: torch.nn.Module, ipc_handles: dict, debug: bool = False -) -> dict: - """ - Create mapping from HuggingFace parameter names to vLLM tensor names. - - vLLM uses different naming conventions and fuses certain layers: - - qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF) - - gate_up_proj (vLLM) = gate_proj + up_proj (HF) - - Returns a dict where: - - Simple mappings: {"hf_name": "vllm_name"} - - Fused mappings: {"hf_name": {"source": "vllm_name", "slice": (start, end), "dim": 0}} - """ - hf_params = set(model.state_dict().keys()) - vllm_params = set(ipc_handles.keys()) - - # Get model config for dimension calculations - model_config = model.config - hidden_size = getattr(model_config, "hidden_size", 4096) - num_attention_heads = getattr(model_config, "num_attention_heads", 32) - num_key_value_heads = getattr( - model_config, "num_key_value_heads", num_attention_heads - ) - intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) - head_dim = hidden_size // num_attention_heads - - # Calculate sizes for QKV split - q_size = hidden_size # num_heads * head_dim - k_size = num_key_value_heads * head_dim - v_size = num_key_value_heads * head_dim - - if debug: - print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " - f"kv_heads={num_key_value_heads}, intermediate={intermediate_size}") - print(f"[Mapping] QKV sizes: q={q_size}, k={k_size}, v={v_size}") - - mapping = {} - - def find_vllm_name(hf_name: str) -> Optional[str]: - """Try to find the corresponding vLLM parameter name.""" - # Direct match - if hf_name in vllm_params: - return hf_name - - # Add 'model.' prefix - if not hf_name.startswith("model."): - candidate = f"model.{hf_name}" - if candidate in vllm_params: - return candidate - - # Remove 'model.' prefix - if hf_name.startswith("model."): - candidate = hf_name[6:] - if candidate in vllm_params: - return candidate - - return None - - def find_fused_source(hf_name: str, fused_suffix: str) -> Optional[str]: - """Try to find the fused layer that contains this parameter.""" - # e.g., "model.layers.0.self_attn.q_proj.weight" -> "model.layers.0.self_attn.qkv_proj.weight" - for unfused in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]: - if unfused in hf_name: - fused_name = hf_name.replace(unfused, fused_suffix) - found = find_vllm_name(fused_name) - if found: - return found - return None - - for hf_name in hf_params: - # Try direct match first - vllm_name = find_vllm_name(hf_name) - if vllm_name: - mapping[hf_name] = vllm_name - continue - - # Check for QKV fusion: q_proj, k_proj, v_proj -> qkv_proj - if any(x in hf_name for x in ["q_proj", "k_proj", "v_proj"]): - fused_name = find_fused_source(hf_name, "qkv_proj") - if fused_name: - # Determine which part of the fused tensor this is - if "q_proj" in hf_name: - start, end = 0, q_size - elif "k_proj" in hf_name: - start, end = q_size, q_size + k_size - else: # v_proj - start, end = q_size + k_size, q_size + k_size + v_size - - mapping[hf_name] = { - "source": fused_name, - "slice": (start, end), - "dim": 0, # Split along output dimension - "type": "qkv_fusion", - } - if debug: - print(f"[Mapping] QKV fusion: {hf_name} -> {fused_name}[{start}:{end}]") - continue - - # Check for Gate/Up fusion: gate_proj, up_proj -> gate_up_proj - if any(x in hf_name for x in ["gate_proj", "up_proj"]): - fused_name = find_fused_source(hf_name, "gate_up_proj") - if fused_name: - # Determine which part of the fused tensor this is - if "gate_proj" in hf_name: - start, end = 0, intermediate_size - else: # up_proj - start, end = intermediate_size, intermediate_size * 2 - - mapping[hf_name] = { - "source": fused_name, - "slice": (start, end), - "dim": 0, # Split along output dimension - "type": "gate_up_fusion", - } - if debug: - print(f"[Mapping] Gate/Up fusion: {hf_name} -> {fused_name}[{start}:{end}]") - continue - - # No mapping found - this parameter will need to be handled specially - if debug and "inv_freq" not in hf_name: # inv_freq is expected to be missing - print(f"[Mapping] No mapping for: {hf_name}") - - if debug: - direct = sum(1 for v in mapping.values() if isinstance(v, str)) - fused = sum(1 for v in mapping.values() if isinstance(v, dict)) - print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)") - print(f"[Mapping] Unmapped: {len(hf_params) - len(mapping)}") - - return mapping - - -def load_model_and_tokenizer( - config: TrainingConfig, - single_copy: bool = False, -) -> Tuple[torch.nn.Module, "AutoTokenizer"]: - """ - Load or attach to model based on weight_bridge_mode. - - Args: - config: Training configuration - single_copy: If True, attach to vLLM's shared tensors via CUDA IPC - - Returns: - Tuple of (model, tokenizer) - """ - tokenizer = AutoTokenizer.from_pretrained(config.model_name) - - # Single-copy mode: attach to vLLM's shared tensors via CUDA IPC - if single_copy or config.weight_bridge_mode == "shared_vllm": - # Check for explicit path first - if config.vllm_config_path and os.path.exists(config.vllm_config_path): - config_path = config.vllm_config_path - print(f"[Setup] Using explicit vLLM config path: {config_path}") - else: - # Auto-detect from common locations - possible_paths = [ - os.environ.get("LOGDIR", "."), - ".", - "/tmp/atropos_bridge", - os.path.dirname(os.path.abspath(__file__)), - ] - - config_path = None - for log_dir in possible_paths: - candidate = os.path.join(log_dir, "vllm_bridge_config.json") - if os.path.exists(candidate): - config_path = candidate - print(f"[Setup] Found vLLM config at: {candidate}") - break - - if config_path is None: - checked = [ - os.path.join(p, "vllm_bridge_config.json") for p in possible_paths - ] - raise RuntimeError( - f"[Setup] Could not find vllm_bridge_config.json\n" - f"Checked: {checked}\n" - f"Tip: Use --vllm-config-path to specify the path explicitly\n" - f"Make sure vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set" - ) - - model = _attach_to_vllm_shared_tensors(config, config_path) - if model is not None: - print("[Setup] βœ“ Single-copy mode active - using vLLM's tensors directly!") - model.train() - return model, tokenizer - else: - raise RuntimeError( - "[Setup] Single-copy mode FAILED to attach to vLLM's tensors.\n" - "Check:\n" - " 1. vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1\n" - " 2. vllm_bridge_config.json exists with ipc_handles\n" - " 3. Trainer is on SAME GPUs as vLLM" - ) + if config.weight_bridge_mode == "shared_vllm": + # Single-copy mode: attach to vLLM's weights, update in-place + train_shared_vllm(config) elif config.weight_bridge_mode == "lora_only": - model = _load_model_with_lora(config) - - else: - print("[Setup] Loading model for legacy mode...") - model = AutoModelForCausalLM.from_pretrained( - config.model_name, torch_dtype=torch.bfloat16 - ) - model.to(config.device) - - # Enable gradient checkpointing (saves memory) - # For LoRA, use PEFT's method; for others, use standard method - # Disable KV cache - incompatible with gradient checkpointing - # Setting explicitly avoids the warning message - model.config.use_cache = False - - if config.weight_bridge_mode == "lora_only": - # PEFT models need gradient_checkpointing enabled on base model - # and require use_reentrant=False for proper gradient flow - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs={"use_reentrant": False} - ) - else: - # Standard gradient checkpointing - model.gradient_checkpointing_enable() - - model.train() - - return model, tokenizer - - -def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: - """ - Load base model and wrap with LoRA adapters. - - Args: - config: Training configuration with LoRA settings - - Returns: - PEFT model with LoRA adapters applied - """ - if not PEFT_AVAILABLE: - raise RuntimeError("PEFT library not available. Install with: pip install peft") - - print("[Setup] Loading base model for LoRA mode...") - base_model = AutoModelForCausalLM.from_pretrained( - config.model_name, torch_dtype=torch.bfloat16 - ) - base_model.to(config.device) - - # Determine target modules - target_modules = config.lora_target_modules - if target_modules is None: - # Default modules for most transformer models - target_modules = ["q_proj", "v_proj"] - - print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}") - print(f"Target modules: {target_modules}") - - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=config.lora_r, - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, - target_modules=target_modules, - bias="none", - ) - - model = get_peft_model(base_model, lora_config) - model.print_trainable_parameters() - - return model - - -def save_lora_checkpoint( - model: torch.nn.Module, - save_path: str, - step: int, - is_final: bool = False, -) -> str: - """ - Save LoRA adapter checkpoint. - - Args: - model: PEFT model with LoRA adapters - save_path: Base directory for checkpoints - step: Current training step - is_final: Whether this is the final checkpoint - - Returns: - Path where adapter was saved - """ - if is_final: - adapter_path = os.path.join(save_path, "final_adapter") - else: - adapter_path = os.path.join(save_path, f"adapter_step_{step}") - - print(f" Saving LoRA adapter to {adapter_path}...") - - if os.path.exists(adapter_path): - shutil.rmtree(adapter_path) - os.makedirs(adapter_path, exist_ok=True) - - # Save only the adapter weights (much smaller than full model) - model.save_pretrained(adapter_path) - - print(" Adapter saved.") - return adapter_path - - -def compute_grpo_loss( - model: torch.nn.Module, - tokens: torch.Tensor, - labels: torch.Tensor, - advantages: torch.Tensor, - temperatures: torch.Tensor, - gradient_accumulation_steps: int, -) -> Tuple[torch.Tensor, dict]: - """ - Compute GRPO loss for a single micro-batch. - - Args: - model: The model to compute loss for - tokens: Input token IDs [batch, seq_len] - labels: Target labels [batch, seq_len] - advantages: Advantage values [batch, 1] - temperatures: Temperature values [batch, 1, 1] - gradient_accumulation_steps: Number of accumulation steps - - Returns: - Tuple of (loss tensor, metrics dict) - """ - # Forward pass - outputs = model(tokens) - logits = outputs.logits - - # Temperature scaling - t = temperatures.to(logits.device, logits.dtype) - t = torch.where(t <= 0, torch.ones_like(t), t) - logits = logits / t - - # Log probabilities per token - logp_per_token = -F.cross_entropy( - logits.view(-1, logits.size(-1)), - labels.view(-1), - reduction="none", - ignore_index=-100, - ).view(labels.shape) - - # Masking based on labels != -100 - mask = (labels != -100).float() - - # Compute metrics (no grad needed) - with torch.no_grad(): - pos = (advantages > 0).float() - neg = (advantages <= 0).float() - mask_float = mask.to(logp_per_token.dtype) - mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8) - avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum - pos_logp = (logp_per_token * pos).mean().item() - neg_logp = (logp_per_token * neg).mean().item() - - # GRPO loss - grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) - grpo_loss = ( - ((-grpo_loss_term * mask).sum(-1) / mask.sum(-1)) - * advantages.to(logp_per_token.device) - ).mean() / gradient_accumulation_steps - - metrics = { - "pos_logp": pos_logp, - "neg_logp": neg_logp, - "avg_logp": avg_logp, - "pos_count": pos.sum().item(), - "neg_count": neg.sum().item(), - } - - return grpo_loss, metrics - - -def run_training_step( - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - token_batches: List[torch.Tensor], - label_batches: List[torch.Tensor], - advantage_batches: List[torch.Tensor], - temperature_batches: List[torch.Tensor], - config: TrainingConfig, -) -> dict: - """ - Run a single training step (forward, backward, optimizer step). - - Args: - model: The model to train - optimizer: The optimizer - token_batches: List of token tensors - label_batches: List of label tensors - advantage_batches: List of advantage tensors - temperature_batches: List of temperature tensors - config: Training configuration - - Returns: - Dict of training metrics for this step - """ - total_loss = 0.0 - total_pos_logp = 0.0 - total_neg_logp = 0.0 - total_pos = 0.0 - total_neg = 0.0 - - # Accumulate gradients over micro-batches - for tokens, labels, advantages, temperatures in zip( - token_batches, label_batches, advantage_batches, temperature_batches - ): - tokens = tokens.to(config.device) - labels = labels.to(config.device) - advantages = advantages.to(config.device) - - loss, metrics = compute_grpo_loss( - model, - tokens, - labels, - advantages, - temperatures, - config.gradient_accumulation_steps, - ) - - loss.backward() - total_loss += loss.item() - total_pos_logp += metrics["pos_logp"] - total_neg_logp += metrics["neg_logp"] - total_pos += metrics["pos_count"] - total_neg += metrics["neg_count"] - - # Gradient clipping and optimizer step - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - optimizer.zero_grad() - - # Normalize metrics - if total_pos > 0: - total_pos_logp /= total_pos - if total_neg > 0: - total_neg_logp /= total_neg - - return { - "loss": total_loss, - "grad_norm": grad_norm.item(), - "pos_logp": total_pos_logp, - "neg_logp": total_neg_logp, - } - - -def save_checkpoint( - model: torch.nn.Module, - tokenizer: "AutoTokenizer", - save_path: str, - step: int, - is_final: bool = False, -) -> str: - """ - Save model checkpoint. - - Args: - model: Model to save - tokenizer: Tokenizer to save - save_path: Base directory for checkpoints - step: Current training step - is_final: Whether this is the final checkpoint - - Returns: - Path where checkpoint was saved - """ - if is_final: - checkpoint_path = os.path.join(save_path, "final_model") - else: - checkpoint_path = os.path.join(save_path, f"step_{step}") - - print(f" Saving checkpoint to {checkpoint_path}...") - - if os.path.exists(checkpoint_path): - shutil.rmtree(checkpoint_path) - os.makedirs(checkpoint_path, exist_ok=True) - - model.save_pretrained(checkpoint_path) - tokenizer.save_pretrained(checkpoint_path) - - print(" Checkpoint saved.") - return checkpoint_path - - -def log_metrics( - metrics: dict, - step: int, - use_wandb: bool, - extra_metrics: Optional[dict] = None, - benchmark: bool = False, -) -> None: - """ - Log training metrics to console and optionally wandb. - - Args: - metrics: Dict of metrics from training step - step: Current step number - use_wandb: Whether to log to wandb - extra_metrics: Optional additional metrics to log - benchmark: Whether to show timing/benchmark info - """ - # Console output with timing info (only if benchmark enabled) - timing_str = "" - if benchmark: - if "step_time" in metrics: - timing_str += f", Step time: {metrics['step_time']:.2f}s" - if "sync_time" in metrics and metrics["sync_time"] > 0: - timing_str += f", Sync time: {metrics['sync_time']:.2f}s" - if "data_fetch_time" in metrics: - timing_str += f", Data fetch: {metrics['data_fetch_time']:.2f}s" - if "gpu_memory_gb" in metrics: - timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB" - - # Show loss with more precision since GRPO loss is often very small - loss_str = ( - f"{metrics['loss']:.6f}" - if abs(metrics["loss"]) < 0.01 - else f"{metrics['loss']:.4f}" - ) - print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") - - # Show GRPO-specific metrics if available - if "pos_count" in metrics or "neg_count" in metrics: - pos_count = metrics.get("pos_count", 0) - neg_count = metrics.get("neg_count", 0) - pos_logp = metrics.get("pos_logp", 0) - neg_logp = metrics.get("neg_logp", 0) - print( - f" Advantages: +{int(pos_count)} / -{int(neg_count)}, LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}" - ) - - if use_wandb: - log_dict = { - "train/loss": metrics["loss"], - "train/grad_norm": metrics["grad_norm"], - "train/pos_logp": metrics["pos_logp"], - "train/neg_logp": metrics["neg_logp"], - } - # Add timing metrics if present - if "step_time" in metrics: - log_dict["train/step_time"] = metrics["step_time"] - if "sync_time" in metrics: - log_dict["train/sync_time"] = metrics["sync_time"] - if "data_fetch_time" in metrics: - log_dict["train/data_fetch_time"] = metrics["data_fetch_time"] - if "gpu_memory_gb" in metrics: - log_dict["train/gpu_memory_gb"] = metrics["gpu_memory_gb"] - if "gpu_memory_reserved_gb" in metrics: - log_dict["train/gpu_memory_reserved_gb"] = metrics["gpu_memory_reserved_gb"] - if extra_metrics: - log_dict.update(extra_metrics) - wandb.log(log_dict, step=step) - - -def finalize_training( - use_wandb: bool, - training_start_time: Optional[float] = None, - mode: str = "unknown", - total_steps: int = 0, - benchmark_stats: Optional[dict] = None, - benchmark: bool = False, -) -> None: - """Clean up after training and log benchmark summary. - - Args: - use_wandb: Whether wandb is enabled - training_start_time: Start time of training - mode: Training mode name - total_steps: Total steps completed - benchmark_stats: Dict with lists of per-step metrics: - - step_times: List of step durations - - sync_times: List of sync durations - - data_fetch_times: List of data fetch durations - - gpu_memories: List of GPU memory readings (GB) - benchmark: Whether to print benchmark summary to console - """ - print("\nTraining finished.") - - # Default empty stats - if benchmark_stats is None: - benchmark_stats = {} - - # Log benchmark summary - if training_start_time is not None: - total_time = time.time() - training_start_time - peak_gpu_mem_gb = ( - torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0 - ) - - # Calculate averages from collected stats - step_times = benchmark_stats.get("step_times", []) - sync_times = benchmark_stats.get("sync_times", []) - data_fetch_times = benchmark_stats.get("data_fetch_times", []) - gpu_memories = benchmark_stats.get("gpu_memories", []) - - avg_step_time = sum(step_times) / len(step_times) if step_times else 0 - total_step_time = sum(step_times) - avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0 - total_sync_time = sum(sync_times) - avg_data_fetch = ( - sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0 - ) - total_data_fetch = sum(data_fetch_times) - avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0 - - # Print benchmark summary only if benchmark flag is enabled - if benchmark: - print(f"\n{'='*70}") - print(f"BENCHMARK SUMMARY ({mode})") - print(f"{'='*70}") - print( - f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)" - ) - print(f" Total steps: {total_steps}") - print(" ") - print(" TIMING BREAKDOWN:") - print(f" Avg step time: {avg_step_time:.2f}s") - print(f" Total step time: {total_step_time:.2f}s") - print( - f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)" - ) - print(f" Total sync time: {total_sync_time:.2f}s") - print(f" Avg data fetch time: {avg_data_fetch:.2f}s") - print(f" Total data fetch time: {total_data_fetch:.2f}s") - print(" ") - print(" MEMORY:") - print(f" Peak GPU memory: {peak_gpu_mem_gb:.2f} GB") - print(f" Avg GPU memory: {avg_gpu_mem:.2f} GB") - print(f"{'='*70}\n") - - if use_wandb: - # Total time metrics - wandb.summary["benchmark/total_time_seconds"] = total_time - wandb.summary["benchmark/total_time_minutes"] = total_time / 60 - wandb.summary["benchmark/mode"] = mode - wandb.summary["benchmark/total_steps"] = total_steps - - # Step timing metrics - wandb.summary["benchmark/avg_step_time_seconds"] = avg_step_time - wandb.summary["benchmark/total_step_time_seconds"] = total_step_time - - # Sync timing metrics - wandb.summary["benchmark/avg_sync_time_seconds"] = avg_sync_time - wandb.summary["benchmark/total_sync_time_seconds"] = total_sync_time - wandb.summary["benchmark/num_syncs"] = len(sync_times) - - # Data fetch timing metrics - wandb.summary["benchmark/avg_data_fetch_time_seconds"] = avg_data_fetch - wandb.summary["benchmark/total_data_fetch_time_seconds"] = total_data_fetch - - # Memory metrics - wandb.summary["benchmark/peak_gpu_memory_gb"] = peak_gpu_mem_gb - wandb.summary["benchmark/avg_gpu_memory_gb"] = avg_gpu_mem - - if use_wandb: - wandb.finish() - - -def train(config: TrainingConfig): - """ - Legacy GRPO training with periodic vLLM restarts. - - This mode saves checkpoints to disk and restarts vLLM to pick up new weights. - Use weight_bridge_mode='shared_vllm' for in-place weight updates without restarts. - """ - global vllm_process - training_start_time = time.time() - - # === Setup === - use_wandb = setup_wandb(config) - model, tokenizer = load_model_and_tokenizer(config) - optimizer = AdamW(model.parameters(), lr=config.lr) - - print(f"\n{'='*60}") - print("LEGACY MODE (checkpoint + vLLM restart)") - print(f"{'='*60}") - print(f"Training for {config.training_steps} steps on {config.device}") - print(f"vLLM restart interval: every {config.vllm_restart_interval} steps") - print(f"Save path: {config.save_path}") - print(f"{'='*60}\n") - - os.makedirs(config.save_path, exist_ok=True) - register_trainer(config) - - # Launch initial vLLM server - vllm_process = _launch_vllm_server(config, config.model_name) - - # === Benchmark tracking === - benchmark_stats = { - "step_times": [], - "sync_times": [], - "data_fetch_times": [], - "gpu_memories": [], - } - - # === Training Loop === - batches = [] - for step in range(config.training_steps): - print(f"\nStep {step+1}/{config.training_steps}") - - # Track data fetch time - data_fetch_start = time.time() - if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len, config.atropos_url) - token_batches, label_batches, advantage_batches, temperature_batches = ( - batches.pop(0) - ) - data_fetch_time = time.time() - data_fetch_start - benchmark_stats["data_fetch_times"].append(data_fetch_time) - - # Terminate vLLM before training step (to free GPU memory) - should_sync = ( - step + 1 - ) % config.vllm_restart_interval == 0 or step == config.training_steps - 1 - if should_sync: - _terminate_vllm_process() - - # Track step time - step_start = time.time() - - # Run training step using common helper - metrics = run_training_step( - model, - optimizer, - token_batches, - label_batches, - advantage_batches, - temperature_batches, - config, - ) - - step_time = time.time() - step_start - benchmark_stats["step_times"].append(step_time) - - # Track GPU memory - if torch.cuda.is_available(): - gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 - gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 - benchmark_stats["gpu_memories"].append(gpu_mem_gb) - else: - gpu_mem_gb = 0 - gpu_mem_reserved_gb = 0 - - # Track sync time - sync_time = 0 - if should_sync: - sync_start = time.time() - checkpoint_path = save_checkpoint( - model, tokenizer, config.save_path, step + 1 - ) - torch.cuda.empty_cache() - vllm_process = _launch_vllm_server(config, checkpoint_path) - sync_time = time.time() - sync_start - benchmark_stats["sync_times"].append(sync_time) - - # Add timing metrics - metrics["step_time"] = step_time - metrics["sync_time"] = sync_time - metrics["data_fetch_time"] = data_fetch_time - metrics["gpu_memory_gb"] = gpu_mem_gb - metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb - - # Log metrics - log_metrics( - metrics, - step + 1, - use_wandb, - { - "train/learning_rate": optimizer.param_groups[0]["lr"], - }, - benchmark=config.benchmark, - ) - - # Check for unexpected vLLM termination - _check_vllm_process_health() - - # === Cleanup === - save_checkpoint( - model, tokenizer, config.save_path, config.training_steps, is_final=True - ) - finalize_training( - use_wandb, - training_start_time, - "legacy", - config.training_steps, - benchmark_stats, - benchmark=config.benchmark, - ) - - -# ============================================================================= -# vLLM Process Management (Legacy Mode Only) -# ============================================================================= - - -def _launch_vllm_server( - config: TrainingConfig, model_path: str -) -> Optional[subprocess.Popen]: - """Launch a vLLM server process using our custom vllm_api_server.py. - - Uses the custom server instead of standard vLLM because: - - Standard vLLM only has /v1/completions (OpenAI-compatible) - - Our custom server has /generate endpoint needed by VLLMServer class - - This allows proper tokens_and_logprobs_completion support - """ - vllm_process - - # Use our custom vllm_api_server.py instead of standard vLLM - # This provides the /generate endpoint that VLLMServer needs - script_dir = os.path.dirname(os.path.abspath(__file__)) - custom_server_path = os.path.join(script_dir, "vllm_api_server.py") - - vllm_command = [ - "python", - custom_server_path, - "--model", - model_path, - "--port", - str(config.vllm_port), - "--gpu-memory-utilization", - str(config.vllm_gpu_memory_utilization), - ] - # Add served-model-name if using checkpoint path - if model_path != config.model_name: - vllm_command.extend(["--served-model-name", config.model_name]) - - print(f" Launching vLLM: {' '.join(vllm_command)}") - - try: - proc = subprocess.Popen(vllm_command) - print(f" vLLM launched with PID: {proc.pid}") - - # Check for immediate startup errors - try: - proc.communicate(timeout=2) - if proc.returncode is not None and proc.returncode != 0: - print(" WARNING: vLLM failed to start") - return None - except subprocess.TimeoutExpired: - print(" vLLM process started (check logs for details)") - - return proc - - except FileNotFoundError: - print(" ERROR: vLLM not found. Is it installed?") - return None - except Exception as e: - print(f" ERROR launching vLLM: {e}") - return None - - -def _terminate_vllm_process() -> None: - """Terminate the running vLLM process if any.""" - global vllm_process - - if vllm_process is None: - return - - print(" Terminating vLLM process...") - vllm_process.terminate() - try: - vllm_process.wait(timeout=5) - except subprocess.TimeoutExpired: - print(" vLLM did not terminate gracefully, killing...") - vllm_process.kill() - vllm_process.wait() - vllm_process = None - - -def _check_vllm_process_health() -> None: - """Check if vLLM process terminated unexpectedly (legacy mode).""" - global vllm_process - - if vllm_process is not None and vllm_process.poll() is not None: - print( - f" WARNING: vLLM terminated unexpectedly (code: {vllm_process.returncode})" - ) - vllm_process = None - - -def train_shared_vllm(config: TrainingConfig): - """ - GRPO training with shared vLLM weights. - - Instead of saving checkpoints and restarting vLLM, this mode: - 1. Joins the same distributed group as vLLM - 2. Attaches to vLLM's weight tensors directly via CUDA IPC - 3. optimizer.step() modifies vLLM's weights in-place - 4. vLLM immediately uses updated weights (no restart!) - - Requirements: - - vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1 - - Trainer on same GPU(s) as vLLM (for IPC to work) - """ - training_start_time = time.time() - - # === Setup === - use_wandb = setup_wandb(config) - - print(f"\n{'='*60}") - print("SINGLE-COPY MODE (CUDA IPC)") - print(">>> TRUE shared memory - only ONE model copy!") - print(">>> Trainer uses vLLM's tensors directly!") - print(f"{'='*60}") - print(f"Model: {config.model_name}") - print(f"Distributed: rank={config.trainer_rank}/{config.world_size}") - print(f"Init method: {config.init_method}") - print(f"Save path: {config.save_path}") - print(f"{'='*60}\n") - - # Single-copy mode: attach directly to vLLM's tensors via CUDA IPC - print("[1/2] Attaching to vLLM's shared tensors...") - model, tokenizer = load_model_and_tokenizer(config, single_copy=True) - - if model is None: - raise RuntimeError( - "Single-copy mode failed. Make sure:\n" - "1. vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1\n" - "2. Trainer is on the SAME GPUs as vLLM\n" - "3. vllm_bridge_config.json exists with IPC handles" - ) - - optimizer = AdamW( - model.parameters(), lr=config.lr - ) # maybe we need to make this configurable in the future - - print(f"[2/2] Starting training for {config.training_steps} steps") - print("NOTE: vLLM sees weight updates immediately after each step!") - print("-" * 60) - - os.makedirs(config.save_path, exist_ok=True) - - # Check Atropos API and register BEFORE training loop - print(f"\n[Setup] Connecting to Atropos API at {config.atropos_url}...") - if not check_atropos_api(url=config.atropos_url, timeout=30): - raise RuntimeError( - f"Atropos API server not reachable at {config.atropos_url}. " - "Please start the environment server (e.g., gsm8k_server.py serve)" - ) - register_trainer(config) - - # === Benchmark tracking === - benchmark_stats = { - "step_times": [], - "sync_times": [], # For shared mode, this is the notify_update time - "data_fetch_times": [], - "gpu_memories": [], - } - - # === Training Loop === - batches = [] - for step in range(config.training_steps): - print(f"\nStep {step+1}/{config.training_steps}") - - # Track data fetch time - data_fetch_start = time.time() - if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len, config.atropos_url) - token_batches, label_batches, advantage_batches, temperature_batches = ( - batches.pop(0) - ) - data_fetch_time = time.time() - data_fetch_start - benchmark_stats["data_fetch_times"].append(data_fetch_time) - - # Track step time - step_start = time.time() - - # Run training step using common helper - metrics = run_training_step( - model, - optimizer, - token_batches, - label_batches, - advantage_batches, - temperature_batches, - config, - ) - - step_time = time.time() - step_start - benchmark_stats["step_times"].append(step_time) - - # Track GPU memory - if torch.cuda.is_available(): - gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 - gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 - benchmark_stats["gpu_memories"].append(gpu_mem_gb) - else: - gpu_mem_gb = 0 - gpu_mem_reserved_gb = 0 - - # In single-copy mode, weights are already updated in-place (same GPU memory!) - # No synchronization needed - vLLM sees changes immediately - sync_time = 0.0 - print(f" [SINGLE-COPY] Weights updated in-place - step {step+1}") - benchmark_stats["sync_times"].append(sync_time) - - # Add timing metrics - metrics["step_time"] = step_time - metrics["sync_time"] = sync_time - metrics["data_fetch_time"] = data_fetch_time - metrics["gpu_memory_gb"] = gpu_mem_gb - metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb - - # Log metrics - log_metrics( - metrics, - step + 1, - use_wandb, - { - "train/learning_rate": optimizer.param_groups[0]["lr"], - "train/update_count": step + 1, - }, - benchmark=config.benchmark, - ) - - # Periodic checkpoint save (for recovery, not for vLLM sync) - if (step + 1) % config.vllm_restart_interval == 0: - save_checkpoint(model, tokenizer, config.save_path, step + 1) - - # === Cleanup === - save_checkpoint( - model, tokenizer, config.save_path, config.training_steps, is_final=True - ) - finalize_training( - use_wandb, - training_start_time, - "shared_vllm", - config.training_steps, - benchmark_stats, - benchmark=config.benchmark, - ) - - -def _check_vllm_health(port: int) -> bool: - """Check if external vLLM server is running and healthy.""" - try: - response = requests.get(f"http://localhost:{port}/health", timeout=5) - return response.status_code == 200 - except Exception: - return False - - -def _hotswap_lora_adapter( - port: int, adapter_path: str, adapter_name: Optional[str] = None -) -> bool: - """ - Request vLLM to hot-swap to a new LoRA adapter. - - Tries both: - 1. Native vLLM endpoint: /v1/load_lora_adapter (standard vLLM serve) - 2. Custom endpoint: /lora/load (vllm_api_server.py) - - Args: - port: vLLM server port - adapter_path: Path to the saved adapter directory - adapter_name: Optional name for the adapter - - Returns: - True if successful, False otherwise - """ - base_url = f"http://localhost:{port}" - name = adapter_name or os.path.basename(adapter_path) - - # Try native vLLM endpoint first (standard vllm serve) - try: - response = requests.post( - f"{base_url}/v1/load_lora_adapter", - json={"lora_name": name, "lora_path": adapter_path}, - timeout=30, - ) - if response.status_code == 200: - print(f" [LORA] Hot-swapped adapter via native API: {name} ({adapter_path})") - return True - except Exception: - pass # Try custom endpoint - - # Try custom endpoint (vllm_api_server.py) - try: - response = requests.post( - f"{base_url}/lora/load", - json={"adapter_path": adapter_path, "adapter_name": name}, - timeout=30, - ) - if response.status_code == 200: - print(f" [LORA] Hot-swapped adapter via custom API: {name} ({adapter_path})") - return True - else: - print(f" [LORA] Hot-swap failed: {response.text}") - return False - except Exception as e: - print(f" [LORA] Hot-swap request failed: {e}") - return False - - -def train_lora(config: TrainingConfig): - """ - GRPO training with LoRA adapters. - - This mode keeps the base model frozen and only trains LoRA adapter weights. - - REQUIRES: External vLLM server running via vllm_api_server.py - - Benefits: - - Much faster training (fewer parameters) - - Smaller checkpoint sizes (adapter only, not full model) - - Adapters can be hot-swapped in vLLM via /lora/load endpoint - """ - if not PEFT_AVAILABLE: - raise RuntimeError( - "PEFT library required for LoRA mode. Install with: pip install peft" - ) - - training_start_time = time.time() - - # === Setup === - use_wandb = setup_wandb(config) - - print(f"\n{'='*60}") - print("LORA MODE (adapter-only training)") - print(f"{'='*60}") - print(f"Base model: {config.model_name}") - print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") - print(f"Save path: {config.save_path}") - print(f"vLLM port: {config.vllm_port}") - print(f"{'='*60}\n") - - # Check that external vLLM is running - print("[1/3] Checking external vLLM server...") - if not _check_vllm_health(config.vllm_port): - print(f"\nERROR: vLLM server not running on port {config.vllm_port}") - print("\nLoRA mode requires an external vLLM server. Start it first:") - print(" python example_trainer/vllm_api_server.py \\") - print(f" --model {config.model_name} \\") - print(f" --port {config.vllm_port} \\") - print(" --gpu-memory-utilization 0.45") - raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") - print(f"vLLM server healthy on port {config.vllm_port}") - - # Load model with LoRA adapters - print("[2/3] Loading model with LoRA adapters...") - model, tokenizer = load_model_and_tokenizer(config) - - # Only optimize LoRA parameters (base model is frozen) - trainable_params = [p for p in model.parameters() if p.requires_grad] - optimizer = AdamW(trainable_params, lr=config.lr) - - print(f"[3/3] Starting training for {config.training_steps} steps") - print("-" * 60) - - os.makedirs(config.save_path, exist_ok=True) - register_trainer(config) - - # NOTE: No vLLM launch here - using external vLLM server - - # === Benchmark tracking === - benchmark_stats = { - "step_times": [], - "sync_times": [], # For LoRA mode, this is adapter save + hot-swap time - "data_fetch_times": [], - "gpu_memories": [], - } - - # === Training Loop === - batches = [] - for step in range(config.training_steps): - print(f"\nStep {step+1}/{config.training_steps}") - - # Track data fetch time - data_fetch_start = time.time() - if len(batches) == 0: - batches = get_data(config.batch_size, config.seq_len, config.atropos_url) - token_batches, label_batches, advantage_batches, temperature_batches = ( - batches.pop(0) - ) - data_fetch_time = time.time() - data_fetch_start - benchmark_stats["data_fetch_times"].append(data_fetch_time) - - # Track step time - step_start = time.time() - - # Run training step - metrics = run_training_step( - model, - optimizer, - token_batches, - label_batches, - advantage_batches, - temperature_batches, - config, - ) - - step_time = time.time() - step_start - benchmark_stats["step_times"].append(step_time) - - # Track GPU memory - if torch.cuda.is_available(): - gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 - gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 - benchmark_stats["gpu_memories"].append(gpu_mem_gb) - else: - gpu_mem_gb = 0 - gpu_mem_reserved_gb = 0 - - # Track sync time (adapter save + hot-swap) - sync_time = 0 - should_sync = (step + 1) % config.vllm_restart_interval == 0 - if should_sync: - sync_start = time.time() - adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) - # Try to hot-swap the adapter in vLLM (non-blocking, best effort) - _hotswap_lora_adapter(config.vllm_port, adapter_path, f"step_{step + 1}") - sync_time = time.time() - sync_start - benchmark_stats["sync_times"].append(sync_time) - - # Add timing metrics - metrics["step_time"] = step_time - metrics["sync_time"] = sync_time - metrics["data_fetch_time"] = data_fetch_time - metrics["gpu_memory_gb"] = gpu_mem_gb - metrics["gpu_memory_reserved_gb"] = gpu_mem_reserved_gb - - # Log metrics - log_metrics( - metrics, - step + 1, - use_wandb, - { - "train/learning_rate": optimizer.param_groups[0]["lr"], - "lora/trainable_params": sum(p.numel() for p in trainable_params), - }, - benchmark=config.benchmark, - ) - - # === Cleanup === - # NOTE: No vLLM termination - external server keeps running - - # Save final adapter (track this sync time too) - final_sync_start = time.time() - final_adapter_path = save_lora_checkpoint( - model, config.save_path, config.training_steps, is_final=True - ) - - # Hot-swap to final adapter - _hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final") - final_sync_time = time.time() - final_sync_start - benchmark_stats["sync_times"].append(final_sync_time) - - finalize_training( - use_wandb, - training_start_time, - "lora_only", - config.training_steps, - benchmark_stats, - benchmark=config.benchmark, - ) - - # Also save tokenizer for convenience - tokenizer_path = os.path.join(config.save_path, "tokenizer") - tokenizer.save_pretrained(tokenizer_path) - print(f"Tokenizer saved to {tokenizer_path}") - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments for the GRPO trainer.""" - parser = argparse.ArgumentParser( - description="GRPO Trainer with optional shared-weight vLLM integration", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # --- Core training arguments --- - parser.add_argument( - "--model-name", - type=str, - required=True, - help="HuggingFace model identifier (e.g., 'Qwen/Qwen2.5-1.5B-Instruct')", - ) - parser.add_argument( - "--lr", - type=float, - default=1e-5, - help="Learning rate for the optimizer", - ) - parser.add_argument( - "--training-steps", - type=int, - default=10, - help="Number of training steps to run", - ) - parser.add_argument( - "--batch-size", - type=int, - default=2, - help="Batch size for training", - ) - parser.add_argument( - "--seq-len", - type=int, - default=2048, - help="Maximum sequence length", - ) - parser.add_argument( - "--gradient-accumulation-steps", - type=int, - default=32, - help="Number of gradient accumulation steps", - ) - parser.add_argument( - "--device", - type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device to train on (cuda/cpu)", - ) - parser.add_argument( - "--save-path", - type=str, - default="trained_model_checkpoints", - help="Directory to save model checkpoints", - ) - - # --- vLLM arguments --- - parser.add_argument( - "--vllm-restart-interval", - type=int, - default=3, - help="Restart vLLM every N training steps (legacy mode only)", - ) - parser.add_argument( - "--vllm-port", - type=int, - default=9001, - help="Port for the vLLM server", - ) - parser.add_argument( - "--atropos-url", - type=str, - default="http://localhost:8000", - help="URL of the Atropos API/environment server (e.g., gsm8k_server)", - ) - parser.add_argument( - "--vllm-gpu-memory-utilization", - type=float, - default=0.45, - help="GPU memory utilization for vLLM server (0.0-1.0)", - ) - - # --- Wandb arguments --- - parser.add_argument( - "--use-wandb", - action="store_true", - help="Enable Weights & Biases logging", - ) - parser.add_argument( - "--wandb-project", - type=str, - default=None, - help="Wandb project name", - ) - parser.add_argument( - "--wandb-group", - type=str, - default=None, - help="Wandb group name", - ) - - # --- Pipeline / weight bridge arguments --- - parser.add_argument( - "--weight-bridge-mode", - type=str, - choices=["shared_vllm", "lora_only", "none"], - default="none", - help=( - "Weight sync mode: " - "'shared_vllm' = attach to vLLM shared memory, " - "'lora_only' = train LoRA adapters only, " - "'none' = legacy restart-based sync" - ), - ) - parser.add_argument( - "--trainer-rank", - type=int, - default=0, - help="Rank of this trainer in the distributed group", - ) - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Total processes in the distributed group", - ) - parser.add_argument( - "--init-method", - type=str, - default="env://", - help="PyTorch distributed init method (e.g., 'env://', 'tcp://host:port')", - ) - parser.add_argument( - "--num-inference-nodes", - type=int, - default=0, - help="Number of inference nodes to coordinate with (0 = single-node local)", - ) - - # --- LoRA arguments --- - parser.add_argument( - "--lora-r", - type=int, - default=16, - help="LoRA rank (dimension of low-rank matrices)", - ) - parser.add_argument( - "--lora-alpha", - type=int, - default=32, - help="LoRA alpha (scaling factor, typically 2x rank)", - ) - parser.add_argument( - "--lora-dropout", - type=float, - default=0.05, - help="Dropout probability for LoRA layers", - ) - parser.add_argument( - "--lora-target-modules", - type=str, - nargs="+", - default=None, - help="Module names to apply LoRA to (default: q_proj v_proj)", - ) - - parser.add_argument( - "--single-copy", - action="store_true", - help=( - "Enable TRUE single-copy mode (shared_vllm mode only). " - "Trainer attaches to vLLM's model tensors via CUDA IPC. " - "Only ONE copy of the model exists in GPU memory! " - "Requires trainer and vLLM to be on the SAME GPU(s). " - "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." - ), - ) - parser.add_argument( - "--vllm-config-path", - type=str, - default=None, - help=( - "Explicit path to vllm_bridge_config.json. " - "If not provided, auto-detects from LOGDIR, current directory, " - "or /tmp/atropos_bridge. " - "This file contains CUDA IPC handles created by vLLM." - ), - ) - - # --- Debug flags --- - parser.add_argument( - "--debug-loading", - action="store_true", - help=( - "Enable verbose debug output during model loading and IPC attachment. " - "Useful for diagnosing single-copy mode issues." - ), - ) - parser.add_argument( - "--benchmark", - action="store_true", - help=( - "Enable benchmark timing output showing step time, sync time, " - "data fetch time, and GPU memory usage per step." - ), - ) - - return parser.parse_args() - - -def config_from_args(args: argparse.Namespace) -> TrainingConfig: - """Build a TrainingConfig from parsed CLI arguments.""" - return TrainingConfig( - model_name=args.model_name, - lr=args.lr, - training_steps=args.training_steps, - batch_size=args.batch_size, - seq_len=args.seq_len, - gradient_accumulation_steps=args.gradient_accumulation_steps, - device=args.device, - save_path=args.save_path, - vllm_restart_interval=args.vllm_restart_interval, - vllm_port=args.vllm_port, - vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization, - use_wandb=args.use_wandb, - wandb_project=args.wandb_project, - wandb_group=args.wandb_group, - weight_bridge_mode=args.weight_bridge_mode, - trainer_rank=args.trainer_rank, - world_size=args.world_size, - init_method=args.init_method, - num_inference_nodes=args.num_inference_nodes, - lora_r=args.lora_r, - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - lora_target_modules=args.lora_target_modules, - single_copy=getattr(args, "single_copy", False), - vllm_config_path=getattr(args, "vllm_config_path", None), - debug_loading=getattr(args, "debug_loading", False), - benchmark=getattr(args, "benchmark", False), - atropos_url=getattr(args, "atropos_url", "http://localhost:8000"), - ) - - -# Example usage (optional, can be run from another script) -if __name__ == "__main__": - args = parse_args() - training_config = config_from_args(args) - - print(f"Weight bridge mode: {training_config.weight_bridge_mode}") - - if training_config.weight_bridge_mode == "shared_vllm": - # Shared vLLM mode: attach to vLLM's weights, update in-place - train_shared_vllm(training_config) - - elif training_config.weight_bridge_mode == "lora_only": # LoRA mode: freeze base model, train adapters only - train_lora(training_config) + train_lora(config) else: # Legacy mode: periodic checkpoint saves + vLLM restarts - train(training_config) + train_legacy(config) + + +if __name__ == "__main__": + main() + diff --git a/example_trainer/model.py b/example_trainer/model.py new file mode 100644 index 00000000..71d52c99 --- /dev/null +++ b/example_trainer/model.py @@ -0,0 +1,607 @@ +""" +Model loading and shared memory utilities for GRPO trainer. + +Handles: +- Standard model loading (legacy mode) +- LoRA model loading and wrapping +- Single-copy mode: Attaching to vLLM's shared tensors via CUDA IPC +""" + +import base64 +import json +import os +from typing import Dict, Optional, Tuple + +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from .config import TrainingConfig + +# Import PEFT for LoRA training +try: + from peft import LoraConfig, TaskType, get_peft_model + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + + +def load_model_and_tokenizer( + config: TrainingConfig, + single_copy: bool = False, +) -> Tuple[torch.nn.Module, AutoTokenizer]: + """ + Load or attach to model based on weight_bridge_mode. + + Args: + config: Training configuration + single_copy: If True, attach to vLLM's shared tensors via CUDA IPC + + Returns: + Tuple of (model, tokenizer) + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + + # Single-copy mode: attach to vLLM's shared tensors via CUDA IPC + if single_copy or config.weight_bridge_mode == "shared_vllm": + config_path = _find_vllm_config(config) + model = _attach_to_vllm_shared_tensors(config, config_path) + + if model is not None: + print("[Setup] βœ“ Single-copy mode active - using vLLM's tensors directly!") + model.train() + return model, tokenizer + else: + raise RuntimeError( + "[Setup] Single-copy mode FAILED to attach to vLLM's tensors.\n" + "Check:\n" + " 1. vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1\n" + " 2. vllm_bridge_config.json exists with ipc_handles\n" + " 3. Trainer is on SAME GPUs as vLLM" + ) + + elif config.weight_bridge_mode == "lora_only": + model = _load_model_with_lora(config) + + else: + # Legacy mode: load full model + print("[Setup] Loading model for legacy mode...") + model = AutoModelForCausalLM.from_pretrained( + config.model_name, torch_dtype=torch.bfloat16 + ) + model.to(config.device) + + # Enable gradient checkpointing + _setup_gradient_checkpointing(model, config) + model.train() + + return model, tokenizer + + +def _find_vllm_config(config: TrainingConfig) -> str: + """Find the vllm_bridge_config.json file.""" + # Check explicit path first + if config.vllm_config_path and os.path.exists(config.vllm_config_path): + print(f"[Setup] Using explicit vLLM config path: {config.vllm_config_path}") + return config.vllm_config_path + + # Auto-detect from common locations + possible_paths = [ + os.environ.get("LOGDIR", "."), + ".", + "/tmp/atropos_bridge", + os.path.dirname(os.path.abspath(__file__)), + ] + + for log_dir in possible_paths: + candidate = os.path.join(log_dir, "vllm_bridge_config.json") + if os.path.exists(candidate): + print(f"[Setup] Found vLLM config at: {candidate}") + return candidate + + checked = [os.path.join(p, "vllm_bridge_config.json") for p in possible_paths] + raise RuntimeError( + f"[Setup] Could not find vllm_bridge_config.json\n" + f"Checked: {checked}\n" + f"Tip: Use --vllm-config-path to specify the path explicitly\n" + f"Make sure vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1 and LOGDIR set" + ) + + +def _load_model_with_lora(config: TrainingConfig) -> torch.nn.Module: + """ + Load base model and wrap with LoRA adapters. + + Args: + config: Training configuration with LoRA settings + + Returns: + PEFT model with LoRA adapters applied + """ + if not PEFT_AVAILABLE: + raise RuntimeError("PEFT library not available. Install with: pip install peft") + + print("[Setup] Loading base model for LoRA mode...") + base_model = AutoModelForCausalLM.from_pretrained( + config.model_name, torch_dtype=torch.bfloat16 + ) + base_model.to(config.device) + + # Determine target modules + target_modules = config.lora_target_modules + if target_modules is None: + target_modules = ["q_proj", "v_proj"] + + print(f"Applying LoRA: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Target modules: {target_modules}") + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + target_modules=target_modules, + bias="none", + ) + + model = get_peft_model(base_model, lora_config) + model.print_trainable_parameters() + + return model + + +def _setup_gradient_checkpointing(model: torch.nn.Module, config: TrainingConfig) -> None: + """Configure gradient checkpointing for the model.""" + # Disable KV cache - incompatible with gradient checkpointing + model.config.use_cache = False + + if config.weight_bridge_mode == "lora_only": + # PEFT models need special handling + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + else: + model.gradient_checkpointing_enable() + + +def _attach_to_vllm_shared_tensors( + config: TrainingConfig, + bridge_config_path: str, +) -> Optional[torch.nn.Module]: + """ + Attach to vLLM's shared tensors via CUDA IPC (true single-copy mode). + + This creates a model whose parameters point to the SAME GPU memory as vLLM, + meaning only ONE copy of the model exists in GPU memory. + + Args: + config: Training configuration + bridge_config_path: Path to vllm_bridge_config.json + + Returns: + Model with parameters pointing to vLLM's tensors, or None if not possible + """ + print(f"[Setup] Reading bridge config from: {bridge_config_path}") + try: + with open(bridge_config_path, "r") as f: + bridge_config = json.load(f) + print(f"[Setup] Bridge config keys: {list(bridge_config.keys())}") + except Exception as e: + print(f"[Setup] Could not read bridge config: {e}") + return None + + single_copy_enabled = bridge_config.get("single_copy_enabled", False) + print(f"[Setup] single_copy_enabled in config: {single_copy_enabled}") + + if not single_copy_enabled: + print("[Setup] Single-copy mode not available (single_copy_enabled=False)") + print("[Setup] Make sure vLLM was started with VLLM_ENABLE_SHARED_WEIGHTS=1") + return None + + ipc_handles_raw = bridge_config.get("ipc_handles", {}) + print(f"[Setup] IPC handles count: {len(ipc_handles_raw)}") + if not ipc_handles_raw: + print("[Setup] No IPC handles found in bridge config") + return None + + # Deserialize base64-encoded bytes + ipc_handles = _deserialize_ipc_handles(ipc_handles_raw) + + print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...") + print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") + + # Load model config (not weights) to get architecture + model_config = AutoConfig.from_pretrained(config.model_name) + + # Create empty model on meta device (no memory allocation) + with torch.device("meta"): + model = AutoModelForCausalLM.from_config( + model_config, + torch_dtype=torch.bfloat16, + ) + + param_names = list(model.state_dict().keys()) + print(f"[Setup] Model architecture has {len(param_names)} parameters", flush=True) + + # Initialize CUDA on devices used by vLLM + device_indices = _initialize_cuda_devices(ipc_handles) + + # Create mapping from HF names to vLLM tensors + vllm_to_hf_mapping = _create_vllm_to_hf_mapping( + model, ipc_handles, debug=config.debug_loading + ) + + # Reconstruct tensors and build state dict + hf_state_dict, attached_count, fused_count = _reconstruct_shared_tensors( + ipc_handles, vllm_to_hf_mapping, config + ) + + print(f"[Setup] Attached {attached_count} tensors ({fused_count} from fused layers)") + + if attached_count == 0: + print("[Setup] Could not attach any tensors, falling back to regular loading") + return None + + # Validate mapping coverage + _validate_mapping_coverage(model, hf_state_dict, attached_count) + + # Load state dict into model + model.load_state_dict(hf_state_dict, strict=False, assign=True) + + # Initialize remaining meta tensors + device = f"cuda:{list(device_indices)[0]}" if device_indices else "cuda:0" + _initialize_meta_tensors(model, device, config) + + # Final validation - ensure nothing is left on meta device + _validate_no_meta_tensors(model) + + print("[Setup] βœ“ All tensors successfully initialized on CUDA") + return model + + +def _deserialize_ipc_handles(handles_raw: dict) -> dict: + """Deserialize base64-encoded bytes in IPC handles.""" + def deserialize(handles): + result = {} + for k, v in handles.items(): + if isinstance(v, dict): + if "_bytes_b64_" in v: + result[k] = base64.b64decode(v["_bytes_b64_"]) + else: + result[k] = deserialize(v) + else: + result[k] = v + return result + return deserialize(handles_raw) + + +def _initialize_cuda_devices(ipc_handles: dict) -> set: + """Initialize CUDA context on devices used by IPC handles.""" + device_indices = set() + for name, info in ipc_handles.items(): + if "device_index" in info: + device_indices.add(info["device_index"]) + + print(f"[Setup] IPC handles span devices: {sorted(device_indices)}", flush=True) + + for dev_idx in sorted(device_indices): + print(f"[Setup] Initializing CUDA on device {dev_idx}...", flush=True) + torch.cuda.set_device(dev_idx) + torch.cuda.synchronize(dev_idx) + print(f"[Setup] βœ“ Device {dev_idx} ready", flush=True) + + return device_indices + + +def _reconstruct_shared_tensors( + ipc_handles: dict, + vllm_to_hf_mapping: dict, + config: TrainingConfig, +) -> Tuple[dict, int, int]: + """Reconstruct tensors from IPC handles and build state dict.""" + hf_state_dict = {} + vllm_tensor_cache: Dict[str, torch.Tensor] = {} + attached_count = 0 + fused_count = 0 + + def reconstruct_vllm_tensor(vllm_name: str) -> Optional[torch.Tensor]: + if vllm_name in vllm_tensor_cache: + return vllm_tensor_cache[vllm_name] + + if vllm_name not in ipc_handles: + return None + + ipc_info = ipc_handles[vllm_name] + if "ipc_handle_b64" not in ipc_info: + return None + + try: + device_index = ipc_info["device_index"] + ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"]) + storage_size = ipc_info["storage_size"] + storage_offset_orig = ipc_info["storage_offset_orig"] + ref_counter_handle = base64.b64decode(ipc_info["ref_counter_handle_b64"]) + ref_counter_offset = ipc_info["ref_counter_offset"] + event_handle = base64.b64decode(ipc_info["event_handle_b64"]) + event_sync_required = ipc_info["event_sync_required"] + + share_tuple = ( + device_index, ipc_handle, storage_size, storage_offset_orig, + ref_counter_handle, ref_counter_offset, event_handle, event_sync_required, + ) + + storage = torch.UntypedStorage._new_shared_cuda(*share_tuple) + dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) + tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") + tensor.set_( + storage, + storage_offset=ipc_info["tensor_storage_offset"], + size=ipc_info["shape"], + stride=ipc_info["stride"], + ) + + vllm_tensor_cache[vllm_name] = tensor + return tensor + + except Exception as e: + print(f"[Setup] Failed to reconstruct {vllm_name}: {e}", flush=True) + return None + + for hf_name, mapping_info in vllm_to_hf_mapping.items(): + try: + if isinstance(mapping_info, dict): + # Fused mapping - slice the source tensor + vllm_name = mapping_info["source"] + slice_start, slice_end = mapping_info["slice"] + slice_dim = mapping_info["dim"] + + full_tensor = reconstruct_vllm_tensor(vllm_name) + if full_tensor is None: + continue + + # Create VIEW (not copy) into the fused tensor + if slice_dim == 0: + tensor = full_tensor[slice_start:slice_end] + else: + tensor = full_tensor.narrow(slice_dim, slice_start, slice_end - slice_start) + + tensor.requires_grad_(True) + hf_state_dict[hf_name] = tensor + fused_count += 1 + attached_count += 1 + + else: + # Direct mapping + vllm_name = mapping_info + tensor = reconstruct_vllm_tensor(vllm_name) + if tensor is None: + continue + + tensor.requires_grad_(True) + hf_state_dict[hf_name] = tensor + attached_count += 1 + + except Exception as e: + print(f"[Setup] Failed to attach {hf_name}: {e}", flush=True) + + return hf_state_dict, attached_count, fused_count + + +def _validate_mapping_coverage( + model: torch.nn.Module, + hf_state_dict: dict, + attached_count: int, +) -> None: + """Validate that enough parameters were mapped.""" + hf_param_count = len(list(model.named_parameters())) + mapping_coverage = attached_count / hf_param_count if hf_param_count > 0 else 0 + + print(f"[Setup] Mapping coverage: {attached_count} tensors for {hf_param_count} parameters") + + if mapping_coverage < 0.90: + unmapped_params = set(model.state_dict().keys()) - set(hf_state_dict.keys()) + warning_msg = f"[Setup] WARNING: Low mapping coverage ({mapping_coverage:.1%})\n" + warning_msg += f"Unmapped parameters ({len(unmapped_params)}):\n" + for name in list(unmapped_params)[:20]: + warning_msg += f" - {name}\n" + print(warning_msg) + + if mapping_coverage < 0.50: + raise RuntimeError( + f"[Setup] CRITICAL: Only {mapping_coverage:.1%} of parameters mapped!" + ) + else: + print(f"[Setup] βœ“ Good mapping coverage ({mapping_coverage:.1%})") + + +def _initialize_meta_tensors( + model: torch.nn.Module, + device: str, + config: TrainingConfig, +) -> None: + """Initialize any remaining meta tensors after loading.""" + meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"] + meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"] + + if config.debug_loading: + print(f"\n[DIAGNOSTIC] Meta params: {len(meta_params)}, Meta buffers: {len(meta_buffers)}") + + def get_parent_and_name(model, full_name): + parts = full_name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + return parent, parts[-1] + + meta_count = 0 + + # Initialize meta parameters + for name in meta_params: + param = dict(model.named_parameters()).get(name) + if param is None: + continue + + try: + new_data = torch.zeros(param.shape, dtype=param.dtype, device=device) + new_param = torch.nn.Parameter(new_data, requires_grad=param.requires_grad) + parent, attr_name = get_parent_and_name(model, name) + setattr(parent, attr_name, new_param) + meta_count += 1 + except Exception as e: + if config.debug_loading: + print(f"[DIAGNOSTIC] FAILED to initialize {name}: {e}") + + # Initialize meta buffers + for name in meta_buffers: + buffer = dict(model.named_buffers()).get(name) + if buffer is None: + continue + + try: + if "inv_freq" in name: + dim = buffer.shape[0] * 2 + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + new_buffer = inv_freq.to(dtype=buffer.dtype, device=device) + else: + new_buffer = torch.zeros(buffer.shape, dtype=buffer.dtype, device=device) + + parent, attr_name = get_parent_and_name(model, name) + parent.register_buffer(attr_name, new_buffer) + meta_count += 1 + except Exception as e: + if config.debug_loading: + print(f"[DIAGNOSTIC] FAILED to initialize buffer {name}: {e}") + + print(f"\n[Setup] Initialized {meta_count} remaining meta tensors") + + +def _validate_no_meta_tensors(model: torch.nn.Module) -> None: + """Ensure no parameters or buffers are still on meta device.""" + final_meta_params = [name for name, p in model.named_parameters() if p.device.type == "meta"] + final_meta_buffers = [name for name, b in model.named_buffers() if b.device.type == "meta"] + + if final_meta_params or final_meta_buffers: + error_msg = "[Setup] CRITICAL ERROR: Some tensors are still on meta device!\n" + error_msg += "The model would produce GARBAGE output.\n\n" + + if final_meta_params: + error_msg += f"Meta parameters ({len(final_meta_params)}):\n" + for name in final_meta_params[:20]: + error_msg += f" - {name}\n" + + if final_meta_buffers: + error_msg += f"\nMeta buffers ({len(final_meta_buffers)}):\n" + for name in final_meta_buffers[:20]: + error_msg += f" - {name}\n" + + raise RuntimeError(error_msg) + + +def _create_vllm_to_hf_mapping( + model: torch.nn.Module, + ipc_handles: dict, + debug: bool = False, +) -> dict: + """ + Create mapping from HuggingFace parameter names to vLLM tensor names. + + Handles fused layers: + - qkv_proj (vLLM) = q_proj + k_proj + v_proj (HF) + - gate_up_proj (vLLM) = gate_proj + up_proj (HF) + """ + hf_params = set(model.state_dict().keys()) + vllm_params = set(ipc_handles.keys()) + + # Get model config for dimension calculations + model_config = model.config + hidden_size = getattr(model_config, "hidden_size", 4096) + num_attention_heads = getattr(model_config, "num_attention_heads", 32) + num_key_value_heads = getattr(model_config, "num_key_value_heads", num_attention_heads) + intermediate_size = getattr(model_config, "intermediate_size", hidden_size * 4) + head_dim = hidden_size // num_attention_heads + + # QKV sizes + q_size = hidden_size + k_size = num_key_value_heads * head_dim + v_size = num_key_value_heads * head_dim + + if debug: + print(f"[Mapping] Model config: hidden={hidden_size}, heads={num_attention_heads}, " + f"kv_heads={num_key_value_heads}, intermediate={intermediate_size}") + + mapping = {} + + def find_vllm_name(hf_name: str) -> Optional[str]: + if hf_name in vllm_params: + return hf_name + if not hf_name.startswith("model."): + candidate = f"model.{hf_name}" + if candidate in vllm_params: + return candidate + if hf_name.startswith("model."): + candidate = hf_name[6:] + if candidate in vllm_params: + return candidate + return None + + def find_fused_source(hf_name: str, fused_suffix: str) -> Optional[str]: + for unfused in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]: + if unfused in hf_name: + fused_name = hf_name.replace(unfused, fused_suffix) + found = find_vllm_name(fused_name) + if found: + return found + return None + + for hf_name in hf_params: + # Try direct match first + vllm_name = find_vllm_name(hf_name) + if vllm_name: + mapping[hf_name] = vllm_name + continue + + # Check for QKV fusion + if any(x in hf_name for x in ["q_proj", "k_proj", "v_proj"]): + fused_name = find_fused_source(hf_name, "qkv_proj") + if fused_name: + if "q_proj" in hf_name: + start, end = 0, q_size + elif "k_proj" in hf_name: + start, end = q_size, q_size + k_size + else: + start, end = q_size + k_size, q_size + k_size + v_size + + mapping[hf_name] = { + "source": fused_name, + "slice": (start, end), + "dim": 0, + "type": "qkv_fusion", + } + continue + + # Check for Gate/Up fusion + if any(x in hf_name for x in ["gate_proj", "up_proj"]): + fused_name = find_fused_source(hf_name, "gate_up_proj") + if fused_name: + if "gate_proj" in hf_name: + start, end = 0, intermediate_size + else: + start, end = intermediate_size, intermediate_size * 2 + + mapping[hf_name] = { + "source": fused_name, + "slice": (start, end), + "dim": 0, + "type": "gate_up_fusion", + } + continue + + if debug: + direct = sum(1 for v in mapping.values() if isinstance(v, str)) + fused = sum(1 for v in mapping.values() if isinstance(v, dict)) + print(f"[Mapping] Total: {len(mapping)} mapped ({direct} direct, {fused} fused)") + + return mapping + diff --git a/example_trainer/scripts/run_comparison.sh b/example_trainer/scripts/run_comparison.sh new file mode 100755 index 00000000..754573f7 --- /dev/null +++ b/example_trainer/scripts/run_comparison.sh @@ -0,0 +1,366 @@ +#!/bin/bash +# ============================================================================== +# GRPO Training Mode Comparison Script +# ============================================================================== +# Runs all three training modes (Legacy, Shared vLLM, LoRA) in parallel +# on an 8-GPU node for comparison. +# +# GPU Allocation: +# - GPUs 0-1: Legacy mode (trainer manages vLLM) +# - GPUs 2-3: Shared vLLM mode (CUDA IPC single-copy) +# - GPUs 4-5: LoRA mode (adapter training) +# - GPUs 6-7: Reserved +# +# Port Allocation: +# - Legacy: API 8001, vLLM 9001 +# - Shared vLLM: API 8002, vLLM 9002 +# - LoRA: API 8003, vLLM 9003 +# +# Usage: +# ./run_comparison.sh # Default 50 steps, logs to ./comparison_ +# ./run_comparison.sh 100 # 100 steps +# LOGDIR=/my/path ./run_comparison.sh # Custom log directory +# +# ============================================================================== +# OUTPUT DIRECTORY STRUCTURE ($LOGDIR): +# ============================================================================== +# +# $LOGDIR/ +# β”œβ”€β”€ api_legacy.log # run-api server log (port 8001) +# β”œβ”€β”€ api_shared.log # run-api server log (port 8002) +# β”œβ”€β”€ api_lora.log # run-api server log (port 8003) +# β”œβ”€β”€ env_legacy.log # gsm8k environment log +# β”œβ”€β”€ env_shared.log # gsm8k environment log +# β”œβ”€β”€ env_lora.log # gsm8k environment log +# β”œβ”€β”€ vllm_shared.log # vLLM server log (shared mode) +# β”œβ”€β”€ vllm_lora.log # vLLM server log (lora mode) +# β”œβ”€β”€ trainer_legacy.log # GRPO trainer log (MAIN OUTPUT) +# β”œβ”€β”€ trainer_shared.log # GRPO trainer log (MAIN OUTPUT) +# β”œβ”€β”€ trainer_lora.log # GRPO trainer log (MAIN OUTPUT) +# β”œβ”€β”€ vllm_bridge_config.json # CUDA IPC config (shared mode) +# β”œβ”€β”€ pids.txt # Process IDs for cleanup +# β”œβ”€β”€ checkpoints_legacy/ # Model checkpoints +# β”‚ β”œβ”€β”€ step_3/ +# β”‚ β”œβ”€β”€ step_6/ +# β”‚ └── final_model/ +# β”œβ”€β”€ checkpoints_shared/ # Model checkpoints +# β”‚ └── final_model/ +# └── checkpoints_lora/ # LoRA adapter checkpoints +# β”œβ”€β”€ adapter_step_3/ +# β”œβ”€β”€ adapter_step_6/ +# β”œβ”€β”€ final_adapter/ +# └── tokenizer/ +# +# ============================================================================== + +set -e + +# Configuration +export MODEL="${MODEL:-Qwen/Qwen2.5-3B-Instruct}" +export TRAINING_STEPS="${1:-50}" +export BATCH_SIZE="${BATCH_SIZE:-2}" +export LOGDIR="${LOGDIR:-./comparison_$(date +%Y%m%d_%H%M%S)}" + +mkdir -p $LOGDIR + +echo "==============================================" +echo "GRPO Training Mode Comparison" +echo "==============================================" +echo "Model: $MODEL" +echo "Steps: $TRAINING_STEPS" +echo "Batch size: $BATCH_SIZE" +echo "Log dir: $LOGDIR" +echo "==============================================" +echo "" + +# Cleanup function +cleanup() { + echo "" + echo "Cleaning up processes..." + if [ -f "$LOGDIR/pids.txt" ]; then + source $LOGDIR/pids.txt + kill $LEGACY_TRAINER_PID $LEGACY_ENV_PID $LEGACY_API_PID 2>/dev/null || true + kill $SHARED_TRAINER_PID $SHARED_VLLM_PID $SHARED_ENV_PID $SHARED_API_PID 2>/dev/null || true + kill $LORA_TRAINER_PID $LORA_VLLM_PID $LORA_ENV_PID $LORA_API_PID 2>/dev/null || true + fi + pkill -f "gsm8k_server.py.*800[123]" 2>/dev/null || true + pkill -f "run_api.*800[123]" 2>/dev/null || true + echo "Cleanup complete." +} +trap cleanup EXIT + +# Kill any existing processes on our ports +echo "Killing any existing processes on ports 8001-8003, 9001-9003..." +pkill -f "vllm_api_server.py.*900[123]" 2>/dev/null || true +pkill -f "gsm8k_server.py.*800[123]" 2>/dev/null || true +pkill -f "run_api.*800[123]" 2>/dev/null || true +sleep 2 + +# ============================================================================== +# MODE 1: LEGACY (GPUs 0-1, API 8001, vLLM 9001) +# ============================================================================== +echo "" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "[1/3] LEGACY MODE (GPUs 0-1, API:8001, vLLM:9001)" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Start run-api server for Legacy +echo " Starting run-api server..." +run-api --port 8001 > $LOGDIR/api_legacy.log 2>&1 & +LEGACY_API_PID=$! +echo " βœ“ run-api started (PID: $LEGACY_API_PID, port 8001)" +sleep 3 + +# Start environment server for Legacy +echo " Starting environment server..." +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9001/v1 \ + --server.port 8001 \ + > $LOGDIR/env_legacy.log 2>&1 & +LEGACY_ENV_PID=$! +echo " βœ“ Environment server started (PID: $LEGACY_ENV_PID)" +sleep 5 + +# Start Legacy trainer (it will manage its own vLLM) +echo " Starting trainer..." +CUDA_VISIBLE_DEVICES=0,1 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode none \ + --vllm-port 9001 \ + --atropos-url http://localhost:8001 \ + --training-steps $TRAINING_STEPS \ + --batch-size $BATCH_SIZE \ + --save-path $LOGDIR/checkpoints_legacy \ + --benchmark \ + > $LOGDIR/trainer_legacy.log 2>&1 & +LEGACY_TRAINER_PID=$! +echo " βœ“ Trainer started (PID: $LEGACY_TRAINER_PID)" + +# ============================================================================== +# MODE 2: SHARED_VLLM (GPUs 2-3, API 8002, vLLM 9002) +# ============================================================================== +echo "" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "[2/3] SHARED_VLLM MODE (GPUs 2-3, API:8002, vLLM:9002)" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Start run-api server for Shared +echo " Starting run-api server..." +run-api --port 8002 > $LOGDIR/api_shared.log 2>&1 & +SHARED_API_PID=$! +echo " βœ“ run-api started (PID: $SHARED_API_PID, port 8002)" +sleep 3 + +# Start vLLM with shared weights +echo " Starting vLLM with shared weights..." +VLLM_ENABLE_SHARED_WEIGHTS=1 LOGDIR=$LOGDIR \ +CUDA_VISIBLE_DEVICES=2 python example_trainer/vllm_api_server.py \ + --model $MODEL \ + --port 9002 \ + --gpu-memory-utilization 0.45 \ + > $LOGDIR/vllm_shared.log 2>&1 & +SHARED_VLLM_PID=$! +echo " βœ“ vLLM started (PID: $SHARED_VLLM_PID, port 9002)" +echo " Waiting for vLLM to initialize (30s)..." +sleep 30 + +# Start environment server for Shared +echo " Starting environment server..." +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9002/v1 \ + --server.port 8002 \ + > $LOGDIR/env_shared.log 2>&1 & +SHARED_ENV_PID=$! +echo " βœ“ Environment server started (PID: $SHARED_ENV_PID)" +sleep 5 + +# Start Shared vLLM trainer +echo " Starting trainer..." +CUDA_VISIBLE_DEVICES=2 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode shared_vllm \ + --vllm-port 9002 \ + --vllm-config-path $LOGDIR/vllm_bridge_config.json \ + --atropos-url http://localhost:8002 \ + --training-steps $TRAINING_STEPS \ + --batch-size $BATCH_SIZE \ + --save-path $LOGDIR/checkpoints_shared \ + --benchmark \ + > $LOGDIR/trainer_shared.log 2>&1 & +SHARED_TRAINER_PID=$! +echo " βœ“ Trainer started (PID: $SHARED_TRAINER_PID)" + +# ============================================================================== +# MODE 3: LORA (GPUs 4-5, API 8003, vLLM 9003) +# ============================================================================== +echo "" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +echo "[3/3] LORA MODE (GPUs 4-5, API:8003, vLLM:9003)" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Start run-api server for LoRA +echo " Starting run-api server..." +run-api --port 8003 > $LOGDIR/api_lora.log 2>&1 & +LORA_API_PID=$! +echo " βœ“ run-api started (PID: $LORA_API_PID, port 8003)" +sleep 3 + +# Start vLLM with LoRA support +echo " Starting vLLM with LoRA support..." +CUDA_VISIBLE_DEVICES=4 python example_trainer/vllm_api_server.py \ + --model $MODEL \ + --port 9003 \ + --gpu-memory-utilization 0.45 \ + --enable-lora \ + --max-lora-rank 32 \ + --enforce-eager \ + > $LOGDIR/vllm_lora.log 2>&1 & +LORA_VLLM_PID=$! +echo " βœ“ vLLM started (PID: $LORA_VLLM_PID, port 9003)" +echo " Waiting for vLLM to initialize (30s)..." +sleep 30 + +# Start environment server for LoRA +echo " Starting environment server..." +python environments/gsm8k_server.py serve \ + --slurm.num_gpus 0 \ + --env.tokenizer_name $MODEL \ + --openai.base_url http://localhost:9003/v1 \ + --server.port 8003 \ + > $LOGDIR/env_lora.log 2>&1 & +LORA_ENV_PID=$! +echo " βœ“ Environment server started (PID: $LORA_ENV_PID)" +sleep 5 + +# Start LoRA trainer +echo " Starting trainer..." +CUDA_VISIBLE_DEVICES=5 python -m example_trainer.grpo \ + --model-name $MODEL \ + --weight-bridge-mode lora_only \ + --vllm-port 9003 \ + --atropos-url http://localhost:8003 \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps $TRAINING_STEPS \ + --batch-size $BATCH_SIZE \ + --save-path $LOGDIR/checkpoints_lora \ + --benchmark \ + > $LOGDIR/trainer_lora.log 2>&1 & +LORA_TRAINER_PID=$! +echo " βœ“ Trainer started (PID: $LORA_TRAINER_PID)" + +# ============================================================================== +# Save PIDs and Monitor +# ============================================================================== +cat > $LOGDIR/pids.txt << EOF +LEGACY_TRAINER_PID=$LEGACY_TRAINER_PID +LEGACY_ENV_PID=$LEGACY_ENV_PID +LEGACY_API_PID=$LEGACY_API_PID +SHARED_TRAINER_PID=$SHARED_TRAINER_PID +SHARED_VLLM_PID=$SHARED_VLLM_PID +SHARED_ENV_PID=$SHARED_ENV_PID +SHARED_API_PID=$SHARED_API_PID +LORA_TRAINER_PID=$LORA_TRAINER_PID +LORA_VLLM_PID=$LORA_VLLM_PID +LORA_ENV_PID=$LORA_ENV_PID +LORA_API_PID=$LORA_API_PID +EOF + +echo "" +echo "==============================================" +echo "All components started!" +echo "==============================================" +echo "" +echo "πŸ“‚ Log directory: $LOGDIR" +echo "" +echo "πŸ“Š Monitor progress:" +echo " tail -f $LOGDIR/trainer_legacy.log" +echo " tail -f $LOGDIR/trainer_shared.log" +echo " tail -f $LOGDIR/trainer_lora.log" +echo "" +echo "πŸ” Or watch all at once:" +echo " tail -f $LOGDIR/trainer_*.log" +echo "" +echo "πŸ“‹ Check API servers:" +echo " curl http://localhost:8001/info" +echo " curl http://localhost:8002/info" +echo " curl http://localhost:8003/info" +echo "" +echo "πŸ“ Process IDs saved to: $LOGDIR/pids.txt" +echo "" +echo "Waiting for all trainers to complete..." +echo "(This may take a while depending on training steps)" +echo "" + +# Wait for trainers to complete +wait $LEGACY_TRAINER_PID 2>/dev/null && echo " βœ“ Legacy trainer completed" || echo " βœ— Legacy trainer failed" +wait $SHARED_TRAINER_PID 2>/dev/null && echo " βœ“ Shared vLLM trainer completed" || echo " βœ— Shared vLLM trainer failed" +wait $LORA_TRAINER_PID 2>/dev/null && echo " βœ“ LoRA trainer completed" || echo " βœ— LoRA trainer failed" + +# ============================================================================== +# Print Results +# ============================================================================== +echo "" +echo "==============================================" +echo "COMPARISON RESULTS" +echo "==============================================" +echo "" + +echo "πŸ“Š LEGACY MODE:" +echo "─────────────────────────────────────────────" +grep -A 15 "BENCHMARK SUMMARY" $LOGDIR/trainer_legacy.log 2>/dev/null || echo " (check $LOGDIR/trainer_legacy.log)" +echo "" + +echo "πŸ“Š SHARED VLLM MODE:" +echo "─────────────────────────────────────────────" +grep -A 15 "BENCHMARK SUMMARY" $LOGDIR/trainer_shared.log 2>/dev/null || echo " (check $LOGDIR/trainer_shared.log)" +echo "" + +echo "πŸ“Š LORA MODE:" +echo "─────────────────────────────────────────────" +grep -A 15 "BENCHMARK SUMMARY" $LOGDIR/trainer_lora.log 2>/dev/null || echo " (check $LOGDIR/trainer_lora.log)" +echo "" + +echo "==============================================" +echo "πŸ“ ALL OUTPUT SAVED TO: $LOGDIR" +echo "==============================================" +echo "" +echo "πŸ“‹ LOG FILES:" +echo " Trainers (main output):" +echo " $LOGDIR/trainer_legacy.log" +echo " $LOGDIR/trainer_shared.log" +echo " $LOGDIR/trainer_lora.log" +echo "" +echo " vLLM servers:" +echo " $LOGDIR/vllm_shared.log" +echo " $LOGDIR/vllm_lora.log" +echo "" +echo " Environment servers:" +echo " $LOGDIR/env_legacy.log" +echo " $LOGDIR/env_shared.log" +echo " $LOGDIR/env_lora.log" +echo "" +echo " API servers:" +echo " $LOGDIR/api_legacy.log" +echo " $LOGDIR/api_shared.log" +echo " $LOGDIR/api_lora.log" +echo "" +echo "πŸ’Ύ CHECKPOINTS:" +echo " Legacy: $LOGDIR/checkpoints_legacy/final_model/" +echo " Shared vLLM: $LOGDIR/checkpoints_shared/final_model/" +echo " LoRA: $LOGDIR/checkpoints_lora/final_adapter/" +echo "" +echo "πŸ”§ OTHER:" +echo " Process IDs: $LOGDIR/pids.txt" +echo " IPC Config: $LOGDIR/vllm_bridge_config.json" +echo "==============================================" +echo "" +echo "To re-run or inspect later:" +echo " export LOGDIR=$LOGDIR" +echo " tail -f \$LOGDIR/trainer_*.log" +echo "" +echo "Done!" diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py new file mode 100644 index 00000000..8cfc9251 --- /dev/null +++ b/example_trainer/trainers.py @@ -0,0 +1,438 @@ +""" +Training mode implementations for GRPO trainer. + +Contains the three main training modes: +- train_legacy: Checkpoint-based training with vLLM restarts +- train_shared_vllm: Single-copy mode with CUDA IPC +- train_lora: LoRA adapter training with hot-swap +""" + +import os +import time +from typing import Optional + +import requests +import torch +from torch.optim import AdamW + +from .api import check_atropos_api, register_trainer +from .checkpointing import save_checkpoint, save_lora_checkpoint +from .config import TrainingConfig +from .data import get_data +from .model import load_model_and_tokenizer, PEFT_AVAILABLE +from .training import ( + finalize_training, + log_metrics, + run_training_step, + setup_wandb, +) +from .vllm_manager import ( + check_vllm_health, + check_vllm_process_health, + launch_vllm_server, + terminate_vllm_process, + set_vllm_process, +) + + +def train_legacy(config: TrainingConfig): + """ + Legacy GRPO training with periodic vLLM restarts. + + This mode: + 1. Trains model on trainer GPU + 2. Saves checkpoints periodically + 3. Restarts vLLM to load new weights + + Use for: + - Simple setup + - When trainer and vLLM on different GPUs + """ + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + model, tokenizer = load_model_and_tokenizer(config) + optimizer = AdamW(model.parameters(), lr=config.lr) + + print(f"\n{'='*60}") + print("LEGACY MODE (checkpoint + vLLM restart)") + print(f"{'='*60}") + print(f"Training for {config.training_steps} steps on {config.device}") + print(f"vLLM restart interval: every {config.vllm_restart_interval} steps") + print(f"Save path: {config.save_path}") + print(f"{'='*60}\n") + + os.makedirs(config.save_path, exist_ok=True) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # Launch initial vLLM server + vllm_proc = launch_vllm_server(config, config.model_name) + set_vllm_process(vllm_proc) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data + data_fetch_start = time.time() + if len(batches) == 0: + batches = get_data(config.batch_size, config.seq_len, config.atropos_url) + token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Check if we should sync (save checkpoint + restart vLLM) + should_sync = (step + 1) % config.vllm_restart_interval == 0 or step == config.training_steps - 1 + if should_sync: + terminate_vllm_process() + + # Training step + step_start = time.time() + metrics = run_training_step( + model, optimizer, + token_batches, label_batches, advantage_batches, temperature_batches, + config, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # Sync (checkpoint + restart) + sync_time = 0 + if should_sync: + sync_start = time.time() + checkpoint_path = save_checkpoint(model, tokenizer, config.save_path, step + 1) + torch.cuda.empty_cache() + vllm_proc = launch_vllm_server(config, checkpoint_path) + set_vllm_process(vllm_proc) + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + + # Update metrics + metrics.update({ + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + }) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + check_vllm_process_health() + + # === Cleanup === + save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) + finalize_training(use_wandb, training_start_time, "legacy", config.training_steps, benchmark_stats, config.benchmark) + + +def train_shared_vllm(config: TrainingConfig): + """ + GRPO training with shared vLLM weights (single-copy mode). + + This mode: + 1. Attaches to vLLM's weight tensors via CUDA IPC + 2. optimizer.step() modifies vLLM's weights in-place + 3. vLLM immediately uses updated weights (no restart!) + + Requirements: + - vLLM running with VLLM_ENABLE_SHARED_WEIGHTS=1 + - Trainer on same GPU(s) as vLLM + """ + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print(f"\n{'='*60}") + print("SINGLE-COPY MODE (CUDA IPC)") + print(">>> TRUE shared memory - only ONE model copy!") + print(">>> Trainer uses vLLM's tensors directly!") + print(f"{'='*60}") + print(f"Model: {config.model_name}") + print(f"Save path: {config.save_path}") + print(f"{'='*60}\n") + + # Attach to vLLM's shared tensors + print("[1/2] Attaching to vLLM's shared tensors...") + model, tokenizer = load_model_and_tokenizer(config, single_copy=True) + + if model is None: + raise RuntimeError( + "Single-copy mode failed. Make sure:\n" + "1. vLLM is running with VLLM_ENABLE_SHARED_WEIGHTS=1\n" + "2. Trainer is on the SAME GPUs as vLLM\n" + "3. vllm_bridge_config.json exists with IPC handles" + ) + + optimizer = AdamW(model.parameters(), lr=config.lr) + + print(f"[2/2] Starting training for {config.training_steps} steps") + print("NOTE: vLLM sees weight updates immediately after each step!") + print("-" * 60) + + os.makedirs(config.save_path, exist_ok=True) + + # Check Atropos API + print(f"\n[Setup] Connecting to Atropos API at {config.atropos_url}...") + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data + data_fetch_start = time.time() + if len(batches) == 0: + batches = get_data(config.batch_size, config.seq_len, config.atropos_url) + token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step + step_start = time.time() + metrics = run_training_step( + model, optimizer, + token_batches, label_batches, advantage_batches, temperature_batches, + config, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # In single-copy mode, weights are updated in-place (no sync needed!) + sync_time = 0.0 + print(f" [SINGLE-COPY] Weights updated in-place - step {step+1}") + benchmark_stats["sync_times"].append(sync_time) + + # Update metrics + metrics.update({ + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + }) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # Periodic checkpoint (for recovery, not for vLLM sync) + if (step + 1) % config.vllm_restart_interval == 0: + save_checkpoint(model, tokenizer, config.save_path, step + 1) + + # === Cleanup === + save_checkpoint(model, tokenizer, config.save_path, config.training_steps, is_final=True) + finalize_training(use_wandb, training_start_time, "shared_vllm", config.training_steps, benchmark_stats, config.benchmark) + + +def train_lora(config: TrainingConfig): + """ + GRPO training with LoRA adapters. + + This mode: + 1. Freezes base model, trains only LoRA adapter weights + 2. Saves lightweight adapter checkpoints + 3. Hot-swaps adapters in vLLM via API + + Benefits: + - Much faster training (fewer parameters) + - Smaller checkpoints + - Adapters can be hot-swapped without restart + + Requirements: + - External vLLM server running with --enable-lora + """ + if not PEFT_AVAILABLE: + raise RuntimeError("PEFT library required for LoRA mode. Install with: pip install peft") + + training_start_time = time.time() + + # === Setup === + use_wandb = setup_wandb(config) + + print(f"\n{'='*60}") + print("LORA MODE (adapter-only training)") + print(f"{'='*60}") + print(f"Base model: {config.model_name}") + print(f"LoRA config: r={config.lora_r}, alpha={config.lora_alpha}") + print(f"Save path: {config.save_path}") + print(f"vLLM port: {config.vllm_port}") + print(f"{'='*60}\n") + + # Check external vLLM server + print("[1/3] Checking external vLLM server...") + if not check_vllm_health(config.vllm_port): + print(f"\nERROR: vLLM server not running on port {config.vllm_port}") + print("\nLoRA mode requires an external vLLM server. Start it first:") + print(f" python example_trainer/vllm_api_server.py --model {config.model_name} " + f"--port {config.vllm_port} --enable-lora --enforce-eager") + raise RuntimeError(f"External vLLM server required on port {config.vllm_port}") + print(f"vLLM server healthy on port {config.vllm_port}") + + # Load model with LoRA adapters + print("[2/3] Loading model with LoRA adapters...") + model, tokenizer = load_model_and_tokenizer(config) + + # Only optimize LoRA parameters + trainable_params = [p for p in model.parameters() if p.requires_grad] + optimizer = AdamW(trainable_params, lr=config.lr) + + print(f"[3/3] Starting training for {config.training_steps} steps") + print("-" * 60) + + os.makedirs(config.save_path, exist_ok=True) + + # Check Atropos API + if not check_atropos_api(url=config.atropos_url, timeout=30): + raise RuntimeError(f"Atropos API not reachable at {config.atropos_url}") + register_trainer(config) + + # === Benchmark tracking === + benchmark_stats = { + "step_times": [], + "sync_times": [], + "data_fetch_times": [], + "gpu_memories": [], + } + + # === Training Loop === + batches = [] + for step in range(config.training_steps): + print(f"\nStep {step+1}/{config.training_steps}") + + # Fetch data + data_fetch_start = time.time() + if len(batches) == 0: + batches = get_data(config.batch_size, config.seq_len, config.atropos_url) + token_batches, label_batches, advantage_batches, temperature_batches = batches.pop(0) + data_fetch_time = time.time() - data_fetch_start + benchmark_stats["data_fetch_times"].append(data_fetch_time) + + # Training step + step_start = time.time() + metrics = run_training_step( + model, optimizer, + token_batches, label_batches, advantage_batches, temperature_batches, + config, + ) + step_time = time.time() - step_start + benchmark_stats["step_times"].append(step_time) + + # GPU memory tracking + gpu_mem_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + gpu_mem_reserved_gb = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 + benchmark_stats["gpu_memories"].append(gpu_mem_gb) + + # Periodic adapter save + hot-swap + sync_time = 0 + should_sync = (step + 1) % config.vllm_restart_interval == 0 + if should_sync: + sync_start = time.time() + adapter_path = save_lora_checkpoint(model, config.save_path, step + 1) + _hotswap_lora_adapter(config.vllm_port, adapter_path, f"step_{step + 1}") + sync_time = time.time() - sync_start + benchmark_stats["sync_times"].append(sync_time) + + # Update metrics + metrics.update({ + "step_time": step_time, + "sync_time": sync_time, + "data_fetch_time": data_fetch_time, + "gpu_memory_gb": gpu_mem_gb, + "gpu_memory_reserved_gb": gpu_mem_reserved_gb, + }) + + log_metrics(metrics, step + 1, use_wandb, benchmark=config.benchmark) + + # === Cleanup === + final_sync_start = time.time() + final_adapter_path = save_lora_checkpoint(model, config.save_path, config.training_steps, is_final=True) + _hotswap_lora_adapter(config.vllm_port, final_adapter_path, "final") + final_sync_time = time.time() - final_sync_start + benchmark_stats["sync_times"].append(final_sync_time) + + finalize_training(use_wandb, training_start_time, "lora_only", config.training_steps, benchmark_stats, config.benchmark) + + # Save tokenizer + tokenizer_path = os.path.join(config.save_path, "tokenizer") + tokenizer.save_pretrained(tokenizer_path) + print(f"Tokenizer saved to {tokenizer_path}") + + +def _hotswap_lora_adapter( + port: int, + adapter_path: str, + adapter_name: Optional[str] = None, +) -> bool: + """ + Request vLLM to hot-swap to a new LoRA adapter. + + Tries: + 1. Native vLLM endpoint: /v1/load_lora_adapter + 2. Custom endpoint: /lora/load + """ + base_url = f"http://localhost:{port}" + name = adapter_name or os.path.basename(adapter_path) + + # Try native vLLM endpoint first + try: + response = requests.post( + f"{base_url}/v1/load_lora_adapter", + json={"lora_name": name, "lora_path": adapter_path}, + timeout=30, + ) + if response.status_code == 200: + print(f" [LORA] βœ“ Hot-swapped adapter: {name}") + return True + except Exception: + pass + + # Try custom endpoint + try: + response = requests.post( + f"{base_url}/lora/load", + json={"adapter_path": adapter_path, "adapter_name": name}, + timeout=30, + ) + if response.status_code == 200: + print(f" [LORA] βœ“ Hot-swapped adapter via custom API: {name}") + return True + else: + print(f" [LORA] βœ— Hot-swap failed: {response.text}") + return False + except Exception as e: + print(f" [LORA] βœ— Hot-swap request failed: {e}") + return False + diff --git a/example_trainer/training.py b/example_trainer/training.py new file mode 100644 index 00000000..860c15db --- /dev/null +++ b/example_trainer/training.py @@ -0,0 +1,355 @@ +""" +Training utilities for GRPO trainer. + +Contains loss computation, training step logic, and metric logging. +""" + +import random +import string +import time +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +import wandb + +from .config import TrainingConfig + + +def setup_wandb(config: TrainingConfig) -> bool: + """ + Initialize Weights & Biases logging if enabled. + + Args: + config: Training configuration + + Returns: + True if wandb is active, False otherwise + """ + if not config.use_wandb: + return False + + if not config.wandb_project: + print("Warning: wandb_project not set, disabling wandb.") + return False + + # Generate random group name if not provided + if not config.wandb_group: + config.wandb_group = "".join( + random.choices(string.ascii_letters + string.digits, k=8) + ) + + try: + wandb.init( + project=config.wandb_project, + group=config.wandb_group, + config=config.dict(), + ) + print( + f"Wandb logging enabled. Run: {wandb.run.name} " + f"(Project: {config.wandb_project})" + ) + return True + except Exception as e: + print(f"Error initializing wandb: {e}. Disabling wandb.") + return False + + +def compute_grpo_loss( + model: torch.nn.Module, + tokens: torch.Tensor, + labels: torch.Tensor, + advantages: torch.Tensor, + temperatures: torch.Tensor, + gradient_accumulation_steps: int, +) -> Tuple[torch.Tensor, dict]: + """ + Compute GRPO (Group Relative Policy Optimization) loss for a single micro-batch. + + The GRPO loss encourages the model to: + - Increase probability for tokens with positive advantages + - Decrease probability for tokens with negative advantages + + Args: + model: The model to compute loss for + tokens: Input token IDs [batch, seq_len] + labels: Target labels [batch, seq_len], -100 for masked positions + advantages: Advantage values [batch, 1] + temperatures: Temperature values [batch, 1, 1] + gradient_accumulation_steps: Number of accumulation steps (for scaling) + + Returns: + Tuple of (loss tensor, metrics dict) + """ + # Forward pass + outputs = model(tokens) + logits = outputs.logits + + # Temperature scaling + t = temperatures.to(logits.device, logits.dtype) + t = torch.where(t <= 0, torch.ones_like(t), t) + logits = logits / t + + # Log probabilities per token + logp_per_token = -F.cross_entropy( + logits.view(-1, logits.size(-1)), + labels.view(-1), + reduction="none", + ignore_index=-100, + ).view(labels.shape) + + # Masking based on labels != -100 + mask = (labels != -100).float() + + # Compute metrics (no grad needed) + with torch.no_grad(): + pos = (advantages > 0).float() + neg = (advantages <= 0).float() + mask_float = mask.to(logp_per_token.dtype) + mask_sum = mask_float.sum(dim=-1).clamp_min(1e-8) + avg_logp = (logp_per_token * mask_float).sum(dim=-1) / mask_sum + pos_logp = (logp_per_token * pos).mean().item() + neg_logp = (logp_per_token * neg).mean().item() + + # GRPO loss: weighted log probabilities by advantages + grpo_loss_term = torch.exp(logp_per_token - logp_per_token.detach()) + grpo_loss = ( + ((-grpo_loss_term * mask).sum(-1) / mask.sum(-1)) + * advantages.to(logp_per_token.device) + ).mean() / gradient_accumulation_steps + + metrics = { + "pos_logp": pos_logp, + "neg_logp": neg_logp, + "avg_logp": avg_logp, + "pos_count": pos.sum().item(), + "neg_count": neg.sum().item(), + } + + return grpo_loss, metrics + + +def run_training_step( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + token_batches: List[torch.Tensor], + label_batches: List[torch.Tensor], + advantage_batches: List[torch.Tensor], + temperature_batches: List[torch.Tensor], + config: TrainingConfig, +) -> dict: + """ + Run a single training step with gradient accumulation. + + Performs: + 1. Forward pass through all micro-batches + 2. Backward pass with gradient accumulation + 3. Gradient clipping + 4. Optimizer step + + Args: + model: The model to train + optimizer: The optimizer + token_batches: List of token tensors (micro-batches) + label_batches: List of label tensors + advantage_batches: List of advantage tensors + temperature_batches: List of temperature tensors + config: Training configuration + + Returns: + Dict of training metrics for this step + """ + total_loss = 0.0 + total_pos_logp = 0.0 + total_neg_logp = 0.0 + total_pos = 0.0 + total_neg = 0.0 + grad_norm = 0.0 + + # Accumulate gradients over micro-batches + for tokens, labels, advantages, temperatures in zip( + token_batches, label_batches, advantage_batches, temperature_batches + ): + tokens = tokens.to(config.device) + labels = labels.to(config.device) + advantages = advantages.to(config.device) + + loss, metrics = compute_grpo_loss( + model, + tokens, + labels, + advantages, + temperatures, + config.gradient_accumulation_steps, + ) + + loss.backward() + total_loss += loss.item() + total_pos_logp += metrics["pos_logp"] + total_neg_logp += metrics["neg_logp"] + total_pos += metrics["pos_count"] + total_neg += metrics["neg_count"] + + # Gradient clipping and optimizer step + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + optimizer.zero_grad() + + # Normalize metrics by count + num_batches = len(token_batches) if token_batches else 1 + if total_pos > 0: + total_pos_logp /= num_batches + if total_neg > 0: + total_neg_logp /= num_batches + + return { + "loss": total_loss, + "grad_norm": grad_norm.item() if hasattr(grad_norm, 'item') else grad_norm, + "pos_logp": total_pos_logp, + "neg_logp": total_neg_logp, + "pos_count": total_pos, + "neg_count": total_neg, + } + + +def log_metrics( + metrics: dict, + step: int, + use_wandb: bool, + extra_metrics: Optional[dict] = None, + benchmark: bool = False, +) -> None: + """ + Log training metrics to console and optionally wandb. + + Args: + metrics: Dict of metrics from training step + step: Current step number + use_wandb: Whether to log to wandb + extra_metrics: Optional additional metrics to log + benchmark: Whether to show timing/benchmark info + """ + # Build timing string (only if benchmark enabled) + timing_str = "" + if benchmark: + if "step_time" in metrics: + timing_str += f", Step time: {metrics['step_time']:.2f}s" + if "sync_time" in metrics and metrics["sync_time"] > 0: + timing_str += f", Sync time: {metrics['sync_time']:.2f}s" + if "data_fetch_time" in metrics: + timing_str += f", Data fetch: {metrics['data_fetch_time']:.2f}s" + if "gpu_memory_gb" in metrics: + timing_str += f", GPU mem: {metrics['gpu_memory_gb']:.2f}GB" + + # Show loss with more precision since GRPO loss is often very small + loss_str = ( + f"{metrics['loss']:.6f}" + if abs(metrics["loss"]) < 0.01 + else f"{metrics['loss']:.4f}" + ) + print(f" Loss: {loss_str}, Grad norm: {metrics['grad_norm']:.4f}{timing_str}") + + # Show GRPO-specific metrics if available + if "pos_count" in metrics or "neg_count" in metrics: + pos_count = metrics.get("pos_count", 0) + neg_count = metrics.get("neg_count", 0) + pos_logp = metrics.get("pos_logp", 0) + neg_logp = metrics.get("neg_logp", 0) + print( + f" Advantages: +{int(pos_count)} / -{int(neg_count)}, " + f"LogP: pos={pos_logp:.3f}, neg={neg_logp:.3f}" + ) + + if use_wandb: + log_dict = { + "train/loss": metrics["loss"], + "train/grad_norm": metrics["grad_norm"], + "train/pos_logp": metrics.get("pos_logp", 0), + "train/neg_logp": metrics.get("neg_logp", 0), + } + # Add timing metrics if present + for key in ["step_time", "sync_time", "data_fetch_time", + "gpu_memory_gb", "gpu_memory_reserved_gb"]: + if key in metrics: + log_dict[f"train/{key}"] = metrics[key] + if extra_metrics: + log_dict.update(extra_metrics) + wandb.log(log_dict, step=step) + + +def finalize_training( + use_wandb: bool, + training_start_time: Optional[float] = None, + mode: str = "unknown", + total_steps: int = 0, + benchmark_stats: Optional[dict] = None, + benchmark: bool = False, +) -> None: + """ + Clean up after training and log benchmark summary. + + Args: + use_wandb: Whether wandb is enabled + training_start_time: Start time of training + mode: Training mode name + total_steps: Total steps completed + benchmark_stats: Dict with lists of per-step metrics + benchmark: Whether to print benchmark summary to console + """ + print("\nTraining finished.") + + if benchmark_stats is None: + benchmark_stats = {} + + if training_start_time is not None: + total_time = time.time() - training_start_time + peak_gpu_mem_gb = ( + torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0 + ) + + # Calculate averages from collected stats + step_times = benchmark_stats.get("step_times", []) + sync_times = benchmark_stats.get("sync_times", []) + data_fetch_times = benchmark_stats.get("data_fetch_times", []) + gpu_memories = benchmark_stats.get("gpu_memories", []) + + avg_step_time = sum(step_times) / len(step_times) if step_times else 0 + total_step_time = sum(step_times) + avg_sync_time = sum(sync_times) / len(sync_times) if sync_times else 0 + total_sync_time = sum(sync_times) + avg_data_fetch = sum(data_fetch_times) / len(data_fetch_times) if data_fetch_times else 0 + total_data_fetch = sum(data_fetch_times) + avg_gpu_mem = sum(gpu_memories) / len(gpu_memories) if gpu_memories else 0 + + if benchmark: + print(f"\n{'='*70}") + print(f"BENCHMARK SUMMARY ({mode})") + print(f"{'='*70}") + print(f" Total training time: {total_time:.2f}s ({total_time/60:.2f} min)") + print(f" Total steps: {total_steps}") + print(" ") + print(" TIMING BREAKDOWN:") + print(f" Avg step time: {avg_step_time:.2f}s") + print(f" Total step time: {total_step_time:.2f}s") + print(f" Avg sync time: {avg_sync_time:.2f}s (x{len(sync_times)} syncs)") + print(f" Total sync time: {total_sync_time:.2f}s") + print(f" Avg data fetch time: {avg_data_fetch:.2f}s") + print(f" Total data fetch time: {total_data_fetch:.2f}s") + print(" ") + print(" MEMORY:") + print(f" Peak GPU memory: {peak_gpu_mem_gb:.2f} GB") + print(f" Avg GPU memory: {avg_gpu_mem:.2f} GB") + print(f"{'='*70}\n") + + if use_wandb: + wandb.summary["benchmark/total_time_seconds"] = total_time + wandb.summary["benchmark/total_time_minutes"] = total_time / 60 + wandb.summary["benchmark/mode"] = mode + wandb.summary["benchmark/total_steps"] = total_steps + wandb.summary["benchmark/avg_step_time_seconds"] = avg_step_time + wandb.summary["benchmark/peak_gpu_memory_gb"] = peak_gpu_mem_gb + wandb.summary["benchmark/avg_gpu_memory_gb"] = avg_gpu_mem + wandb.finish() + elif use_wandb: + wandb.finish() + diff --git a/example_trainer/vllm_manager.py b/example_trainer/vllm_manager.py new file mode 100644 index 00000000..1942977b --- /dev/null +++ b/example_trainer/vllm_manager.py @@ -0,0 +1,232 @@ +""" +vLLM process management for GRPO trainer. + +Handles launching, monitoring, and terminating vLLM server processes +for legacy mode training. +""" + +import atexit +import os +import subprocess +import time +from typing import Optional + +import requests + +from .config import TrainingConfig + + +# Global variable to keep track of the vLLM process +_vllm_process: Optional[subprocess.Popen] = None + + +def cleanup_vllm(): + """Cleanup function to terminate vLLM on exit.""" + global _vllm_process + if _vllm_process: + print("\nTerminating vLLM process...") + _vllm_process.terminate() + try: + _vllm_process.wait(timeout=5) + print("vLLM process terminated.") + except subprocess.TimeoutExpired: + print("vLLM process did not terminate gracefully, killing.") + _vllm_process.kill() + _vllm_process.wait() + print("vLLM process killed.") + _vllm_process = None + + +# Register cleanup on module load +atexit.register(cleanup_vllm) + + +def launch_vllm_server( + config: TrainingConfig, + model_path: str, +) -> Optional[subprocess.Popen]: + """ + Launch a vLLM server process using our custom vllm_api_server.py. + + Uses the custom server instead of standard vLLM because: + - Standard vLLM only has /v1/completions (OpenAI-compatible) + - Our custom server has /generate endpoint needed by VLLMServer class + - This allows proper tokens_and_logprobs_completion support + + Args: + config: Training configuration + model_path: Path to model checkpoint + + Returns: + Popen process object, or None if launch failed + """ + global _vllm_process + + # Use our custom vllm_api_server.py + script_dir = os.path.dirname(os.path.abspath(__file__)) + custom_server_path = os.path.join(script_dir, "vllm_api_server.py") + + vllm_command = [ + "python", + custom_server_path, + "--model", + model_path, + "--port", + str(config.vllm_port), + "--gpu-memory-utilization", + str(config.vllm_gpu_memory_utilization), + ] + + # Add served-model-name if using checkpoint path + if model_path != config.model_name: + vllm_command.extend(["--served-model-name", config.model_name]) + + print(f" Launching vLLM: {' '.join(vllm_command)}") + + try: + proc = subprocess.Popen(vllm_command) + print(f" vLLM launched with PID: {proc.pid}") + + # Check for immediate startup errors + try: + proc.communicate(timeout=2) + if proc.returncode is not None and proc.returncode != 0: + print(" WARNING: vLLM failed to start") + return None + except subprocess.TimeoutExpired: + print(" vLLM process started (check logs for details)") + + _vllm_process = proc + return proc + + except FileNotFoundError: + print(" ERROR: vLLM not found. Is it installed?") + return None + except Exception as e: + print(f" ERROR launching vLLM: {e}") + return None + + +def terminate_vllm_process() -> None: + """Terminate the running vLLM process if any.""" + global _vllm_process + + if _vllm_process is None: + return + + print(" Terminating vLLM process...") + _vllm_process.terminate() + try: + _vllm_process.wait(timeout=5) + except subprocess.TimeoutExpired: + print(" vLLM did not terminate gracefully, killing...") + _vllm_process.kill() + _vllm_process.wait() + _vllm_process = None + + +def check_vllm_process_health() -> None: + """Check if vLLM process terminated unexpectedly.""" + global _vllm_process + + if _vllm_process is not None and _vllm_process.poll() is not None: + print(f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})") + _vllm_process = None + + +def get_vllm_process() -> Optional[subprocess.Popen]: + """Get the current vLLM process.""" + return _vllm_process + + +def set_vllm_process(proc: Optional[subprocess.Popen]) -> None: + """Set the vLLM process (for external management).""" + global _vllm_process + _vllm_process = proc + + +def check_vllm_health(port: int) -> bool: + """ + Check if vLLM server is healthy and responding. + + Args: + port: Port the vLLM server is running on + + Returns: + True if server is healthy + """ + try: + response = requests.get(f"http://localhost:{port}/health", timeout=5) + return response.status_code == 200 + except Exception: + return False + + +def wait_for_vllm_ready(port: int, timeout: float = 120.0) -> bool: + """ + Wait for vLLM server to be ready. + + Args: + port: Port the vLLM server is running on + timeout: Maximum time to wait in seconds + + Returns: + True if server is ready, False if timeout + """ + print(f" Waiting for vLLM to be ready (port {port})...") + start_time = time.time() + + while time.time() - start_time < timeout: + if check_vllm_health(port): + print(" vLLM is ready!") + return True + time.sleep(2) + + print(f" WARNING: vLLM not ready after {timeout}s") + return False + + +def hotswap_lora_adapter( + adapter_name: str, + adapter_path: str, + port: int, +) -> bool: + """ + Hot-swap a LoRA adapter on a running vLLM server. + + Uses the vLLM /v1/load_lora_adapter endpoint to load a new adapter + without restarting the server. + + Args: + adapter_name: Name to identify the adapter + adapter_path: Path to the adapter checkpoint + port: vLLM server port + + Returns: + True if hot-swap succeeded + """ + try: + # Use vLLM's native LoRA loading endpoint + response = requests.post( + f"http://localhost:{port}/v1/load_lora_adapter", + json={ + "lora_name": adapter_name, + "lora_path": adapter_path, + }, + timeout=30, + ) + + if response.status_code == 200: + print(f" [LORA] βœ“ Hot-swapped adapter: {adapter_name} ({adapter_path})") + return True + else: + print(f" [LORA] βœ— Hot-swap failed: {response.status_code} - {response.text}") + return False + + except requests.exceptions.ConnectionError: + print(f" [LORA] βœ— Cannot connect to vLLM at port {port}") + return False + except Exception as e: + print(f" [LORA] βœ— Error during hot-swap: {e}") + return False +