diff --git a/example_trainer/README.md b/example_trainer/README.md index 6da70c93..b5e7c9a7 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -9,12 +9,14 @@ The trainer supports three weight synchronization modes: | Mode | Description | Sync Latency | Best For | |------|-------------|--------------|----------| | **Legacy** (`none`) | Save checkpoints, restart vLLM | ~30-60 seconds | Simple setups, debugging | -| **Shared vLLM** (`shared_vllm`) | Direct shared memory updates | ~0 ms | Production, maximum throughput | +| **Shared vLLM** (`shared_vllm`) | Direct shared memory updates via NCCL | ~0 ms | Production, maximum throughput | | **LoRA** (`lora_only`) | Train adapters, hot-swap | ~1-5 seconds | Memory-constrained, fast iteration | --- -## Quick Start with GSM8k +## Quick Start with GSM8k (Shared vLLM Mode) + +This is the **recommended** production setup for maximum training throughput. ### Prerequisites @@ -29,892 +31,455 @@ pip install datasets latex2sympy2_extended math_verify ### Architecture Overview ``` -┌─────────────────────────────────────────────────────────────────┐ -│ Training Setup │ -│ │ -│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ -│ │ GSM8k Env │───▶│ Atropos API │◀───│ GRPO Trainer │ │ -│ │ (problems) │ │ (batching) │ │ (optimization) │ │ -│ └─────────────┘ └─────────────┘ └─────────────────────┘ │ -│ │ │ │ -│ │ │ │ -│ ▼ ▼ │ -│ ┌─────────────────────────────────────────────────────────┐ │ -│ │ vLLM Inference Server │ │ -│ │ (generates rollouts for scoring) │ │ -│ └─────────────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ SHARED VLLM TRAINING ARCHITECTURE │ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────────────┐ │ +│ │ GSM8k Env │───▶│ Atropos API │◀───│ GRPO Trainer (GPU 2) │ │ +│ │ (problems) │ │ (batching) │ │ - Loads model for training │ │ +│ └─────────────┘ └─────────────┘ │ - Broadcasts weights via NCCL │ │ +│ │ └─────────────────────────────────┘ │ +│ │ │ │ +│ │ │ NCCL Broadcast │ +│ ▼ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ vLLM Inference Server (GPUs 0-1) │ │ +│ │ - Model weights in shared memory │ │ +│ │ - Weight updater threads receive NCCL updates │ │ +│ │ - Generates rollouts for scoring │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────────────────┘ ``` -### IMPORTANT: Startup Order (Same for ALL Modes!) +### Step-by-Step Guide (Tested & Working) -All three training modes use the **same startup order**: - -``` -1. Atropos API (python -m atroposlib.cli.run_api) - ↓ -2. vLLM Server (python example_trainer/vllm_api_server.py) - ↓ wait 60-90s for model load -3. GRPO Trainer (python example_trainer/grpo.py) - ↓ -4. GSM8k Environment (python environments/gsm8k_server.py serve ...) -``` - -**Why this order?** -- The API must be running before the trainer or environment tries to connect -- vLLM must be loaded before GSM8k tries to generate rollouts -- The trainer must be running before GSM8k sends scored batches -- GSM8k is started last because it immediately begins generating work - -### GSM8k CLI Arguments (Required for All Modes) - -When starting the GSM8k environment, always include these arguments: - -```bash -python environments/gsm8k_server.py serve \ - --slurm False \ - --openai.model_name Qwen/Qwen2.5-3B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct -``` - -| Argument | Required | Description | -|----------|----------|-------------| -| `--slurm False` | Yes | Disable SLURM mode for local runs | -| `--openai.model_name` | Yes | Model name (must match vLLM) | -| `--openai.base_url` | Yes | vLLM server URL with `/v1` suffix | -| `--openai.server_type vllm` | **Yes** | Must be `vllm` for `/generate` endpoint | -| `--env.tokenizer_name` | Yes | Tokenizer for environment | - -**Note:** `--openai.server_type vllm` is required because only the `VLLMServer` class supports `tokens_and_logprobs_completion` which GSM8k needs. +**IMPORTANT: GPU Allocation** +- vLLM runs on GPUs 0-1 (tensor-parallel) +- Trainer runs on GPU 2 (separate to avoid OOM) --- -## Mode 1: Legacy (Checkpoint + Restart) +#### Step 1: Kill Any Existing Processes + +```bash +pkill -9 -u $USER -f "vllm|grpo|python|run-api" 2>/dev/null; sleep 3 +``` + +#### Step 2: Setup Directory + +```bash +cd ~/atropos_stuff/atropos +rm -f vllm_bridge_config.json vllm.log trainer.log api.log gsm8k.log +``` + +#### Step 3: Set Environment Variables + +```bash +export VLLM_ENABLE_SHARED_WEIGHTS=1 +export NUM_INFERENCE_NODES=0 +export MASTER_ADDR=localhost +export MASTER_PORT=29500 +``` + +#### Step 4: Start Atropos API -This mode saves checkpoints periodically and can restart vLLM with updated weights. - -### Startup Order (Same for All Modes!) - -``` -┌────────────────────────────────────────────────────────────────┐ -│ 1. Atropos API → Coordinates environments + trainer │ -│ 2. vLLM Server → Serves inference requests │ -│ 3. GRPO Trainer → Trains model, fetches batches │ -│ 4. GSM8k Environment → Generates problems, scores rollouts │ -└────────────────────────────────────────────────────────────────┘ -``` - -### Step-by-Step Guide - -**Step 1: Start the Atropos API** ```bash -cd atropos python -m atroposlib.cli.run_api > api.log 2>&1 & -sleep 5 +echo "Atropos API started" +sleep 3 ``` -**Step 2: Start the vLLM Server** +#### Step 5: Start GSM8K Environment + ```bash -cd atropos -python example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-3B-Instruct \ - --port 9001 \ - --gpu-memory-utilization 0.30 \ - > vllm.log 2>&1 & -sleep 90 # Wait for model to load - -# Verify vLLM is ready -curl -s http://localhost:9001/health && echo "vLLM ready!" +python environments/gsm8k_server.py > gsm8k.log 2>&1 & +echo "GSM8K environment started" +sleep 3 ``` -**Step 3: Start the GRPO Trainer** +#### Step 6: Start vLLM Server on GPUs 0-1 + ```bash -cd atropos -python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode none \ - --training-steps 100 \ - --vllm-restart-interval 10 \ - --vllm-port 9001 \ - --vllm-gpu-memory-utilization 0.30 \ - --batch-size 2 \ - --gradient-accumulation-steps 16 \ - --lr 1e-5 \ - --save-path checkpoints_legacy \ - --use-wandb \ - --wandb-project gsm8k-grpo \ - > trainer.log 2>&1 & -sleep 10 +CUDA_VISIBLE_DEVICES=0,1 python -u example_trainer/vllm_api_server.py \ + --model Qwen/Qwen2.5-14B-Instruct \ + --tensor-parallel-size 2 \ + --port 9001 \ + --dtype bfloat16 \ + > vllm.log 2>&1 & +echo "vLLM starting on GPUs 0,1..." ``` -**Step 4: Start the GSM8k Environment** +#### Step 7: Wait for vLLM to Load + ```bash -cd atropos -python environments/gsm8k_server.py serve \ - --slurm False \ - --openai.model_name Qwen/Qwen2.5-3B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \ - > gsm8k.log 2>&1 & +tail -f vllm.log ``` -**Monitor Training:** +Wait until you see: `Uvicorn running on http://0.0.0.0:9001` + +Then press **Ctrl+C** to stop tailing. + +#### Step 8: Verify Shared Memory Setup + +```bash +grep -E "thread|updater|Exported|Shared memory" vllm.log +``` + +You should see: +``` +[vLLM Patch] ✓ Shared memory setup complete! +[vLLM Patch] ✓ Weight updater thread started (name: WeightUpdater_TP0) +[vLLM Patch] ✓ Weight updater thread started (name: WeightUpdater_TP1) +``` + +#### Step 9: Start Trainer on GPU 2 + +```bash +CUDA_VISIBLE_DEVICES=2 python -u example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-14B-Instruct \ + --weight-bridge-mode shared_vllm \ + --vllm-port 9001 \ + --lr 1e-6 \ + --batch-size 4 \ + --training-steps 100 \ + --use-shared-memory \ + 2>&1 | tee trainer.log +``` + +#### Step 10: Monitor Training + ```bash tail -f trainer.log ``` +You should see: +``` +[Bridge] ✓ Gloo group created +[Bridge] ✓ NCCL group created +[Bridge] ✓ All ranks synchronized and ready +[Bridge] Mapped 195/339 params from vLLM to trainer +Step 1/100 +``` + +--- + ### Quick Copy-Paste (All-in-One) ```bash -cd atropos && \ -pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \ -python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \ -python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \ -python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode none --training-steps 100 --vllm-port 9001 --vllm-gpu-memory-utilization 0.30 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-5 --save-path checkpoints_legacy --use-wandb --wandb-project gsm8k-grpo > trainer.log 2>&1 & sleep 10 && \ -python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \ -tail -f trainer.log +# Kill everything and setup +pkill -9 -u $USER -f "vllm|grpo|python|run-api" 2>/dev/null; sleep 3 +cd ~/atropos_stuff/atropos +rm -f vllm_bridge_config.json vllm.log trainer.log api.log gsm8k.log + +# Environment variables +export VLLM_ENABLE_SHARED_WEIGHTS=1 NUM_INFERENCE_NODES=0 MASTER_ADDR=localhost MASTER_PORT=29500 + +# Start services +python -m atroposlib.cli.run_api > api.log 2>&1 & +sleep 3 +python environments/gsm8k_server.py > gsm8k.log 2>&1 & +sleep 3 +CUDA_VISIBLE_DEVICES=0,1 python -u example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-14B-Instruct --tensor-parallel-size 2 --port 9001 --dtype bfloat16 > vllm.log 2>&1 & + +echo "Waiting for vLLM to load... (check: tail -f vllm.log)" +echo "Once ready, run the trainer command below:" +echo "" +echo "CUDA_VISIBLE_DEVICES=2 python -u example_trainer/grpo.py --model-name Qwen/Qwen2.5-14B-Instruct --weight-bridge-mode shared_vllm --vllm-port 9001 --lr 1e-6 --batch-size 4 --training-steps 100 --use-shared-memory 2>&1 | tee trainer.log" ``` -### What Happens - -1. vLLM server starts and loads `Qwen/Qwen2.5-3B-Instruct` -2. Trainer loads its own copy of the model for training -3. GSM8k env sends problems → vLLM generates solutions → scores sent to API -4. Trainer fetches scored batches from API, computes GRPO loss, updates weights -5. Every N steps: save checkpoint (weights stay in sync via external vLLM) -6. Repeat until done - -### Pros & Cons - -+ Simple conceptually -+ Easy to debug -+ Uses custom vLLM server with full endpoint support -- 2x GPU memory (trainer + vLLM both load model) -- Requires external vLLM to be running - --- -## Mode 2: Shared vLLM Bridge (In-Place Updates) +## How Shared vLLM Mode Works -This mode supports two sub-modes: +### The Problem +Traditional RL training requires syncing model weights between the trainer and inference server. This is slow: +- Save checkpoint → Load into vLLM → Restart server = **30-60 seconds per sync** -1. **HTTP Notification Mode** (default): Trainer notifies vLLM after weight updates -2. **NCCL Shared Memory Mode** (`--use-shared-memory`): Weights broadcast via NCCL to vLLM's daemon +### Two Solutions Available -### Step-by-Step Guide +#### Option 1: Broadcast Mode (`--use-shared-memory`) +Two copies of the model, but instant NCCL sync. Use when trainer is on **different GPUs**. -**Step 1: Start the Atropos API** -```bash -cd atropos -python -m atroposlib.cli.run_api > api.log 2>&1 & -sleep 5 +``` +Trainer (GPU 2) NCCL vLLM Workers (GPUs 0-1) + │ │ │ + │ optimizer.step() │ │ + │ ─────────────────────────────────────────────► │ + │ broadcast_weights() │ │ Thread receives + │ │ │ weights via NCCL + │ │ │ Copies to shared + │ │ │ memory tensors + │ │ │ + │ Next training step │ │ Ready for inference ``` -**Step 2: Start the vLLM Server with Bridge Support** +- **Memory**: 2x model size (trainer copy + vLLM copy) +- **Sync Latency**: ~0ms (NCCL broadcast) +- **GPU Layout**: Trainer on different GPUs than vLLM -For HTTP notification mode: -```bash -cd atropos -export LOGDIR=/tmp/atropos_bridge -export NUM_INFERENCE_NODES=0 -mkdir -p $LOGDIR +#### Option 2: Single-Copy Mode (`--single-copy`) ⭐ RECOMMENDED +TRUE shared memory - only ONE copy of the model! Use when trainer is on **same GPUs**. -python example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-3B-Instruct \ - --port 9001 \ - --gpu-memory-utilization 0.30 \ - > vllm.log 2>&1 & -sleep 90 - -curl -s http://localhost:9001/health && echo "vLLM ready!" +``` +┌────────────────────────────────────────────────────────────┐ +│ SAME GPU(s) │ +│ │ +│ ┌──────────────────────────────────────────────────┐ │ +│ │ SHARED MODEL TENSORS │ │ +│ │ (only ONE copy in GPU memory!) │ │ +│ └──────────────────────────────────────────────────┘ │ +│ ▲ ▲ │ +│ │ Reads/Writes │ Reads │ +│ ┌────────┴───────┐ ┌────────┴───────┐ │ +│ │ Trainer │ │ vLLM │ │ +│ │ (gradients) │ │ (inference) │ │ +│ └────────────────┘ └────────────────┘ │ +│ │ │ +│ │ optimizer.step() │ +│ │ (updates shared tensors in-place) │ +│ ▼ │ +│ vLLM immediately sees new weights! │ +└────────────────────────────────────────────────────────────┘ ``` -For NCCL shared memory mode (requires patched vLLM): -```bash -cd atropos -export LOGDIR=/tmp/atropos_bridge -export NUM_INFERENCE_NODES=0 -export VLLM_ENABLE_SHARED_WEIGHTS=1 # Enable shared memory patches -mkdir -p $LOGDIR +- **Memory**: 1x model size (truly shared via CUDA IPC!) +- **Sync Latency**: 0ms (same memory, no copy needed) +- **GPU Layout**: Trainer on SAME GPUs as vLLM (required!) -python example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-3B-Instruct \ - --port 9001 \ - --gpu-memory-utilization 0.30 \ - > vllm.log 2>&1 & -sleep 90 +### When to Use Which -curl -s http://localhost:9001/health && echo "vLLM ready!" -``` +| Mode | Memory | Sync | Use When | +|------|--------|------|----------| +| **Broadcast** (`--use-shared-memory`) | 2x model | ~0ms NCCL | Trainer on different GPUs | +| **Single-Copy** (`--single-copy`) | 1x model | 0ms | Trainer on same GPUs, memory constrained | -**Step 3: Start the GRPO Trainer in Shared Mode** - -For HTTP notification mode: -```bash -cd atropos -export LOGDIR=/tmp/atropos_bridge -export NUM_INFERENCE_NODES=0 - -python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode shared_vllm \ - --num-inference-nodes 0 \ - --training-steps 100 \ - --vllm-port 9001 \ - --batch-size 2 \ - --gradient-accumulation-steps 16 \ - --lr 1e-5 \ - --save-path checkpoints_shared \ - --use-wandb \ - --wandb-project gsm8k-grpo-shared \ - > trainer.log 2>&1 & -sleep 10 -``` - -For NCCL shared memory mode (add `--use-shared-memory`): -```bash -python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode shared_vllm \ - --use-shared-memory \ - --num-inference-nodes 0 \ - --training-steps 100 \ - --vllm-port 9001 \ - --batch-size 2 \ - --gradient-accumulation-steps 16 \ - --lr 1e-5 \ - --save-path checkpoints_shared \ - > trainer.log 2>&1 & -``` - -**Step 4: Start the GSM8k Environment** -```bash -cd atropos -python environments/gsm8k_server.py serve \ - --slurm False \ - --openai.model_name Qwen/Qwen2.5-3B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \ - > gsm8k.log 2>&1 & -``` - -**Monitor Training:** -```bash -tail -f trainer.log -``` - -### Quick Copy-Paste (All-in-One) +### Single-Copy Mode Usage ```bash -cd atropos && \ -pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \ -export LOGDIR=/tmp/atropos_bridge && export NUM_INFERENCE_NODES=0 && mkdir -p $LOGDIR && \ -python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \ -python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \ -python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode shared_vllm --num-inference-nodes 0 --training-steps 100 --vllm-port 9001 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-5 --save-path checkpoints_shared --use-wandb --wandb-project gsm8k-grpo-shared > trainer.log 2>&1 & sleep 10 && \ -python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \ -tail -f trainer.log +# vLLM and Trainer on SAME GPUs (0,1) +CUDA_VISIBLE_DEVICES=0,1 python -u example_trainer/vllm_api_server.py \ + --model Qwen/Qwen2.5-14B-Instruct \ + --tensor-parallel-size 2 \ + --port 9001 \ + > vllm.log 2>&1 & + +# Wait for vLLM to load... + +# Trainer also on GPUs 0,1 - shares vLLM's tensors! +CUDA_VISIBLE_DEVICES=0,1 python -u example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-14B-Instruct \ + --weight-bridge-mode shared_vllm \ + --single-copy \ + --training-steps 100 \ + 2>&1 | tee trainer.log ``` -### What Happens (HTTP Notification Mode) - -1. vLLM server starts on port 9001 -2. Trainer initializes bridge in LOCAL MODE (HTTP-based) -3. Trainer loads its own model copy and trains normally -4. After each `optimizer.step()`: - - `bridge.notify_update()` sends HTTP POST to vLLM - - Periodic checkpoint saves sync weights to disk -5. Simple setup, suitable for debugging - -### What Happens (NCCL Shared Memory Mode) - -When using `--use-shared-memory` with `VLLM_ENABLE_SHARED_WEIGHTS=1`: - -1. vLLM patches GPUModelRunner to call `share_memory_()` on model weights -2. vLLM spawns a daemon process that joins NCCL groups with the trainer -3. Trainer broadcasts weights via NCCL after each optimizer step -4. Daemon copies weights into shared tensors → vLLM uses them immediately - -This provides true shared memory without separate model copies! - -### What Happens (Distributed Mode - num_inference_nodes>0) - -1. vLLM server starts, writes parameter mapping to `$LOGDIR/vllm_bridge_config.json` -2. Trainer reads mapping, joins NCCL process group with vLLM -3. Trainer's model parameters point to vLLM's GPU tensors (shared memory) -4. Training loop: - - Forward pass uses shared weights - - `optimizer.step()` modifies shared tensors in-place - - `bridge.notify_update()` broadcasts via Gloo - - vLLM immediately uses new weights for next inference -5. No restarts needed! - -### Environment Variables - -| Variable | Description | Example | -|----------|-------------|---------| -| `LOGDIR` | Directory for bridge coordination files | `/tmp/atropos_bridge` | -| `NUM_INFERENCE_NODES` | Number of vLLM nodes (0 = local) | `0` | -| `MASTER_ADDR` | Rendezvous address | `localhost` | -| `MASTER_PORT` | Rendezvous port | `26756` | - -### Pros & Cons - -+ ~0ms sync latency (instant updates) -+ 1x GPU memory (shared tensors) -+ Maximum training throughput -- More complex setup -- Requires compatible vLLM version - --- -## Mode 3: LoRA Adapters (Hot-Swap) +## Alternative Modes -This mode trains only LoRA adapter weights. Much smaller checkpoints, faster iteration. +### Mode 1: Legacy (Checkpoint + Restart) -### Step-by-Step Guide +For simple setups or debugging. Saves checkpoints and can restart vLLM. -**Step 1: Start the Atropos API** ```bash -cd atropos -python -m atroposlib.cli.run_api > api.log 2>&1 & -sleep 5 -``` - -**Step 2: Start the vLLM Server (Required for LoRA Hot-Swap)** -```bash -cd atropos -python example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-3B-Instruct \ - --port 9001 \ - --gpu-memory-utilization 0.30 \ - > vllm.log 2>&1 & -sleep 90 - -curl -s http://localhost:9001/health && echo "vLLM ready!" -``` - -**Step 3: Start the GRPO Trainer in LoRA Mode** -```bash -cd atropos python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode lora_only \ - --lora-r 16 \ - --lora-alpha 32 \ - --lora-dropout 0.05 \ - --lora-target-modules q_proj v_proj \ - --training-steps 100 \ - --vllm-port 9001 \ - --batch-size 2 \ - --gradient-accumulation-steps 16 \ - --lr 1e-4 \ - --save-path checkpoints_lora \ - --use-wandb \ - --wandb-project gsm8k-grpo-lora \ - > trainer.log 2>&1 & -sleep 10 + --model-name Qwen/Qwen2.5-3B-Instruct \ + --weight-bridge-mode none \ + --training-steps 100 \ + --vllm-restart-interval 10 \ + --batch-size 2 \ + --lr 1e-5 ``` -**Step 4: Start the GSM8k Environment** -```bash -cd atropos -python environments/gsm8k_server.py serve \ - --slurm False \ - --openai.model_name Qwen/Qwen2.5-3B-Instruct \ - --openai.base_url http://localhost:9001/v1 \ - --openai.server_type vllm \ - --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct \ - > gsm8k.log 2>&1 & -``` +### Mode 2: LoRA Adapters -**Monitor Training:** -```bash -tail -f trainer.log -``` - -### Quick Copy-Paste (All-in-One) +Trains only adapter weights. Small checkpoints, lower memory. ```bash -cd atropos && \ -pkill -f "grpo.py|vllm|gsm8k|run_api" 2>/dev/null; sleep 3 && \ -python -m atroposlib.cli.run_api > api.log 2>&1 & sleep 5 && \ -python example_trainer/vllm_api_server.py --model Qwen/Qwen2.5-3B-Instruct --port 9001 --gpu-memory-utilization 0.30 > vllm.log 2>&1 & sleep 90 && \ -python example_trainer/grpo.py --model-name Qwen/Qwen2.5-3B-Instruct --weight-bridge-mode lora_only --lora-r 16 --lora-alpha 32 --lora-dropout 0.05 --training-steps 100 --vllm-port 9001 --batch-size 2 --gradient-accumulation-steps 16 --lr 1e-4 --save-path checkpoints_lora --use-wandb --wandb-project gsm8k-grpo-lora > trainer.log 2>&1 & sleep 10 && \ -python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen/Qwen2.5-3B-Instruct --openai.base_url http://localhost:9001/v1 --openai.server_type vllm --env.tokenizer_name Qwen/Qwen2.5-3B-Instruct > gsm8k.log 2>&1 & \ -tail -f trainer.log +python example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-3B-Instruct \ + --weight-bridge-mode lora_only \ + --lora-r 16 \ + --lora-alpha 32 \ + --training-steps 100 \ + --batch-size 2 \ + --lr 1e-4 ``` -### What Happens - -1. Trainer loads base model, wraps with LoRA adapters (PEFT) -2. Only adapter parameters are trainable (~0.1% of total) -3. Training loop updates adapter weights only -4. Every N steps: save adapter checkpoint (small, ~10-50MB) -5. vLLM can hot-swap adapters via `/lora/load` endpoint - -### LoRA Configuration - -| Option | Default | Description | -|--------|---------|-------------| -| `--lora-r` | 16 | Rank of low-rank matrices | -| `--lora-alpha` | 32 | Scaling factor (typically 2x rank) | -| `--lora-dropout` | 0.05 | Dropout for regularization | -| `--lora-target-modules` | `q_proj v_proj` | Which layers to adapt | - -### Common Target Module Combinations - -```bash -# Minimal (fastest training) ---lora-target-modules q_proj v_proj - -# Attention only ---lora-target-modules q_proj k_proj v_proj o_proj - -# Full (most expressive) ---lora-target-modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj -``` - -### Pros & Cons - -+ Much faster training (fewer parameters) -+ Tiny checkpoints (~10-50MB vs ~6GB) -+ Can hot-swap adapters without full restart -+ Lower GPU memory (base model frozen) -- Less expressive than full fine-tuning -- May need higher learning rate - --- ## Configuration Reference -### All CLI Options +### Environment Variables -```bash -python example_trainer/grpo.py --help -``` +| Variable | Required | Description | Example | +|----------|----------|-------------|---------| +| `VLLM_ENABLE_SHARED_WEIGHTS` | Yes (shared mode) | Enable vLLM patching | `1` | +| `NUM_INFERENCE_NODES` | Yes | Number of vLLM nodes (0 = local) | `0` | +| `MASTER_ADDR` | Yes | Rendezvous address | `localhost` | +| `MASTER_PORT` | Yes | Rendezvous port | `29500` | +| `CUDA_VISIBLE_DEVICES` | Recommended | GPU allocation | `0,1` or `2` | -### Core Training Options +### Trainer CLI Options | Option | Default | Description | |--------|---------|-------------| | `--model-name` | (required) | HuggingFace model ID | -| `--lr` | `1e-5` | Learning rate | +| `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` | +| `--use-shared-memory` | `False` | Enable NCCL weight broadcasting | +| `--vllm-port` | `9001` | vLLM server port | | `--training-steps` | `10` | Total optimization steps | | `--batch-size` | `2` | Micro-batch size | -| `--gradient-accumulation-steps` | `32` | Gradient accumulation | -| `--seq-len` | `2048` | Max sequence length | +| `--lr` | `1e-5` | Learning rate | | `--save-path` | `trained_model_checkpoints` | Checkpoint directory | -### vLLM Options +### vLLM Server Options -| Option | Default | Description | -|--------|---------|-------------| -| `--vllm-port` | `9001` | vLLM server port | -| `--vllm-restart-interval` | `3` | Steps between syncs | - -### Weight Bridge Options - -| Option | Default | Description | -|--------|---------|-------------| -| `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` | -| `--trainer-rank` | `0` | Distributed rank | -| `--world-size` | `1` | Total processes | -| `--init-method` | `env://` | PyTorch distributed init | -| `--num-inference-nodes` | `0` | Number of vLLM nodes | - -### Logging Options - -| Option | Default | Description | -|--------|---------|-------------| -| `--use-wandb` | `False` | Enable W&B logging | -| `--wandb-project` | `None` | W&B project name | -| `--wandb-group` | `None` | W&B group name | +| Option | Description | +|--------|-------------| +| `--model` | HuggingFace model ID | +| `--tensor-parallel-size` | Number of GPUs for tensor parallelism | +| `--port` | Server port (default: 9001) | +| `--dtype` | Model dtype (`bfloat16`, `float16`, `auto`) | --- -## Shutdown / Cleanup +## FAQ & Troubleshooting -### Stop All Processes +### Q: The trainer is stuck at "Creating Gloo process group..." + +**A:** This means the trainer is waiting for the vLLM weight updater threads to connect. Check if the threads started: ```bash -# Graceful shutdown -pkill -f "gsm8k_server" -sleep 2 -pkill -f "grpo.py" -sleep 2 -pkill -f "vllm_api_server" -sleep 2 -pkill -f "run_api" - -echo "All processes stopped" +grep -E "thread|updater|ERROR" vllm.log ``` -### Check Running Processes +You should see: +``` +[vLLM Patch] ✓ Weight updater thread started (name: WeightUpdater_TP0) +[vLLM Patch] ✓ Weight updater thread started (name: WeightUpdater_TP1) +``` + +If not, ensure `VLLM_ENABLE_SHARED_WEIGHTS=1` was set **before** starting vLLM. + +--- + +### Q: I get "CUDA out of memory" when starting the trainer + +**A:** The trainer is trying to load the model on the same GPUs as vLLM. Use separate GPUs: ```bash -ps aux | grep -E "(grpo|vllm|gsm8k|run_api)" | grep -v grep +# vLLM on GPUs 0-1 +CUDA_VISIBLE_DEVICES=0,1 python -u example_trainer/vllm_api_server.py ... + +# Trainer on GPU 2 +CUDA_VISIBLE_DEVICES=2 python -u example_trainer/grpo.py ... ``` -### Check GPU Usage +--- +### Q: I see "daemonic processes are not allowed to have children" + +**A:** This was a bug in older versions. The fix uses **threads** instead of **processes** for the weight updater. Make sure you have the latest `patched_gpu_runner.py`. + +--- + +### Q: The `vllm_bridge_config.json` shows `param_mappings: {}` + +**A:** The vLLM patches didn't run. Check: + +1. `VLLM_ENABLE_SHARED_WEIGHTS=1` was set before starting vLLM +2. Look for `[vLLM Patch] ✓ Exported X params` in vllm.log + +```bash +grep "Exported" vllm.log +``` + +--- + +### Q: How do I verify the NCCL connection is working? + +**A:** Check the trainer log for these messages: + +``` +[Bridge] ✓ Gloo group created +[Bridge] ✓ NCCL group created +[Bridge] ✓ All ranks synchronized and ready +``` + +--- + +### Q: What's the difference between Gloo and NCCL? + +**A:** +- **Gloo**: CPU-based coordination protocol. Used for synchronization barriers. +- **NCCL**: GPU-based high-speed protocol. Used for broadcasting weight tensors. + +Both are needed: Gloo for coordination, NCCL for fast tensor transfers. + +--- + +### Q: How do I check GPU memory usage? + +**A:** ```bash nvidia-smi ``` ---- - -## Troubleshooting - -### "CUDA out of memory" - -Try reducing: -```bash ---batch-size 1 \ ---gradient-accumulation-steps 64 \ ---seq-len 1024 \ ---vllm-gpu-memory-utilization 0.25 -``` - -Or use LoRA mode which uses less memory. - -### "Connection refused" to Atropos API - -Make sure the API is running: -```bash -python -m atroposlib.cli.run_api -``` - -### vLLM fails to start - -Check if port 9001 is in use: -```bash -lsof -i :9001 -``` - -Kill existing processes or use a different port: -```bash -pkill -f "vllm_api_server" -# or use different port: ---vllm-port 9002 -``` - -### "NotImplementedError" or "404 Not Found" on `/generate` - -This means you're using the wrong server type. Make sure: - -1. You started `vllm_api_server.py` (NOT standard `vllm serve`) -2. GSM8k uses `--openai.server_type vllm` (NOT `openai`) - -```bash -# CORRECT -python example_trainer/vllm_api_server.py --model ... --port 9001 -python environments/gsm8k_server.py serve --openai.server_type vllm ... - -# WRONG - standard vLLM doesn't have /generate endpoint -python -m vllm.entrypoints.openai.api_server --model ... --port 9001 -``` - -### "Free memory on device is less than desired GPU memory utilization" - -Lower the vLLM memory utilization: - -```bash -python example_trainer/vllm_api_server.py \ - --model Qwen/Qwen2.5-3B-Instruct \ - --port 9001 \ - --gpu-memory-utilization 0.25 # Lower this -``` - -### Bridge mode: "Parameter mapping file not found" - -Ensure `$LOGDIR` is set and vLLM server is running: -```bash -export LOGDIR=/tmp/atropos_bridge -ls $LOGDIR/vllm_bridge_config.json -``` - -### LoRA mode: "PEFT library not available" - -Install PEFT: -```bash -pip install peft -``` - -### No trajectories collected / Workers timing out - -Check that all services are running in the correct order: -```bash -# Check processes -ps aux | grep -E "(run_api|vllm|grpo|gsm8k)" | grep -v grep - -# Check vLLM health -curl http://localhost:9001/health - -# Check API health -curl http://localhost:8000/health -``` - -If vLLM isn't ready, wait longer before starting GSM8k. +Expected for Qwen2.5-14B with shared mode: +- GPUs 0-1: ~168GB each (vLLM workers) +- GPU 2: ~29GB (trainer) --- -## Checkpoint Locations - -### Where Are Trained Models Saved? - -| Mode | Location | Contents | -|------|----------|----------| -| **Legacy** | `trained_model_checkpoints/step_N/` | Full model + tokenizer | -| **Legacy** | `trained_model_checkpoints/final_model/` | Final checkpoint | -| **Shared vLLM** | `trained_model_checkpoints/step_N/` | Full model + tokenizer | -| **LoRA** | `trained_model_checkpoints/adapter_step_N/` | LoRA adapters only (~10-50MB) | -| **LoRA** | `trained_model_checkpoints/final_adapter/` | Final adapter | - -### Customizing Save Path +### Q: How do I stop all processes? +**A:** ```bash -python example_trainer/grpo.py \ - --save-path /path/to/my/checkpoints \ - ... -``` - -### Loading Checkpoints for Inference - -```python -# Full model (Legacy/Shared modes) -from transformers import AutoModelForCausalLM, AutoTokenizer -model = AutoModelForCausalLM.from_pretrained("trained_model_checkpoints/final_model") -tokenizer = AutoTokenizer.from_pretrained("trained_model_checkpoints/final_model") - -# LoRA adapter -from peft import PeftModel, PeftConfig -from transformers import AutoModelForCausalLM - -base_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B-Instruct") -model = PeftModel.from_pretrained(base_model, "trained_model_checkpoints/final_adapter") +pkill -9 -u $USER -f "vllm|grpo|python|run-api" ``` --- -## vLLM Server Requirements +### Q: The training is slow / not progressing -When using `--openai.server_type vllm` or the shared_vllm bridge, your vLLM server must expose these endpoints: +**A:** Check if all services are running: -### Required Endpoints - -| Endpoint | Method | Purpose | Used By | -|----------|--------|---------|---------| -| `/health` | GET | Health check | All modes | -| `/generate` | POST | Native generation with token IDs + logprobs | VLLMServer class | - -### Required `/generate` Request Format - -The vLLM server must handle **both** prompt formats: - -```json -// String prompt (simple) -{ - "prompt": "Hello, world!", - "max_tokens": 100, - "temperature": 1.0, - "logprobs": 1 -} - -// Token ID prompt (used by atroposlib) -{ - "prompt": {"prompt_token_ids": [1, 2, 3, 4, 5]}, - "max_tokens": 100, - "temperature": 1.0, - "logprobs": 1 -} +```bash +ps aux | grep -E "(run_api|vllm|grpo|gsm8k)" | grep $USER ``` -### Required `/generate` Response Format - -```json -{ - "text": ["generated text here"], - "prompt": "original prompt", - "finish_reasons": ["stop"], - "logprobs": [ - [ - [{"12345": -0.5}], - [{"67890": -1.2}] - ] - ], - "prompt_token_ids": [1, 2, 3, 4, 5], - "token_ids": [[12345, 67890, ...]] -} +Check logs for errors: +```bash +tail -20 api.log +tail -20 gsm8k.log +tail -20 vllm.log +tail -20 trainer.log ``` -The `logprobs` field format: `List[List[List[Dict[token_id, logprob]]]]` -- Outer list: per completion (n samples) -- Middle list: per token in completion -- Inner list: contains single dict `{token_id: logprob}` - -### Optional Bridge Endpoints (for shared_vllm mode) - -| Endpoint | Method | Purpose | -|----------|--------|---------| -| `/bridge/info` | GET | Get bridge status | -| `/bridge/notify_update` | POST | Receive weight update notifications | -| `/bridge/state_dict_info` | GET | Get model parameter mappings | - -### Optional LoRA Endpoints (for lora_only mode) - -| Endpoint | Method | Purpose | -|----------|--------|---------| -| `/lora/status` | GET | Get active LoRA adapter | -| `/lora/load` | POST | Load new LoRA adapter | -| `/lora/unload` | POST | Unload current adapter | - -### Using Standard vLLM vs Custom Server - -| Server | Supports `/generate` with logprobs | Supports bridge | Supports LoRA hot-swap | -|--------|-----------------------------------|-----------------|------------------------| -| `vllm serve ...` | ❌ No | ❌ No | ❌ No | -| `vllm_api_server.py` | ✅ Yes | ✅ Yes | ✅ Yes | - -**Use `example_trainer/vllm_api_server.py` for full feature support.** - --- -## Benchmarking Speed & Memory +### Q: How do I use a smaller model for testing? -### Memory Usage Comparison +**A:** Use Qwen2.5-3B-Instruct with single GPU: ```bash -# Run this during training to monitor GPU memory -watch -n 1 nvidia-smi +# vLLM on GPU 0 +CUDA_VISIBLE_DEVICES=0 python -u example_trainer/vllm_api_server.py \ + --model Qwen/Qwen2.5-3B-Instruct \ + --port 9001 \ + > vllm.log 2>&1 & + +# Trainer on GPU 1 +CUDA_VISIBLE_DEVICES=1 python -u example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-3B-Instruct \ + --weight-bridge-mode shared_vllm \ + --use-shared-memory \ + --training-steps 10 \ + 2>&1 | tee trainer.log ``` -**Expected Memory Usage (Qwen2.5-3B-Instruct):** - -| Mode | Trainer GPU | vLLM GPU | Total | -|------|------------|----------|-------| -| **Legacy** | ~8GB | ~8GB | ~16GB (2x model) | -| **Shared vLLM** | ~8GB (shared) | ~8GB (shared) | ~8GB (1x model) | -| **LoRA** | ~10GB (frozen base) | ~8GB | ~18GB | - -### Speed Benchmarking - -Add these measurements to your training script or use the built-in wandb logging: - -```python -import time -import torch - -# Track step times -step_times = [] -sync_times = [] - -for step in range(training_steps): - # Measure training step time - step_start = time.time() - # ... training code ... - step_time = time.time() - step_start - step_times.append(step_time) - - # Measure sync time (Legacy mode only) - if step % vllm_restart_interval == 0: - sync_start = time.time() - # ... checkpoint + restart vLLM ... - sync_time = time.time() - sync_start - sync_times.append(sync_time) - -# Print summary -print(f"Avg step time: {sum(step_times)/len(step_times):.2f}s") -print(f"Avg sync time: {sum(sync_times)/len(sync_times):.2f}s" if sync_times else "No syncs") -``` - -### Benchmark Script - -Create a benchmark comparing modes: - -```bash -#!/bin/bash -# benchmark_modes.sh - -MODEL="Qwen/Qwen2.5-3B-Instruct" -STEPS=50 -BATCH=2 -ACCUM=16 - -echo "=== Benchmarking Legacy Mode ===" -time python example_trainer/grpo.py \ - --model-name $MODEL \ - --weight-bridge-mode none \ - --training-steps $STEPS \ - --batch-size $BATCH \ - --gradient-accumulation-steps $ACCUM \ - --vllm-restart-interval 10 \ - 2>&1 | tee benchmark_legacy.log - -echo "=== Benchmarking Shared vLLM Mode ===" -export LOGDIR=/tmp/bench_shared -export NUM_INFERENCE_NODES=0 -mkdir -p $LOGDIR - -# Start vLLM first -python example_trainer/vllm_api_server.py \ - --model $MODEL --port 9001 --gpu-memory-utilization 0.45 & -VLLM_PID=$! -sleep 60 # Wait for vLLM to load - -time python example_trainer/grpo.py \ - --model-name $MODEL \ - --weight-bridge-mode shared_vllm \ - --training-steps $STEPS \ - --batch-size $BATCH \ - --gradient-accumulation-steps $ACCUM \ - --num-inference-nodes 0 \ - 2>&1 | tee benchmark_shared.log - -kill $VLLM_PID - -echo "=== Benchmarking LoRA Mode ===" -time python example_trainer/grpo.py \ - --model-name $MODEL \ - --weight-bridge-mode lora_only \ - --training-steps $STEPS \ - --batch-size $BATCH \ - --gradient-accumulation-steps $ACCUM \ - --lora-r 16 \ - --vllm-restart-interval 25 \ - 2>&1 | tee benchmark_lora.log - -echo "=== Summary ===" -echo "Check benchmark_*.log for detailed timing" -``` - -### Expected Benchmark Results - -| Metric | Legacy | Shared vLLM | LoRA | -|--------|--------|-------------|------| -| **Step time** | ~2-5s | ~2-5s | ~1-3s | -| **Sync overhead** | ~30-60s every N steps | ~0ms | ~5-10s every N steps | -| **Total time (50 steps, sync every 10)** | ~15-20 min | ~3-5 min | ~5-8 min | -| **Peak GPU memory** | ~16GB | ~8GB | ~10GB | -| **Checkpoint size** | ~6GB | ~6GB | ~50MB | - -### WandB Metrics to Watch - -If using `--use-wandb`, these metrics are logged: - -| Metric | Description | -|--------|-------------| -| `train/loss` | GRPO loss | -| `train/grad_norm` | Gradient norm | -| `train/pos_logp` | Log prob of positive examples | -| `train/neg_logp` | Log prob of negative examples | -| `train/step_time` | Time per training step | -| `train/sync_time` | Time for weight sync (legacy/lora) | - --- ## Files in This Directory @@ -922,54 +487,80 @@ If using `--use-wandb`, these metrics are logged: | File | Description | |------|-------------| | `grpo.py` | Main trainer script with all modes | -| `vllm_api_server.py` | Custom vLLM server with bridge endpoints | -| `vllm_weight_bridge.py` | Shared memory bridge implementation | +| `vllm_api_server.py` | Custom vLLM server with shared memory patches | +| `vllm_weight_bridge.py` | NCCL bridge for weight synchronization | +| `vllm_patching/` | vLLM patches for shared memory support | | `requirements.txt` | Python dependencies | | `README.md` | This documentation | +### vllm_patching/ Directory + +| File | Description | +|------|-------------| +| `__init__.py` | Module exports | +| `patched_gpu_runner.py` | Patches GPUModelRunner for shared memory | +| `weight_updater.py` | Thread that receives NCCL weight broadcasts | +| `distributed_utils.py` | Process group initialization helpers | + --- -## Example Runs +## Performance Comparison -### Quick Test (Legacy Mode) +| Mode | Sync Latency | Memory (14B model) | Best For | +|------|--------------|-------------------|----------| +| **Legacy** | 30-60s | 2x model | Debugging | +| **Shared vLLM** | ~0ms | 1x model (shared) + trainer | Production | +| **LoRA** | 5-10s | 1x model + adapters | Memory-constrained | + +--- + +## Checkpoint Locations + +| Mode | Location | Size | +|------|----------|------| +| Legacy | `trained_model_checkpoints/step_N/` | ~28GB (14B model) | +| Shared vLLM | `trained_model_checkpoints/step_N/` | ~28GB | +| LoRA | `trained_model_checkpoints/adapter_step_N/` | ~50MB | + +--- + +## Example Training Runs + +### Quick Test (3B model, LoRA) ```bash -# Minimal test to verify setup works python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-0.5B-Instruct \ - --training-steps 5 \ - --batch-size 1 \ - --gradient-accumulation-steps 4 + --model-name Qwen/Qwen2.5-3B-Instruct \ + --weight-bridge-mode lora_only \ + --training-steps 5 \ + --batch-size 1 ``` -### Full GSM8k Training (LoRA Mode) +### Production (14B model, Shared vLLM) ```bash -# Recommended for single-GPU training -python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode lora_only \ - --lora-r 32 \ - --lora-alpha 64 \ - --training-steps 500 \ - --batch-size 2 \ - --gradient-accumulation-steps 32 \ - --lr 5e-5 \ - --use-wandb \ - --wandb-project gsm8k-lora +# See Step-by-Step Guide above +CUDA_VISIBLE_DEVICES=2 python -u example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-14B-Instruct \ + --weight-bridge-mode shared_vllm \ + --use-shared-memory \ + --training-steps 1000 \ + --batch-size 4 \ + --lr 1e-6 ``` -### Production (Shared vLLM Mode) +### Multi-GPU Training (70B model) ```bash -# Maximum throughput setup -export LOGDIR=/tmp/atropos_bridge -export NUM_INFERENCE_NODES=0 +# vLLM on GPUs 0-3 (tensor parallel 4) +CUDA_VISIBLE_DEVICES=0,1,2,3 python -u example_trainer/vllm_api_server.py \ + --model Qwen/Qwen2.5-72B-Instruct \ + --tensor-parallel-size 4 \ + --port 9001 \ + > vllm.log 2>&1 & -python example_trainer/grpo.py \ - --model-name Qwen/Qwen2.5-3B-Instruct \ - --weight-bridge-mode shared_vllm \ - --training-steps 1000 \ - --batch-size 4 \ - --gradient-accumulation-steps 16 \ - --lr 1e-5 \ - --use-wandb \ - --wandb-project gsm8k-shared +# Trainer on GPUs 4-5 +CUDA_VISIBLE_DEVICES=4,5 python -u example_trainer/grpo.py \ + --model-name Qwen/Qwen2.5-72B-Instruct \ + --weight-bridge-mode shared_vllm \ + --use-shared-memory \ + --training-steps 100 \ + 2>&1 | tee trainer.log ``` diff --git a/example_trainer/grpo.py b/example_trainer/grpo.py index 39836de7..cca23b33 100644 --- a/example_trainer/grpo.py +++ b/example_trainer/grpo.py @@ -155,6 +155,18 @@ class TrainingConfig(BaseModel): "Weight updates are broadcast to vLLM's daemon process." ), ) + + # Single-copy mode (TRUE shared memory - no extra model copy) + single_copy: bool = Field( + False, + description=( + "Enable TRUE single-copy mode via CUDA IPC. " + "The trainer attaches to vLLM's model tensors directly, " + "meaning only ONE copy of the model exists in GPU memory. " + "Requires trainer and vLLM to be on the SAME GPU(s). " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." + ), + ) def check_atropos_api(timeout: float = 30.0) -> bool: @@ -414,9 +426,143 @@ def setup_wandb(config: TrainingConfig) -> bool: return False +def _attach_to_vllm_shared_tensors( + config: TrainingConfig, + bridge_config_path: str, +) -> Optional[torch.nn.Module]: + """ + Attach to vLLM's shared tensors via CUDA IPC (true single-copy mode). + + This creates a model whose parameters point to the SAME GPU memory as vLLM, + meaning only ONE copy of the model exists in GPU memory. + + Args: + config: Training configuration + bridge_config_path: Path to vllm_bridge_config.json + + Returns: + Model with parameters pointing to vLLM's tensors, or None if not possible + """ + try: + with open(bridge_config_path, 'r') as f: + bridge_config = json.load(f) + except Exception as e: + print(f"[Setup] Could not read bridge config: {e}") + return None + + if not bridge_config.get("single_copy_enabled", False): + print("[Setup] Single-copy mode not available (no IPC handles exported)") + return None + + ipc_handles = bridge_config.get("ipc_handles", {}) + if not ipc_handles: + print("[Setup] No IPC handles found in bridge config") + return None + + print(f"[Setup] Attaching to vLLM's shared tensors ({len(ipc_handles)} tensors)...") + print("[Setup] TRUE SINGLE-COPY MODE - No additional model memory!") + + # Create model architecture (meta device - no memory allocation) + with torch.device('meta'): + model = AutoModelForCausalLM.from_pretrained( + config.model_name, + torch_dtype=torch.bfloat16, + ) + + # Map vLLM tensor names to HuggingFace model parameter names + hf_state_dict = {} + vllm_to_hf_mapping = _create_vllm_to_hf_mapping(model, ipc_handles) + + attached_count = 0 + for hf_name, vllm_name in vllm_to_hf_mapping.items(): + if vllm_name not in ipc_handles: + continue + + ipc_info = ipc_handles[vllm_name] + + try: + # Reconstruct tensor from IPC handle + handle_bytes = bytes.fromhex(ipc_info["handle"]) + storage_size = ipc_info["storage_size"] + device_index = ipc_info["device_index"] + + # Create storage from IPC handle + storage = torch.cuda.UntypedStorage._new_shared_cuda( + device_index, + handle_bytes, + storage_size, + ) + + # Reconstruct tensor + dtype = getattr(torch, ipc_info["dtype"].replace("torch.", "")) + tensor = torch.tensor([], dtype=dtype, device=f"cuda:{device_index}") + tensor.set_( + storage, + storage_offset=ipc_info["storage_offset"], + size=ipc_info["shape"], + stride=ipc_info["stride"], + ) + + # Make tensor require gradients for training + tensor.requires_grad_(True) + + hf_state_dict[hf_name] = tensor + attached_count += 1 + + except Exception as e: + print(f"[Setup] Failed to attach {hf_name}: {e}") + continue + + if attached_count == 0: + print("[Setup] Could not attach any tensors, falling back to regular loading") + return None + + print(f"[Setup] ✓ Attached {attached_count} tensors to vLLM's shared memory") + + # Load state dict into model + model.load_state_dict(hf_state_dict, strict=False, assign=True) + + return model + + +def _create_vllm_to_hf_mapping(model: torch.nn.Module, ipc_handles: dict) -> dict: + """ + Create mapping from HuggingFace parameter names to vLLM tensor names. + + vLLM uses slightly different naming conventions than HuggingFace. + This function creates the bidirectional mapping. + """ + hf_params = set(model.state_dict().keys()) + vllm_params = set(ipc_handles.keys()) + + mapping = {} + + for hf_name in hf_params: + # Try direct match first + if hf_name in vllm_params: + mapping[hf_name] = hf_name + continue + + # Try common transformations + # vLLM often uses 'model.' prefix + vllm_name = f"model.{hf_name}" if not hf_name.startswith("model.") else hf_name + if vllm_name in vllm_params: + mapping[hf_name] = vllm_name + continue + + # Remove 'model.' prefix if present + if hf_name.startswith("model."): + vllm_name = hf_name[6:] + if vllm_name in vllm_params: + mapping[hf_name] = vllm_name + + return mapping + + def load_model_and_tokenizer( config: TrainingConfig, bridge: Optional["VLLMWeightBridge"] = None, + single_copy: bool = False, ) -> Tuple[torch.nn.Module, "AutoTokenizer"]: """ Load or attach to model based on weight_bridge_mode. @@ -424,6 +570,7 @@ def load_model_and_tokenizer( Args: config: Training configuration bridge: Optional weight bridge for shared_vllm mode + single_copy: If True, try to attach to vLLM's shared tensors (no extra memory) Returns: Tuple of (model, tokenizer) @@ -431,8 +578,21 @@ def load_model_and_tokenizer( tokenizer = AutoTokenizer.from_pretrained(config.model_name) if config.weight_bridge_mode == "shared_vllm" and bridge is not None: - # Shared vLLM mode: load model, weights will be broadcast via NCCL - print("[Setup] Loading model for shared vLLM mode...") + # Try single-copy mode first if enabled + if single_copy or os.environ.get("VLLM_SINGLE_COPY", "0") == "1": + log_dir = os.environ.get("LOGDIR", ".") + bridge_config_path = os.path.join(log_dir, "vllm_bridge_config.json") + + model = _attach_to_vllm_shared_tensors(config, bridge_config_path) + if model is not None: + print("[Setup] ✓ Single-copy mode active - using vLLM's tensors directly!") + model.train() + return model, tokenizer + else: + print("[Setup] Single-copy failed, falling back to broadcast mode...") + + # Fallback: Load separate model, broadcast updates via NCCL + print("[Setup] Loading model for shared vLLM mode (broadcast)...") if config.use_shared_memory: print("[Setup] NCCL shared memory mode - updates broadcast to vLLM daemon") else: @@ -1101,7 +1261,11 @@ def train_shared_vllm(config: TrainingConfig): # Load model with bridge attachment print("[2/3] Loading model with shared weights...") - model, tokenizer = load_model_and_tokenizer(config, bridge=bridge) + model, tokenizer = load_model_and_tokenizer( + config, + bridge=bridge, + single_copy=config.single_copy + ) # maybe we can actually pick optimizer @@ -1560,6 +1724,18 @@ def parse_args() -> argparse.Namespace: "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." ), ) + + parser.add_argument( + "--single-copy", + action="store_true", + help=( + "Enable TRUE single-copy mode (shared_vllm mode only). " + "Trainer attaches to vLLM's model tensors via CUDA IPC. " + "Only ONE copy of the model exists in GPU memory! " + "Requires trainer and vLLM to be on the SAME GPU(s). " + "vLLM must be started with VLLM_ENABLE_SHARED_WEIGHTS=1." + ), + ) return parser.parse_args() @@ -1591,6 +1767,7 @@ def config_from_args(args: argparse.Namespace) -> TrainingConfig: lora_dropout=args.lora_dropout, lora_target_modules=args.lora_target_modules, use_shared_memory=getattr(args, 'use_shared_memory', False), + single_copy=getattr(args, 'single_copy', False), ) diff --git a/example_trainer/vllm_patching/patched_gpu_runner.py b/example_trainer/vllm_patching/patched_gpu_runner.py index 7631683d..2545e1b3 100644 --- a/example_trainer/vllm_patching/patched_gpu_runner.py +++ b/example_trainer/vllm_patching/patched_gpu_runner.py @@ -230,13 +230,36 @@ def _create_patched_runner(BaseRunner: type) -> type: param_mappings = {} param_names = [] + ipc_handles = {} + for name, tensor in state_dict.items(): param_mappings[name] = { "vllm_name": name, "shape": list(tensor.shape), "dtype": str(tensor.dtype), + "device": str(tensor.device), } param_names.append(name) + + # Export CUDA IPC handles for true single-copy mode + if tensor.is_cuda: + try: + # Get the storage's IPC handle + storage = tensor.untyped_storage() + ipc_handle = storage._share_cuda_() + ipc_handles[name] = { + "handle": ipc_handle[0].hex() if isinstance(ipc_handle[0], bytes) else str(ipc_handle[0]), + "storage_size": ipc_handle[1], + "storage_offset": tensor.storage_offset(), + "shape": list(tensor.shape), + "stride": list(tensor.stride()), + "dtype": str(tensor.dtype), + "device_index": tensor.device.index, + } + except Exception as e: + print(f"[vLLM Patch] Could not get IPC handle for {name}: {e}", flush=True) + + print(f"[vLLM Patch] Exported {len(ipc_handles)} IPC handles for single-copy mode", flush=True) # Get model info model_name = "unknown" @@ -253,8 +276,10 @@ def _create_patched_runner(BaseRunner: type) -> type: "dp_shard_degree": 1, "param_mappings": param_mappings, "param_names": sorted(param_names), + "ipc_handles": ipc_handles, "shared_weights_enabled": True, "num_params": len(param_names), + "single_copy_enabled": len(ipc_handles) > 0, } try: