vllm weight bridge

This commit is contained in:
Jai Suphavadeeprasit 2026-01-18 23:01:48 -05:00
parent fe2fd3d824
commit 23b6552277
2 changed files with 448 additions and 32 deletions

View file

@ -223,9 +223,248 @@ CUDA_VISIBLE_DEVICES=0 LOGDIR=. python -u example_trainer/grpo.py \
---
## Alternative Modes
## How Each Mode Works (Data Flow Diagrams)
### Mode 1: Legacy (Checkpoint + Restart)
### Single-Copy Mode (`--weight-bridge-mode shared_vllm`) ⭐ RECOMMENDED
**The Magic**: Trainer and vLLM share the EXACT SAME GPU memory via CUDA IPC.
```
┌─────────────────────────────────────────────────────────────────────────────────────┐
│ SINGLE-COPY MODE - COMPLETE DATA FLOW │
│ │
│ STEP 1: GSM8k sends problem │
│ ┌──────────────────┐ │
│ │ GSM8k Server │──── "What is 15 × 7?" ────▶┌──────────────────┐ │
│ │ (Environment) │ │ Atropos API │ │
│ └──────────────────┘ │ (Batching) │ │
│ └────────┬─────────┘ │
│ │ │
│ STEP 2: Atropos forwards to vLLM │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────────────┐ │
│ │ GPU MEMORY │ │
│ │ │ │
│ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │
│ │ │ MODEL WEIGHTS (ONE COPY - SHARED!) │ │ │
│ │ │ │ │ │
│ │ │ embed_tokens.weight, layers.*.qkv_proj, ..., lm_head.weight │ │ │
│ │ │ (address: 0x7f8a12340000) │ │ │
│ │ └────────────────────────────────────────────────────────────────────────┘ │ │
│ │ ▲ ▲ │ │
│ │ │ STEP 3: READ │ STEP 6: WRITE │ │
│ │ │ (generate tokens) │ (optimizer.step) │ │
│ │ ┌────────┴────────┐ ┌─────────┴─────────┐ │ │
│ │ │ vLLM Server │ │ Trainer │ │ │
│ │ │ │ │ (grpo.py) │ │ │
│ │ │ Generates: │ │ │ │ │
│ │ │ "15 × 7 = 105" │ │ STEP 5: Compute │ │ │
│ │ │ │ │ GRPO loss & │ │ │
│ │ └────────┬────────┘ │ gradients │ │ │
│ │ │ └─────────▲─────────┘ │ │
│ └───────────┼──────────────────────────────────────────────┼────────────────────┘ │
│ │ │ │
│ │ STEP 4: Return completion │ │
│ ▼ │ │
│ ┌──────────────────┐ │ │
│ │ GSM8k Server │───────────────────────────────────────┘ │
│ │ (Scoring) │ │
│ │ │ Scores: "15 × 7 = 105" ✓ reward=1.0 │
│ │ │ "15 × 7 = 100" ✗ reward=0.0 │
│ └──────────────────┘ │
│ │
│ STEP 7: IMMEDIATE UPDATE │
│ ┌─────────────────────────────────────────────────────────────────────────────┐ │
│ │ After optimizer.step(), vLLM's NEXT inference uses the NEW weights! │ │
│ │ NO SYNC NEEDED - it's the same memory! │ │
│ └─────────────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────────────┘
```
**Key Points:**
- ✅ ONE copy of weights in GPU memory
- ✅ 0ms sync latency (same memory!)
- ✅ Memory efficient (~1x model size)
- ⚠️ Requires same GPU for trainer and vLLM
---
### LoRA Mode (`--weight-bridge-mode lora_only`)
**The Idea**: Freeze base model, only train small adapter layers. Hot-swap adapters into vLLM.
```
┌─────────────────────────────────────────────────────────────────────────────────────┐
│ LORA MODE - COMPLETE DATA FLOW │
│ │
│ STEP 1: GSM8k sends problem │
│ ┌──────────────────┐ │
│ │ GSM8k Server │──── "What is 15 × 7?" ────▶┌──────────────────┐ │
│ │ (Environment) │ │ Atropos API │ │
│ └──────────────────┘ └────────┬─────────┘ │
│ │ │
│ STEP 2: Forward to vLLM ▼ │
│ ┌──────────────────────────────────────────────────────────────────────────────┐ │
│ │ vLLM GPU MEMORY │ │
│ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │
│ │ │ BASE MODEL (frozen, ~6GB) │ │ │
│ │ │ + LORA ADAPTER A (current, ~50MB) │ │ │
│ │ └────────────────────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ │ STEP 3: Inference with base + adapter A │ │
│ │ ▼ │ │
│ │ ┌────────────────────┐ │ │
│ │ │ vLLM Server │ ──── "15 × 7 = 105" ────▶ │ │
│ │ └────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────────────┐ │
│ │ TRAINER GPU MEMORY (separate!) │ │
│ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │
│ │ │ BASE MODEL (frozen, ~6GB) │ │ │
│ │ │ + LORA ADAPTER B (training, ~50MB) ◀── gradients flow here only! │ │ │
│ │ └────────────────────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ │ STEP 4-5: Receive rollout, compute loss, update adapter B │ │
│ │ ▼ │ │
│ │ ┌────────────────────┐ │ │
│ │ │ Trainer │ │ │
│ │ │ (grpo.py) │ │ │
│ │ └────────┬───────────┘ │ │
│ └───────────┼──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ │ STEP 6: Every N steps, save adapter B to disk │
│ ▼ │
│ ┌──────────────────┐ STEP 7: POST /lora/load ┌──────────────────┐ │
│ │ adapter_step_N/ │ ─────────────────────────────────▶│ vLLM Server │ │
│ │ (50MB on disk) │ │ Swaps A → B │ │
│ └──────────────────┘ └──────────────────┘ │
│ │
│ STEP 8: Next inference uses NEW adapter B │
│ ┌─────────────────────────────────────────────────────────────────────────────┐ │
│ │ Sync latency: 1-5 seconds (save to disk + HTTP load) │ │
│ │ Memory: 2x base model + adapters │ │
│ └─────────────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────────────┘
```
**Key Points:**
- ✅ Small adapter files (~50MB vs ~28GB)
- ✅ Works on separate GPUs
- ✅ Easy to switch between adapters
- ⚠️ 1-5 second sync latency
- ⚠️ 2x base model memory (trainer + vLLM)
---
### Legacy Mode (`--weight-bridge-mode none`)
**The Simple Approach**: Save full checkpoints, restart vLLM to load new weights.
```
┌─────────────────────────────────────────────────────────────────────────────────────┐
│ LEGACY MODE - COMPLETE DATA FLOW │
│ │
│ STEP 1: GSM8k sends problem │
│ ┌──────────────────┐ │
│ │ GSM8k Server │──── "What is 15 × 7?" ────▶┌──────────────────┐ │
│ │ (Environment) │ │ Atropos API │ │
│ └──────────────────┘ └────────┬─────────┘ │
│ │ │
│ STEP 2: Forward to vLLM ▼ │
│ ┌──────────────────────────────────────────────────────────────────────────────┐ │
│ │ vLLM GPU MEMORY │ │
│ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │
│ │ │ FULL MODEL - Version 1 (~28GB) │ │ │
│ │ └────────────────────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ │ STEP 3: Inference │ │
│ │ ▼ │ │
│ │ ┌────────────────────┐ │ │
│ │ │ vLLM Server │ ──── "15 × 7 = 105" ────▶ │ │
│ │ └────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────────────────────────────┐ │
│ │ TRAINER GPU MEMORY (separate!) │ │
│ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │
│ │ │ FULL MODEL - Version 2 (~28GB + gradients + optimizer) │ │ │
│ │ └────────────────────────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ │ STEP 4-5: Receive rollout, compute loss, update weights │ │
│ │ ▼ │ │
│ │ ┌────────────────────┐ │ │
│ │ │ Trainer │ │ │
│ │ │ (grpo.py) │ │ │
│ │ └────────┬───────────┘ │ │
│ └───────────┼──────────────────────────────────────────────────────────────────┘ │
│ │ │
│ │ STEP 6: Every N steps, save FULL checkpoint to disk (~28GB) │
│ ▼ │
│ ┌──────────────────┐ │
│ │ checkpoint/ │ │
│ │ step_N/ │ (28GB on disk!) │
│ │ - model.safetensors │
│ │ - config.json │
│ └────────┬─────────┘ │
│ │ │
│ │ STEP 7: RESTART vLLM with new checkpoint │
│ │ │
│ │ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ │ 1. Kill vLLM process │ │
│ │ │ 2. Start new vLLM with --model checkpoint/step_N/ │ │
│ │ │ 3. Wait for model to load (~30-60 seconds) │ │
│ │ │ 4. Resume training │ │
│ │ └─────────────────────────────────────────────────────────────────┘ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────────────┐ │
│ │ vLLM GPU MEMORY (restarted) │ │
│ │ ┌────────────────────────────────────────────────────────────────────────┐ │ │
│ │ │ FULL MODEL - Version 2 (loaded from checkpoint) │ │ │
│ │ └────────────────────────────────────────────────────────────────────────┘ │ │
│ └──────────────────────────────────────────────────────────────────────────────┘ │
│ │
│ STEP 8: Next inference uses updated model │
│ ┌─────────────────────────────────────────────────────────────────────────────┐ │
│ │ Sync latency: 30-60 seconds (save + restart + reload) │ │
│ │ Memory: 2x full model │ │
│ │ Disk: 28GB per checkpoint │ │
│ └─────────────────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────────────┘
```
**Key Points:**
- ✅ Simple to understand
- ✅ Works on any setup
- ✅ Good for debugging
- ⚠️ 30-60 second sync latency
- ⚠️ 2x GPU memory (trainer + vLLM)
- ⚠️ Large checkpoint files (~28GB each)
---
## Mode Comparison Summary
```
┌──────────────────────────────────────────────────────────────────────────────────┐
│ MODE COMPARISON AT A GLANCE │
├────────────────┬───────────────┬────────────────┬────────────────────────────────┤
│ │ SINGLE-COPY │ LORA │ LEGACY │
├────────────────┼───────────────┼────────────────┼────────────────────────────────┤
│ Sync Latency │ 0 ms ⚡ │ 1-5 sec │ 30-60 sec │
│ GPU Memory │ 1x model │ 2x model │ 2x model │
│ Disk Space │ 28GB/ckpt │ 50MB/adapter │ 28GB/ckpt │
│ Complexity │ Medium │ Medium │ Simple │
│ Same GPU? │ Required ⚠️ │ Optional │ Optional │
│ Best For │ Production │ Experiments │ Debugging │
└────────────────┴───────────────┴────────────────┴────────────────────────────────┘
```
---
## Alternative Mode Commands
### Legacy Mode (Checkpoint + Restart)
For simple setups or debugging. Saves checkpoints and restarts vLLM to load new weights.
@ -239,7 +478,7 @@ python example_trainer/grpo.py \
--lr 1e-5
```
### Mode 2: LoRA Adapters
### LoRA Mode (Adapter Training)
Trains only adapter weights. Small checkpoints, lower memory.
@ -274,6 +513,8 @@ python example_trainer/grpo.py \
|--------|---------|-------------|
| `--model-name` | (required) | HuggingFace model ID |
| `--weight-bridge-mode` | `none` | `none`, `shared_vllm`, or `lora_only` |
| `--single-copy` | `false` | Enable TRUE single-copy mode via CUDA IPC |
| `--vllm-config-path` | (auto-detect) | Explicit path to `vllm_bridge_config.json` |
| `--vllm-port` | `9001` | vLLM server port |
| `--training-steps` | `10` | Total optimization steps |
| `--batch-size` | `2` | Micro-batch size |
@ -288,6 +529,150 @@ python example_trainer/grpo.py \
| `--tensor-parallel-size` | Number of GPUs (use 1 for single-copy) |
| `--port` | Server port (default: 9001) |
| `--dtype` | Model dtype (`bfloat16`, `float16`, `auto`) |
| `--gpu-memory-utilization` | Fraction of GPU memory for KV cache (default: 0.9) |
---
## The vLLM Bridge Config (vllm_bridge_config.json)
The `vllm_bridge_config.json` file is the critical communication mechanism between the vLLM inference server and the GRPO trainer in single-copy mode. Understanding this file is essential for debugging and advanced configurations.
### What It Is
When you start vLLM with `VLLM_ENABLE_SHARED_WEIGHTS=1`, the patched `GPUModelRunner` exports CUDA IPC (Inter-Process Communication) handles for all model tensors. These handles allow another process (the trainer) to access the exact same GPU memory—no copying required.
### Why It's Important
1. **True Single-Copy Architecture**: Instead of loading the model twice (once for training, once for inference), both processes share the same tensors in GPU memory.
2. **Zero-Latency Weight Updates**: When `optimizer.step()` modifies the weights, vLLM immediately sees the changes—no serialization, no network transfer, no disk I/O.
3. **Memory Efficiency**: For a 7B model (~14GB in bf16), you save ~14GB of GPU memory compared to having two separate copies.
### File Location
The trainer searches for `vllm_bridge_config.json` in this order:
1. **Explicit path** (if `--vllm-config-path` is provided)
2. **`$LOGDIR/vllm_bridge_config.json`** (if `LOGDIR` env var is set)
3. **`./vllm_bridge_config.json`** (current directory)
4. **`/tmp/atropos_bridge/vllm_bridge_config.json`** (default fallback)
**Tip**: To avoid "Config not found" errors, always set `LOGDIR`:
```bash
export LOGDIR=.
```
### File Contents
The JSON file contains everything needed to reconstruct tensor references in another process:
```json
{
"model": "Qwen/Qwen2.5-3B-Instruct",
"tp_degree": 1,
"dp_shard_degree": 1,
"param_names": [
"model.embed_tokens.weight",
"model.layers.0.self_attn.qkv_proj.weight",
...
],
"param_mappings": {
"model.embed_tokens.weight": {
"vllm_name": "model.embed_tokens.weight",
"shape": [152064, 2048],
"dtype": "torch.bfloat16",
"device": "cuda:0"
},
...
},
"ipc_handles": {
"model.embed_tokens.weight": {
"device_index": 0,
"ipc_handle_b64": "AmPA0pN...",
"storage_size": 623902720,
"storage_offset": 0,
"ref_counter_handle_b64": "Y2JY...",
"ref_counter_offset": 0,
"event_handle_b64": "wRIs...",
"event_sync_required": true,
"shape": [152064, 2048],
"dtype": "torch.bfloat16"
},
...
},
"shared_weights_enabled": true,
"single_copy_enabled": true,
"num_params": 255
}
```
#### Field Descriptions
| Field | Description |
|-------|-------------|
| `model` | HuggingFace model identifier |
| `tp_degree` | Tensor parallel degree (must be 1 for single-copy) |
| `param_names` | List of all parameter names in the model |
| `param_mappings` | Shape, dtype, and device info for each parameter |
| `ipc_handles` | CUDA IPC handles for reconstructing shared tensors |
| `ipc_handle_b64` | The actual CUDA IPC handle (base64-encoded bytes) |
| `ref_counter_handle_b64` | Reference counter for CUDA memory (base64) |
| `event_handle_b64` | CUDA event handle for synchronization (base64) |
| `storage_size` | Size of the underlying storage in bytes |
### How the Trainer Uses It
1. **Load Config**: Trainer reads `vllm_bridge_config.json`
2. **Create Shell Model**: Uses `AutoModelForCausalLM.from_config()` with meta tensors (no memory allocation)
3. **Attach IPC Handles**: For each parameter, reconstructs the tensor using `torch.UntypedStorage._new_shared_cuda()` with the IPC handles
4. **Verify Shapes**: Ensures trainer's model architecture matches vLLM's sharding
```python
# Simplified version of what happens internally:
for name, ipc_info in config["ipc_handles"].items():
# Decode IPC handle from base64
ipc_handle = base64.b64decode(ipc_info["ipc_handle_b64"])
# Reconstruct storage from IPC handle
storage = torch.UntypedStorage._new_shared_cuda(
device_index, ipc_handle, storage_size, ...
)
# Create tensor from shared storage
tensor = torch.tensor(storage).view(shape).to(dtype)
# Replace model parameter with shared tensor
model.get_parameter(name).data = tensor
```
### Specifying the Config Path Explicitly
If auto-detection isn't working (e.g., in complex cluster setups), you can specify the path explicitly:
```bash
# If vLLM writes config to a non-standard location:
python -u example_trainer/grpo.py \
--model-name Qwen/Qwen2.5-3B-Instruct \
--weight-bridge-mode shared_vllm \
--single-copy \
--vllm-config-path /shared/nfs/vllm_bridge_config.json \
--training-steps 50
```
### Common Issues
| Symptom | Cause | Fix |
|---------|-------|-----|
| "Could not find vllm_bridge_config.json" | vLLM didn't export config | Check `VLLM_ENABLE_SHARED_WEIGHTS=1` was set BEFORE starting vLLM |
| Config exists but has empty `ipc_handles` | Patch didn't run | Ensure vLLM is using our custom `vllm_api_server.py` |
| "tuple of 8 items expected" | IPC handle format mismatch | Update to latest code (handles all 8 CUDA IPC tuple components) |
| "size mismatch" errors | Tensor parallel mismatch | Use `tensor-parallel-size 1` for single-copy mode |
---
@ -298,11 +683,17 @@ python example_trainer/grpo.py \
**A:** vLLM didn't export the IPC handles. Check:
1. `VLLM_ENABLE_SHARED_WEIGHTS=1` was set **before** starting vLLM
2. Look for export messages in vllm.log:
2. `LOGDIR` is set to a valid, writable directory
3. Look for export messages in vllm.log:
```bash
grep "Exported" vllm.log
```
If the file exists but in a different location, specify it explicitly:
```bash
python grpo.py ... --vllm-config-path /path/to/vllm_bridge_config.json
```
---
### Q: I get "CUDA out of memory" when starting the trainer
@ -365,8 +756,7 @@ pkill -9 -u $USER -f "vllm|grpo|python|run-api"
| File | Description |
|------|-------------|
| `__init__.py` | Module exports and patch application |
| `patched_gpu_runner.py` | Patches GPUModelRunner to export IPC handles |
| `distributed_utils.py` | Distributed training utilities |
| `patched_gpu_runner.py` | Patches GPUModelRunner to export CUDA IPC handles |
---