mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
changes based on torchtitan
This commit is contained in:
parent
078dd4a333
commit
53b29472b4
7 changed files with 1535 additions and 1977 deletions
|
|
@ -199,7 +199,10 @@ tail -f trainer.log
|
|||
|
||||
## Mode 2: Shared vLLM Bridge (In-Place Updates)
|
||||
|
||||
This mode uses an HTTP-based notification system. The trainer notifies vLLM after weight updates.
|
||||
This mode supports two sub-modes:
|
||||
|
||||
1. **HTTP Notification Mode** (default): Trainer notifies vLLM after weight updates
|
||||
2. **NCCL Shared Memory Mode** (`--use-shared-memory`): Weights broadcast via NCCL to vLLM's daemon
|
||||
|
||||
### Step-by-Step Guide
|
||||
|
||||
|
|
@ -211,6 +214,8 @@ sleep 5
|
|||
```
|
||||
|
||||
**Step 2: Start the vLLM Server with Bridge Support**
|
||||
|
||||
For HTTP notification mode:
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
|
|
@ -227,7 +232,27 @@ sleep 90
|
|||
curl -s http://localhost:9001/health && echo "vLLM ready!"
|
||||
```
|
||||
|
||||
For NCCL shared memory mode (requires patched vLLM):
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
export NUM_INFERENCE_NODES=0
|
||||
export VLLM_ENABLE_SHARED_WEIGHTS=1 # Enable shared memory patches
|
||||
mkdir -p $LOGDIR
|
||||
|
||||
python example_trainer/vllm_api_server.py \
|
||||
--model Qwen/Qwen2.5-3B-Instruct \
|
||||
--port 9001 \
|
||||
--gpu-memory-utilization 0.30 \
|
||||
> vllm.log 2>&1 &
|
||||
sleep 90
|
||||
|
||||
curl -s http://localhost:9001/health && echo "vLLM ready!"
|
||||
```
|
||||
|
||||
**Step 3: Start the GRPO Trainer in Shared Mode**
|
||||
|
||||
For HTTP notification mode:
|
||||
```bash
|
||||
cd atropos
|
||||
export LOGDIR=/tmp/atropos_bridge
|
||||
|
|
@ -249,6 +274,22 @@ python example_trainer/grpo.py \
|
|||
sleep 10
|
||||
```
|
||||
|
||||
For NCCL shared memory mode (add `--use-shared-memory`):
|
||||
```bash
|
||||
python example_trainer/grpo.py \
|
||||
--model-name Qwen/Qwen2.5-3B-Instruct \
|
||||
--weight-bridge-mode shared_vllm \
|
||||
--use-shared-memory \
|
||||
--num-inference-nodes 0 \
|
||||
--training-steps 100 \
|
||||
--vllm-port 9001 \
|
||||
--batch-size 2 \
|
||||
--gradient-accumulation-steps 16 \
|
||||
--lr 1e-5 \
|
||||
--save-path checkpoints_shared \
|
||||
> trainer.log 2>&1 &
|
||||
```
|
||||
|
||||
**Step 4: Start the GSM8k Environment**
|
||||
```bash
|
||||
cd atropos
|
||||
|
|
@ -279,15 +320,26 @@ python environments/gsm8k_server.py serve --slurm False --openai.model_name Qwen
|
|||
tail -f trainer.log
|
||||
```
|
||||
|
||||
### What Happens (Local Mode - num_inference_nodes=0)
|
||||
### What Happens (HTTP Notification Mode)
|
||||
|
||||
1. vLLM server starts on port 9001
|
||||
2. Trainer initializes bridge in LOCAL MODE (HTTP-based, no NCCL)
|
||||
2. Trainer initializes bridge in LOCAL MODE (HTTP-based)
|
||||
3. Trainer loads its own model copy and trains normally
|
||||
4. After each `optimizer.step()`:
|
||||
- `bridge.notify_update()` sends HTTP POST to vLLM
|
||||
- Periodic checkpoint saves sync weights to disk
|
||||
5. Much simpler than distributed mode!
|
||||
5. Simple setup, suitable for debugging
|
||||
|
||||
### What Happens (NCCL Shared Memory Mode)
|
||||
|
||||
When using `--use-shared-memory` with `VLLM_ENABLE_SHARED_WEIGHTS=1`:
|
||||
|
||||
1. vLLM patches GPUModelRunner to call `share_memory_()` on model weights
|
||||
2. vLLM spawns a daemon process that joins NCCL groups with the trainer
|
||||
3. Trainer broadcasts weights via NCCL after each optimizer step
|
||||
4. Daemon copies weights into shared tensors → vLLM uses them immediately
|
||||
|
||||
This provides true shared memory without separate model copies!
|
||||
|
||||
### What Happens (Distributed Mode - num_inference_nodes>0)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue